Browse Source

Add rudimentary interface support

pull/6/head
Ayke van Laethem 7 years ago
parent
commit
3e3c3d259f
  1. 13
      src/examples/hello/hello.go
  2. 4
      src/runtime/runtime.c
  3. 122
      tgo.go

13
src/examples/hello/hello.go

@ -22,12 +22,25 @@ func main() {
thing := Thing{"foo"} thing := Thing{"foo"}
println("thing:", thing.String()) println("thing:", thing.String())
printItf(5)
printItf(byte('x'))
} }
func strlen(s string) int { func strlen(s string) int {
return len(s) return len(s)
} }
func printItf(val interface{}) {
switch val := val.(type) {
case int:
println("is int:", val)
case byte:
println("is byte:", val)
default:
println("is ?")
}
}
func calculateAnswer() int { func calculateAnswer() int {
seven := 7 seven := 7
return SIX * seven return SIX * seven

4
src/runtime/runtime.c

@ -26,6 +26,10 @@ void __go_printint(intgo_t n) {
putchar((n % 10) + '0'); putchar((n % 10) + '0');
} }
void __go_printbyte(uint8_t c) {
putchar(c);
}
void __go_printspace() { void __go_printspace() {
putchar(' '); putchar(' ');
} }

122
tgo.go

@ -37,10 +37,15 @@ type Compiler struct {
intType llvm.Type intType llvm.Type
stringLenType llvm.Type stringLenType llvm.Type
stringType llvm.Type stringType llvm.Type
interfaceType llvm.Type
typeassertType llvm.Type
printstringFunc llvm.Value printstringFunc llvm.Value
printintFunc llvm.Value printintFunc llvm.Value
printbyteFunc llvm.Value
printspaceFunc llvm.Value printspaceFunc llvm.Value
printnlFunc llvm.Value printnlFunc llvm.Value
itfTypeNumbers map[types.Type]uint64
itfTypes []types.Type
} }
type Frame struct { type Frame struct {
@ -60,6 +65,7 @@ type Phi struct {
func NewCompiler(pkgName, triple string) (*Compiler, error) { func NewCompiler(pkgName, triple string) (*Compiler, error) {
c := &Compiler{ c := &Compiler{
triple: triple, triple: triple,
itfTypeNumbers: make(map[types.Type]uint64),
} }
target, err := llvm.GetTargetFromTriple(triple) target, err := llvm.GetTargetFromTriple(triple)
@ -79,10 +85,18 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) {
// Go string: tuple of (len, ptr) // Go string: tuple of (len, ptr)
c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, llvm.PointerType(llvm.Int8Type(), 0)}, false) c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, llvm.PointerType(llvm.Int8Type(), 0)}, false)
// Go interface: tuple of (type, ptr)
c.interfaceType = llvm.StructType([]llvm.Type{llvm.Int32Type(), llvm.PointerType(llvm.Int8Type(), 0)}, false)
// Go typeassert result: tuple of (ptr, bool)
c.typeassertType = llvm.StructType([]llvm.Type{llvm.PointerType(llvm.Int8Type(), 0), llvm.Int1Type()}, false)
printstringType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.stringType}, false) printstringType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.stringType}, false)
c.printstringFunc = llvm.AddFunction(c.mod, "__go_printstring", printstringType) c.printstringFunc = llvm.AddFunction(c.mod, "__go_printstring", printstringType)
printintType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.intType}, false) printintType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.intType}, false)
c.printintFunc = llvm.AddFunction(c.mod, "__go_printint", printintType) c.printintFunc = llvm.AddFunction(c.mod, "__go_printint", printintType)
printbyteType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{llvm.Int8Type()}, false)
c.printbyteFunc = llvm.AddFunction(c.mod, "__go_printbyte", printbyteType)
printspaceType := llvm.FunctionType(llvm.VoidType(), nil, false) printspaceType := llvm.FunctionType(llvm.VoidType(), nil, false)
c.printspaceFunc = llvm.AddFunction(c.mod, "__go_printspace", printspaceType) c.printspaceFunc = llvm.AddFunction(c.mod, "__go_printspace", printspaceType)
printnlType := llvm.FunctionType(llvm.VoidType(), nil, false) printnlType := llvm.FunctionType(llvm.VoidType(), nil, false)
@ -154,16 +168,8 @@ func (c *Compiler) Parse(pkgName string) error {
} }
frames[member] = frame frames[member] = frame
case *ssa.NamedConst: case *ssa.NamedConst:
val, err := c.parseConst(member.Value) // Ignore package-level untyped constants. The SSA form doesn't
if err != nil { // need them.
return err
}
global := llvm.AddGlobal(c.mod, val.Type(), pkgPrefix + "." + member.Name())
global.SetInitializer(val)
global.SetGlobalConstant(true)
if ast.IsExported(member.Name()) {
global.SetLinkage(llvm.PrivateLinkage)
}
case *ssa.Global: case *ssa.Global:
typ, err := c.getLLVMType(member.Type()) typ, err := c.getLLVMType(member.Type())
if err != nil { if err != nil {
@ -225,6 +231,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
switch typ.Kind() { switch typ.Kind() {
case types.Bool: case types.Bool:
return llvm.Int1Type(), nil return llvm.Int1Type(), nil
case types.Uint8:
return llvm.Int8Type(), nil
case types.Int: case types.Int:
return c.intType, nil return c.intType, nil
case types.Int32: case types.Int32:
@ -236,6 +244,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
default: default:
return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ)) return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ))
} }
case *types.Interface:
return c.interfaceType, nil
case *types.Named: case *types.Named:
return c.getLLVMType(typ.Underlying()) return c.getLLVMType(typ.Underlying())
case *types.Pointer: case *types.Pointer:
@ -259,6 +269,15 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
} }
} }
func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value {
if _, ok := c.itfTypeNumbers[typ]; !ok {
num := uint64(len(c.itfTypes))
c.itfTypes = append(c.itfTypes, typ)
c.itfTypeNumbers[typ] = num
}
return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false)
}
func (c *Compiler) getFunctionName(pkgPrefix string, fn *ssa.Function) string { func (c *Compiler) getFunctionName(pkgPrefix string, fn *ssa.Function) string {
if fn.Signature.Recv() != nil { if fn.Signature.Recv() != nil {
// Method on a defined type. // Method on a defined type.
@ -316,6 +335,12 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
frame.blocks[block] = llvmBlock frame.blocks[block] = llvmBlock
} }
// Load function parameters
for _, param := range f.Params {
llvmParam := llvmFn.Param(frame.params[param])
frame.locals[param] = llvmParam
}
// Fill those blocks with instructions. // Fill those blocks with instructions.
for _, block := range f.DomPreorder() { for _, block := range f.DomPreorder() {
c.builder.SetInsertPointAtEnd(frame.blocks[block]) c.builder.SetInsertPointAtEnd(frame.blocks[block])
@ -412,6 +437,8 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B
switch typ := arg.Type().(type) { switch typ := arg.Type().(type) {
case *types.Basic: case *types.Basic:
switch typ.Kind() { switch typ.Kind() {
case types.Uint8:
c.builder.CreateCall(c.printbyteFunc, []llvm.Value{value}, "")
case types.Int, types.Int32: // TODO: assumes a 32-bit int type case types.Int, types.Int32: // TODO: assumes a 32-bit int type
c.builder.CreateCall(c.printintFunc, []llvm.Value{value}, "") c.builder.CreateCall(c.printintFunc, []llvm.Value{value}, "")
case types.String: case types.String:
@ -506,12 +533,19 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
} else { } else {
return c.builder.CreateAlloca(typ, expr.Comment), nil return c.builder.CreateAlloca(typ, expr.Comment), nil
} }
case *ssa.Const:
return c.parseConst(expr)
case *ssa.BinOp: case *ssa.BinOp:
return c.parseBinOp(frame, expr) return c.parseBinOp(frame, expr)
case *ssa.Call: case *ssa.Call:
return c.parseCall(frame, expr) return c.parseCall(frame, expr)
case *ssa.Const:
return c.parseConst(expr)
case *ssa.Extract:
value, err := c.parseExpr(frame, expr.Tuple)
if err != nil {
return llvm.Value{}, err
}
result := c.builder.CreateExtractValue(value, expr.Index, "")
return result, nil
case *ssa.FieldAddr: case *ssa.FieldAddr:
val, err := c.parseExpr(frame, expr.X) val, err := c.parseExpr(frame, expr.X)
if err != nil { if err != nil {
@ -524,9 +558,16 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
return c.builder.CreateGEP(val, indices, ""), nil return c.builder.CreateGEP(val, indices, ""), nil
case *ssa.Global: case *ssa.Global:
return c.mod.NamedGlobal(expr.Name()), nil return c.mod.NamedGlobal(expr.Name()), nil
case *ssa.Parameter: case *ssa.MakeInterface:
llvmFn := c.mod.NamedFunction(frame.name) val, err := c.parseExpr(frame, expr.X)
return llvmFn.Param(frame.params[expr]), nil if err != nil {
return llvm.Value{}, err
}
bitcast := c.builder.CreateIntToPtr(val, llvm.PointerType(llvm.Int8Type(), 0), "")
itfType := c.getInterfaceType(expr.X.Type())
itf := c.ctx.ConstStruct([]llvm.Value{itfType, llvm.Undef(llvm.PointerType(llvm.Int8Type(), 0))}, false)
itf = c.builder.CreateInsertValue(itf, bitcast, 1, "")
return itf, nil
case *ssa.Phi: case *ssa.Phi:
t, err := c.getLLVMType(expr.Type()) t, err := c.getLLVMType(expr.Type())
if err != nil { if err != nil {
@ -535,6 +576,27 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
phi := c.builder.CreatePHI(t, "") phi := c.builder.CreatePHI(t, "")
frame.phis = append(frame.phis, Phi{expr, phi}) frame.phis = append(frame.phis, Phi{expr, phi})
return phi, nil return phi, nil
case *ssa.TypeAssert:
if !expr.CommaOk {
return llvm.Value{}, errors.New("todo: type assert without comma-ok")
}
itf, err := c.parseExpr(frame, expr.X)
if err != nil {
return llvm.Value{}, err
}
assertedType, err := c.getLLVMType(expr.AssertedType)
if err != nil {
return llvm.Value{}, err
}
assertedTypeNum := c.getInterfaceType(expr.AssertedType)
actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type")
valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value")
value := c.builder.CreatePtrToInt(valuePtr, assertedType, "")
commaOk := c.builder.CreateICmp(llvm.IntEQ, assertedTypeNum, actualTypeNum, "")
tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple
tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value
tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean
return tuple, nil
case *ssa.UnOp: case *ssa.UnOp:
return c.parseUnOp(frame, expr) return c.parseUnOp(frame, expr)
default: default:
@ -602,8 +664,34 @@ func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) {
strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false)
return strObj, nil return strObj, nil
case constant.Int: case constant.Int:
n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value? switch expr.Type().(*types.Basic).Kind() {
return llvm.ConstInt(c.intType, uint64(n), true), nil case types.Bool:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int1Type(), uint64(n), false), nil
case types.Int:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(c.intType, uint64(n), false), nil
case types.Int8:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int8Type(), uint64(n), false), nil
case types.Uint8:
n, _ := constant.Uint64Val(expr.Value)
return llvm.ConstInt(llvm.Int8Type(), n, false), nil
case types.Int32:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int32Type(), uint64(n), false), nil
case types.Uint32:
n, _ := constant.Uint64Val(expr.Value)
return llvm.ConstInt(llvm.Int32Type(), n, false), nil
case types.Int64:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int64Type(), uint64(n), false), nil
case types.Uint64:
n, _ := constant.Uint64Val(expr.Value)
return llvm.ConstInt(llvm.Int64Type(), n, false), nil
default:
return llvm.Value{}, errors.New("todo: unknown integer constant")
}
default: default:
return llvm.Value{}, errors.New("todo: unknown constant") return llvm.Value{}, errors.New("todo: unknown constant")
} }

Loading…
Cancel
Save