Browse Source

Add rudimentary interface support

pull/6/head
Ayke van Laethem 7 years ago
parent
commit
3e3c3d259f
  1. 13
      src/examples/hello/hello.go
  2. 4
      src/runtime/runtime.c
  3. 124
      tgo.go

13
src/examples/hello/hello.go

@ -22,12 +22,25 @@ func main() {
thing := Thing{"foo"}
println("thing:", thing.String())
printItf(5)
printItf(byte('x'))
}
func strlen(s string) int {
return len(s)
}
func printItf(val interface{}) {
switch val := val.(type) {
case int:
println("is int:", val)
case byte:
println("is byte:", val)
default:
println("is ?")
}
}
func calculateAnswer() int {
seven := 7
return SIX * seven

4
src/runtime/runtime.c

@ -26,6 +26,10 @@ void __go_printint(intgo_t n) {
putchar((n % 10) + '0');
}
void __go_printbyte(uint8_t c) {
putchar(c);
}
void __go_printspace() {
putchar(' ');
}

124
tgo.go

@ -37,10 +37,15 @@ type Compiler struct {
intType llvm.Type
stringLenType llvm.Type
stringType llvm.Type
interfaceType llvm.Type
typeassertType llvm.Type
printstringFunc llvm.Value
printintFunc llvm.Value
printbyteFunc llvm.Value
printspaceFunc llvm.Value
printnlFunc llvm.Value
itfTypeNumbers map[types.Type]uint64
itfTypes []types.Type
}
type Frame struct {
@ -59,7 +64,8 @@ type Phi struct {
func NewCompiler(pkgName, triple string) (*Compiler, error) {
c := &Compiler{
triple: triple,
triple: triple,
itfTypeNumbers: make(map[types.Type]uint64),
}
target, err := llvm.GetTargetFromTriple(triple)
@ -79,10 +85,18 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) {
// Go string: tuple of (len, ptr)
c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, llvm.PointerType(llvm.Int8Type(), 0)}, false)
// Go interface: tuple of (type, ptr)
c.interfaceType = llvm.StructType([]llvm.Type{llvm.Int32Type(), llvm.PointerType(llvm.Int8Type(), 0)}, false)
// Go typeassert result: tuple of (ptr, bool)
c.typeassertType = llvm.StructType([]llvm.Type{llvm.PointerType(llvm.Int8Type(), 0), llvm.Int1Type()}, false)
printstringType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.stringType}, false)
c.printstringFunc = llvm.AddFunction(c.mod, "__go_printstring", printstringType)
printintType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.intType}, false)
c.printintFunc = llvm.AddFunction(c.mod, "__go_printint", printintType)
printbyteType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{llvm.Int8Type()}, false)
c.printbyteFunc = llvm.AddFunction(c.mod, "__go_printbyte", printbyteType)
printspaceType := llvm.FunctionType(llvm.VoidType(), nil, false)
c.printspaceFunc = llvm.AddFunction(c.mod, "__go_printspace", printspaceType)
printnlType := llvm.FunctionType(llvm.VoidType(), nil, false)
@ -154,16 +168,8 @@ func (c *Compiler) Parse(pkgName string) error {
}
frames[member] = frame
case *ssa.NamedConst:
val, err := c.parseConst(member.Value)
if err != nil {
return err
}
global := llvm.AddGlobal(c.mod, val.Type(), pkgPrefix + "." + member.Name())
global.SetInitializer(val)
global.SetGlobalConstant(true)
if ast.IsExported(member.Name()) {
global.SetLinkage(llvm.PrivateLinkage)
}
// Ignore package-level untyped constants. The SSA form doesn't
// need them.
case *ssa.Global:
typ, err := c.getLLVMType(member.Type())
if err != nil {
@ -225,6 +231,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
switch typ.Kind() {
case types.Bool:
return llvm.Int1Type(), nil
case types.Uint8:
return llvm.Int8Type(), nil
case types.Int:
return c.intType, nil
case types.Int32:
@ -236,6 +244,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
default:
return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ))
}
case *types.Interface:
return c.interfaceType, nil
case *types.Named:
return c.getLLVMType(typ.Underlying())
case *types.Pointer:
@ -259,6 +269,15 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
}
}
func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value {
if _, ok := c.itfTypeNumbers[typ]; !ok {
num := uint64(len(c.itfTypes))
c.itfTypes = append(c.itfTypes, typ)
c.itfTypeNumbers[typ] = num
}
return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false)
}
func (c *Compiler) getFunctionName(pkgPrefix string, fn *ssa.Function) string {
if fn.Signature.Recv() != nil {
// Method on a defined type.
@ -316,6 +335,12 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
frame.blocks[block] = llvmBlock
}
// Load function parameters
for _, param := range f.Params {
llvmParam := llvmFn.Param(frame.params[param])
frame.locals[param] = llvmParam
}
// Fill those blocks with instructions.
for _, block := range f.DomPreorder() {
c.builder.SetInsertPointAtEnd(frame.blocks[block])
@ -412,6 +437,8 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B
switch typ := arg.Type().(type) {
case *types.Basic:
switch typ.Kind() {
case types.Uint8:
c.builder.CreateCall(c.printbyteFunc, []llvm.Value{value}, "")
case types.Int, types.Int32: // TODO: assumes a 32-bit int type
c.builder.CreateCall(c.printintFunc, []llvm.Value{value}, "")
case types.String:
@ -506,12 +533,19 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
} else {
return c.builder.CreateAlloca(typ, expr.Comment), nil
}
case *ssa.Const:
return c.parseConst(expr)
case *ssa.BinOp:
return c.parseBinOp(frame, expr)
case *ssa.Call:
return c.parseCall(frame, expr)
case *ssa.Const:
return c.parseConst(expr)
case *ssa.Extract:
value, err := c.parseExpr(frame, expr.Tuple)
if err != nil {
return llvm.Value{}, err
}
result := c.builder.CreateExtractValue(value, expr.Index, "")
return result, nil
case *ssa.FieldAddr:
val, err := c.parseExpr(frame, expr.X)
if err != nil {
@ -524,9 +558,16 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
return c.builder.CreateGEP(val, indices, ""), nil
case *ssa.Global:
return c.mod.NamedGlobal(expr.Name()), nil
case *ssa.Parameter:
llvmFn := c.mod.NamedFunction(frame.name)
return llvmFn.Param(frame.params[expr]), nil
case *ssa.MakeInterface:
val, err := c.parseExpr(frame, expr.X)
if err != nil {
return llvm.Value{}, err
}
bitcast := c.builder.CreateIntToPtr(val, llvm.PointerType(llvm.Int8Type(), 0), "")
itfType := c.getInterfaceType(expr.X.Type())
itf := c.ctx.ConstStruct([]llvm.Value{itfType, llvm.Undef(llvm.PointerType(llvm.Int8Type(), 0))}, false)
itf = c.builder.CreateInsertValue(itf, bitcast, 1, "")
return itf, nil
case *ssa.Phi:
t, err := c.getLLVMType(expr.Type())
if err != nil {
@ -535,6 +576,27 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
phi := c.builder.CreatePHI(t, "")
frame.phis = append(frame.phis, Phi{expr, phi})
return phi, nil
case *ssa.TypeAssert:
if !expr.CommaOk {
return llvm.Value{}, errors.New("todo: type assert without comma-ok")
}
itf, err := c.parseExpr(frame, expr.X)
if err != nil {
return llvm.Value{}, err
}
assertedType, err := c.getLLVMType(expr.AssertedType)
if err != nil {
return llvm.Value{}, err
}
assertedTypeNum := c.getInterfaceType(expr.AssertedType)
actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type")
valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value")
value := c.builder.CreatePtrToInt(valuePtr, assertedType, "")
commaOk := c.builder.CreateICmp(llvm.IntEQ, assertedTypeNum, actualTypeNum, "")
tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple
tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value
tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean
return tuple, nil
case *ssa.UnOp:
return c.parseUnOp(frame, expr)
default:
@ -602,8 +664,34 @@ func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) {
strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false)
return strObj, nil
case constant.Int:
n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value?
return llvm.ConstInt(c.intType, uint64(n), true), nil
switch expr.Type().(*types.Basic).Kind() {
case types.Bool:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int1Type(), uint64(n), false), nil
case types.Int:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(c.intType, uint64(n), false), nil
case types.Int8:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int8Type(), uint64(n), false), nil
case types.Uint8:
n, _ := constant.Uint64Val(expr.Value)
return llvm.ConstInt(llvm.Int8Type(), n, false), nil
case types.Int32:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int32Type(), uint64(n), false), nil
case types.Uint32:
n, _ := constant.Uint64Val(expr.Value)
return llvm.ConstInt(llvm.Int32Type(), n, false), nil
case types.Int64:
n, _ := constant.Int64Val(expr.Value)
return llvm.ConstInt(llvm.Int64Type(), uint64(n), false), nil
case types.Uint64:
n, _ := constant.Uint64Val(expr.Value)
return llvm.ConstInt(llvm.Int64Type(), n, false), nil
default:
return llvm.Value{}, errors.New("todo: unknown integer constant")
}
default:
return llvm.Value{}, errors.New("todo: unknown constant")
}

Loading…
Cancel
Save