diff --git a/src/examples/hello/hello.go b/src/examples/hello/hello.go index ea6d14ac..a47ff510 100644 --- a/src/examples/hello/hello.go +++ b/src/examples/hello/hello.go @@ -22,12 +22,25 @@ func main() { thing := Thing{"foo"} println("thing:", thing.String()) + printItf(5) + printItf(byte('x')) } func strlen(s string) int { return len(s) } +func printItf(val interface{}) { + switch val := val.(type) { + case int: + println("is int:", val) + case byte: + println("is byte:", val) + default: + println("is ?") + } +} + func calculateAnswer() int { seven := 7 return SIX * seven diff --git a/src/runtime/runtime.c b/src/runtime/runtime.c index 51cda99c..93ff2024 100644 --- a/src/runtime/runtime.c +++ b/src/runtime/runtime.c @@ -26,6 +26,10 @@ void __go_printint(intgo_t n) { putchar((n % 10) + '0'); } +void __go_printbyte(uint8_t c) { + putchar(c); +} + void __go_printspace() { putchar(' '); } diff --git a/tgo.go b/tgo.go index ab14d9e6..faa6215d 100644 --- a/tgo.go +++ b/tgo.go @@ -37,10 +37,15 @@ type Compiler struct { intType llvm.Type stringLenType llvm.Type stringType llvm.Type + interfaceType llvm.Type + typeassertType llvm.Type printstringFunc llvm.Value printintFunc llvm.Value + printbyteFunc llvm.Value printspaceFunc llvm.Value printnlFunc llvm.Value + itfTypeNumbers map[types.Type]uint64 + itfTypes []types.Type } type Frame struct { @@ -59,7 +64,8 @@ type Phi struct { func NewCompiler(pkgName, triple string) (*Compiler, error) { c := &Compiler{ - triple: triple, + triple: triple, + itfTypeNumbers: make(map[types.Type]uint64), } target, err := llvm.GetTargetFromTriple(triple) @@ -79,10 +85,18 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) { // Go string: tuple of (len, ptr) c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, llvm.PointerType(llvm.Int8Type(), 0)}, false) + // Go interface: tuple of (type, ptr) + c.interfaceType = llvm.StructType([]llvm.Type{llvm.Int32Type(), llvm.PointerType(llvm.Int8Type(), 0)}, false) + + // Go typeassert result: tuple of (ptr, bool) + c.typeassertType = llvm.StructType([]llvm.Type{llvm.PointerType(llvm.Int8Type(), 0), llvm.Int1Type()}, false) + printstringType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.stringType}, false) c.printstringFunc = llvm.AddFunction(c.mod, "__go_printstring", printstringType) printintType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.intType}, false) c.printintFunc = llvm.AddFunction(c.mod, "__go_printint", printintType) + printbyteType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{llvm.Int8Type()}, false) + c.printbyteFunc = llvm.AddFunction(c.mod, "__go_printbyte", printbyteType) printspaceType := llvm.FunctionType(llvm.VoidType(), nil, false) c.printspaceFunc = llvm.AddFunction(c.mod, "__go_printspace", printspaceType) printnlType := llvm.FunctionType(llvm.VoidType(), nil, false) @@ -154,16 +168,8 @@ func (c *Compiler) Parse(pkgName string) error { } frames[member] = frame case *ssa.NamedConst: - val, err := c.parseConst(member.Value) - if err != nil { - return err - } - global := llvm.AddGlobal(c.mod, val.Type(), pkgPrefix + "." + member.Name()) - global.SetInitializer(val) - global.SetGlobalConstant(true) - if ast.IsExported(member.Name()) { - global.SetLinkage(llvm.PrivateLinkage) - } + // Ignore package-level untyped constants. The SSA form doesn't + // need them. case *ssa.Global: typ, err := c.getLLVMType(member.Type()) if err != nil { @@ -225,6 +231,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { switch typ.Kind() { case types.Bool: return llvm.Int1Type(), nil + case types.Uint8: + return llvm.Int8Type(), nil case types.Int: return c.intType, nil case types.Int32: @@ -236,6 +244,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { default: return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ)) } + case *types.Interface: + return c.interfaceType, nil case *types.Named: return c.getLLVMType(typ.Underlying()) case *types.Pointer: @@ -259,6 +269,15 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } } +func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value { + if _, ok := c.itfTypeNumbers[typ]; !ok { + num := uint64(len(c.itfTypes)) + c.itfTypes = append(c.itfTypes, typ) + c.itfTypeNumbers[typ] = num + } + return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false) +} + func (c *Compiler) getFunctionName(pkgPrefix string, fn *ssa.Function) string { if fn.Signature.Recv() != nil { // Method on a defined type. @@ -316,6 +335,12 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { frame.blocks[block] = llvmBlock } + // Load function parameters + for _, param := range f.Params { + llvmParam := llvmFn.Param(frame.params[param]) + frame.locals[param] = llvmParam + } + // Fill those blocks with instructions. for _, block := range f.DomPreorder() { c.builder.SetInsertPointAtEnd(frame.blocks[block]) @@ -412,6 +437,8 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B switch typ := arg.Type().(type) { case *types.Basic: switch typ.Kind() { + case types.Uint8: + c.builder.CreateCall(c.printbyteFunc, []llvm.Value{value}, "") case types.Int, types.Int32: // TODO: assumes a 32-bit int type c.builder.CreateCall(c.printintFunc, []llvm.Value{value}, "") case types.String: @@ -506,12 +533,19 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { } else { return c.builder.CreateAlloca(typ, expr.Comment), nil } - case *ssa.Const: - return c.parseConst(expr) case *ssa.BinOp: return c.parseBinOp(frame, expr) case *ssa.Call: return c.parseCall(frame, expr) + case *ssa.Const: + return c.parseConst(expr) + case *ssa.Extract: + value, err := c.parseExpr(frame, expr.Tuple) + if err != nil { + return llvm.Value{}, err + } + result := c.builder.CreateExtractValue(value, expr.Index, "") + return result, nil case *ssa.FieldAddr: val, err := c.parseExpr(frame, expr.X) if err != nil { @@ -524,9 +558,16 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Global: return c.mod.NamedGlobal(expr.Name()), nil - case *ssa.Parameter: - llvmFn := c.mod.NamedFunction(frame.name) - return llvmFn.Param(frame.params[expr]), nil + case *ssa.MakeInterface: + val, err := c.parseExpr(frame, expr.X) + if err != nil { + return llvm.Value{}, err + } + bitcast := c.builder.CreateIntToPtr(val, llvm.PointerType(llvm.Int8Type(), 0), "") + itfType := c.getInterfaceType(expr.X.Type()) + itf := c.ctx.ConstStruct([]llvm.Value{itfType, llvm.Undef(llvm.PointerType(llvm.Int8Type(), 0))}, false) + itf = c.builder.CreateInsertValue(itf, bitcast, 1, "") + return itf, nil case *ssa.Phi: t, err := c.getLLVMType(expr.Type()) if err != nil { @@ -535,6 +576,27 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { phi := c.builder.CreatePHI(t, "") frame.phis = append(frame.phis, Phi{expr, phi}) return phi, nil + case *ssa.TypeAssert: + if !expr.CommaOk { + return llvm.Value{}, errors.New("todo: type assert without comma-ok") + } + itf, err := c.parseExpr(frame, expr.X) + if err != nil { + return llvm.Value{}, err + } + assertedType, err := c.getLLVMType(expr.AssertedType) + if err != nil { + return llvm.Value{}, err + } + assertedTypeNum := c.getInterfaceType(expr.AssertedType) + actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") + valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value") + value := c.builder.CreatePtrToInt(valuePtr, assertedType, "") + commaOk := c.builder.CreateICmp(llvm.IntEQ, assertedTypeNum, actualTypeNum, "") + tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple + tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value + tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean + return tuple, nil case *ssa.UnOp: return c.parseUnOp(frame, expr) default: @@ -602,8 +664,34 @@ func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) { strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) return strObj, nil case constant.Int: - n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value? - return llvm.ConstInt(c.intType, uint64(n), true), nil + switch expr.Type().(*types.Basic).Kind() { + case types.Bool: + n, _ := constant.Int64Val(expr.Value) + return llvm.ConstInt(llvm.Int1Type(), uint64(n), false), nil + case types.Int: + n, _ := constant.Int64Val(expr.Value) + return llvm.ConstInt(c.intType, uint64(n), false), nil + case types.Int8: + n, _ := constant.Int64Val(expr.Value) + return llvm.ConstInt(llvm.Int8Type(), uint64(n), false), nil + case types.Uint8: + n, _ := constant.Uint64Val(expr.Value) + return llvm.ConstInt(llvm.Int8Type(), n, false), nil + case types.Int32: + n, _ := constant.Int64Val(expr.Value) + return llvm.ConstInt(llvm.Int32Type(), uint64(n), false), nil + case types.Uint32: + n, _ := constant.Uint64Val(expr.Value) + return llvm.ConstInt(llvm.Int32Type(), n, false), nil + case types.Int64: + n, _ := constant.Int64Val(expr.Value) + return llvm.ConstInt(llvm.Int64Type(), uint64(n), false), nil + case types.Uint64: + n, _ := constant.Uint64Val(expr.Value) + return llvm.ConstInt(llvm.Int64Type(), n, false), nil + default: + return llvm.Value{}, errors.New("todo: unknown integer constant") + } default: return llvm.Value{}, errors.New("todo: unknown constant") }