package main import ( "errors" "flag" "fmt" "go/build" "go/constant" "go/token" "go/types" "os" "sort" "strings" "github.com/aykevl/llvm/bindings/go/llvm" "golang.org/x/tools/go/loader" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) func init() { llvm.InitializeAllTargets() llvm.InitializeAllTargetMCs() llvm.InitializeAllTargetInfos() llvm.InitializeAllAsmParsers() llvm.InitializeAllAsmPrinters() } type Compiler struct { triple string mod llvm.Module ctx llvm.Context builder llvm.Builder machine llvm.TargetMachine targetData llvm.TargetData intType llvm.Type i8ptrType llvm.Type // for convenience uintptrType llvm.Type stringLenType llvm.Type stringType llvm.Type interfaceType llvm.Type typeassertType llvm.Type allocFunc llvm.Value freeFunc llvm.Value itfTypeNumbers map[types.Type]uint64 itfTypes []types.Type initFuncs []llvm.Value } type Frame struct { llvmFn llvm.Value params map[*ssa.Parameter]int // arguments to the function locals map[ssa.Value]llvm.Value // local variables blocks map[*ssa.BasicBlock]llvm.BasicBlock phis []Phi } func pkgPrefix(pkg *ssa.Package) string { if pkg.Pkg.Name() == "main" { return "main" } return pkg.Pkg.Path() } type Phi struct { ssa *ssa.Phi llvm llvm.Value } func NewCompiler(pkgName, triple string) (*Compiler, error) { c := &Compiler{ triple: triple, itfTypeNumbers: make(map[types.Type]uint64), } target, err := llvm.GetTargetFromTriple(triple) if err != nil { return nil, err } c.machine = target.CreateTargetMachine(triple, "", "", llvm.CodeGenLevelDefault, llvm.RelocDefault, llvm.CodeModelDefault) c.targetData = c.machine.CreateTargetData() c.mod = llvm.NewModule(pkgName) c.ctx = c.mod.Context() c.builder = c.ctx.NewBuilder() // Depends on platform (32bit or 64bit), but fix it here for now. c.intType = llvm.Int32Type() c.stringLenType = llvm.Int32Type() c.uintptrType = c.targetData.IntPtrType() c.i8ptrType = llvm.PointerType(llvm.Int8Type(), 0) // Go string: tuple of (len, ptr) c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, c.i8ptrType}, false) // Go interface: tuple of (type, ptr) c.interfaceType = llvm.StructType([]llvm.Type{llvm.Int32Type(), c.i8ptrType}, false) // Go typeassert result: tuple of (ptr, bool) c.typeassertType = llvm.StructType([]llvm.Type{c.i8ptrType, llvm.Int1Type()}, false) allocType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.uintptrType}, false) c.allocFunc = llvm.AddFunction(c.mod, "runtime.alloc", allocType) freeType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.i8ptrType}, false) c.freeFunc = llvm.AddFunction(c.mod, "runtime.free", freeType) return c, nil } func (c *Compiler) Parse(mainPath string, buildTags []string) error { tripleSplit := strings.Split(c.triple, "-") config := loader.Config { TypeChecker: types.Config{ Sizes: &types.StdSizes{ int64(c.targetData.PointerSize()), int64(c.targetData.PrefTypeAlignment(c.i8ptrType)), }, }, Build: &build.Context { GOARCH: tripleSplit[0], GOOS: tripleSplit[2], GOROOT: ".", CgoEnabled: true, UseAllFiles: false, Compiler: "gc", // must be one of the recognized compilers BuildTags: append([]string{"tgo"}, buildTags...), }, AllowErrors: true, } config.Import("runtime") config.Import(mainPath) lprogram, err := config.Load() if err != nil { return err } // TODO: pick the error of the first package, not a random package for _, pkgInfo := range lprogram.AllPackages { if len(pkgInfo.Errors) != 0 { return pkgInfo.Errors[0] } } program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits) program.Build() // Make a list of packages in import order. packageList := []*ssa.Package{} packageSet := map[string]struct{}{} worklist := []string{"runtime", mainPath} for len(worklist) != 0 { pkgPath := worklist[0] pkg := program.ImportedPackage(pkgPath) if pkg == nil { // Non-SSA package (e.g. cgo). packageSet[pkgPath] = struct{}{} worklist = worklist[1:] continue } if _, ok := packageSet[pkgPath]; ok { // Package already in the final package list. worklist = worklist[1:] continue } unsatisfiedImports := make([]string, 0) imports := pkg.Pkg.Imports() for _, pkg := range imports { if _, ok := packageSet[pkg.Path()]; ok { continue } unsatisfiedImports = append(unsatisfiedImports, pkg.Path()) } if len(unsatisfiedImports) == 0 { // All dependencies of this package are satisfied, so add this // package to the list. packageList = append(packageList, pkg) packageSet[pkgPath] = struct{}{} worklist = worklist[1:] } else { // Prepend all dependencies to the worklist and reconsider this // package (by not removing it from the worklist). At that point, it // must be possible to add it to packageList. worklist = append(unsatisfiedImports, worklist...) } } // Transform each package into LLVM IR. for _, pkg := range packageList { err := c.parsePackage(program, pkg) if err != nil { return err } } // After all packages are imported, add a synthetic initializer function // that calls the initializer of each package. initFn := c.mod.NamedFunction("runtime.initAll") if initFn.IsNil() { initType := llvm.FunctionType(llvm.VoidType(), nil, false) initFn = llvm.AddFunction(c.mod, "runtime.initAll", initType) } initFn.SetLinkage(llvm.PrivateLinkage) block := c.ctx.AddBasicBlock(initFn, "entry") c.builder.SetInsertPointAtEnd(block) for _, fn := range c.initFuncs { c.builder.CreateCall(fn, nil, "") } c.builder.CreateRetVoid() // Set functions referenced in runtime.ll to internal linkage, to improve // optimization (hopefully). main := c.mod.NamedFunction("main.main") main.SetLinkage(llvm.PrivateLinkage) return nil } func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { switch typ := goType.(type) { case *types.Array: elemType, err := c.getLLVMType(typ.Elem()) if err != nil { return llvm.Type{}, err } return llvm.ArrayType(elemType, int(typ.Len())), nil case *types.Basic: switch typ.Kind() { case types.Bool: return llvm.Int1Type(), nil case types.Int8, types.Uint8: return llvm.Int8Type(), nil case types.Int16, types.Uint16: return llvm.Int16Type(), nil case types.Int32, types.Uint32: return llvm.Int32Type(), nil case types.Int, types.Uint: return c.intType, nil case types.Int64, types.Uint64: return llvm.Int64Type(), nil case types.String: return c.stringType, nil case types.Uintptr: return c.uintptrType, nil case types.UnsafePointer: return c.i8ptrType, nil default: return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ)) } case *types.Interface: return c.interfaceType, nil case *types.Named: return c.getLLVMType(typ.Underlying()) case *types.Pointer: ptrTo, err := c.getLLVMType(typ.Elem()) if err != nil { return llvm.Type{}, err } return llvm.PointerType(ptrTo, 0), nil case *types.Struct: members := make([]llvm.Type, typ.NumFields()) for i := 0; i < typ.NumFields(); i++ { member, err := c.getLLVMType(typ.Field(i).Type()) if err != nil { return llvm.Type{}, err } members[i] = member } return llvm.StructType(members, false), nil default: return llvm.Type{}, errors.New("todo: unknown type: " + fmt.Sprintf("%#v", goType)) } } func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) { switch typ.TypeKind() { case llvm.ArrayTypeKind: subTyp := typ.ElementType() vals := make([]llvm.Value, typ.ArrayLength()) for i := range vals { val, err := c.getZeroValue(subTyp) if err != nil { return llvm.Value{}, err } vals[i] = val } return llvm.ConstArray(subTyp, vals), nil case llvm.IntegerTypeKind: return llvm.ConstInt(typ, 0, false), nil case llvm.PointerTypeKind: return llvm.ConstPointerNull(typ), nil case llvm.StructTypeKind: types := typ.StructElementTypes() vals := make([]llvm.Value, len(types)) for i, subTyp := range types { val, err := c.getZeroValue(subTyp) if err != nil { return llvm.Value{}, err } vals[i] = val } return llvm.ConstStruct(vals, false), nil default: return llvm.Value{}, errors.New("todo: LLVM zero initializer") } } func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value { if _, ok := c.itfTypeNumbers[typ]; !ok { num := uint64(len(c.itfTypes)) c.itfTypes = append(c.itfTypes, typ) c.itfTypeNumbers[typ] = num } return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false) } func (c *Compiler) isPointer(typ types.Type) bool { if _, ok := typ.(*types.Pointer); ok { return true } else if typ, ok := typ.(*types.Basic); ok && typ.Kind() == types.UnsafePointer { return true } else { return false } } func getFunctionName(fn *ssa.Function) string { if fn.Signature.Recv() != nil { // Method on a defined type. typeName := fn.Params[0].Type().(*types.Named).Obj().Name() return pkgPrefix(fn.Pkg) + "." + typeName + "." + fn.Name() } else { // Bare function. if strings.HasPrefix(fn.Name(), "_Cfunc_") { // Name CGo functions directly. return fn.Name()[len("_Cfunc_"):] } else { name := pkgPrefix(fn.Pkg) + "." + fn.Name() if fn.Pkg.Pkg.Path() == "runtime" && strings.HasPrefix(fn.Name(), "_llvm_") { // Special case for LLVM intrinsics in the runtime. name = "llvm." + strings.Replace(fn.Name()[len("_llvm_"):], "_", ".", -1) } return name } } } func getGlobalName(global *ssa.Global) string { if strings.HasPrefix(global.Name(), "_extern_") { return global.Name()[len("_extern_"):] } else { return pkgPrefix(global.Pkg) + "." + global.Name() } } func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { // Make sure we're walking through all members in a constant order every // run. memberNames := make([]string, 0) for name := range pkg.Members { if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { // _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall continue // CGo-internal functions } if strings.HasPrefix(name, "__cgofn__cgo_") { continue // CGo function pointer in global scope } memberNames = append(memberNames, name) } sort.Strings(memberNames) frames := make(map[*ssa.Function]*Frame) // First, build all function declarations. for _, name := range memberNames { member := pkg.Members[name] switch member := member.(type) { case *ssa.Function: frame, err := c.parseFuncDecl(member) if err != nil { return err } frames[member] = frame if member.Synthetic == "package initializer" { c.initFuncs = append(c.initFuncs, frame.llvmFn) } // TODO: recursively anonymous functions for _, child := range member.AnonFuncs { frame, err := c.parseFuncDecl(child) if err != nil { return err } frames[child] = frame } case *ssa.NamedConst: // Ignore package-level untyped constants. The SSA form doesn't need // them. case *ssa.Global: typ := member.Type() if typPtr, ok := typ.(*types.Pointer); ok { typ = typPtr.Elem() } else { return errors.New("global is not a pointer") } llvmType, err := c.getLLVMType(typ) if err != nil { return err } global := llvm.AddGlobal(c.mod, llvmType, getGlobalName(member)) if !strings.HasPrefix(member.Name(), "_extern_") { global.SetLinkage(llvm.PrivateLinkage) if getGlobalName(member) == "runtime.TargetBits" { bitness := c.targetData.PointerSize() * 8 if bitness < 32 { // Only 8 and 32+ architectures supported at the moment. // On 8 bit architectures, pointers are normally bigger // than 8 bits to do anything meaningful. // TODO: clean up this hack to support 16-bit // architectures. bitness = 8 } global.SetInitializer(llvm.ConstInt(llvm.Int8Type(), uint64(bitness), false)) global.SetGlobalConstant(true) } else { initializer, err := c.getZeroValue(llvmType) if err != nil { return err } global.SetInitializer(initializer) } } case *ssa.Type: ms := program.MethodSets.MethodSet(member.Type()) for i := 0; i < ms.Len(); i++ { fn := program.MethodValue(ms.At(i)) frame, err := c.parseFuncDecl(fn) if err != nil { return err } frames[fn] = frame } default: return errors.New("todo: member: " + fmt.Sprintf("%#v", member)) } } // Now, add definitions to those declarations. for _, name := range memberNames { member := pkg.Members[name] switch member := member.(type) { case *ssa.Function: if strings.HasPrefix(name, "_Cfunc_") { // CGo function. Don't implement it's body. continue } if member.Blocks == nil { continue // external function } var err error if member.Synthetic == "package initializer" { err = c.parseInitFunc(frames[member], member) } else { err = c.parseFunc(frames[member], member) } if err != nil { return err } case *ssa.Type: ms := program.MethodSets.MethodSet(member.Type()) for i := 0; i < ms.Len(); i++ { fn := program.MethodValue(ms.At(i)) err := c.parseFunc(frames[fn], fn) if err != nil { return err } } } } return nil } func (c *Compiler) parseFuncDecl(f *ssa.Function) (*Frame, error) { f.WriteTo(os.Stdout) frame := &Frame{ params: make(map[*ssa.Parameter]int), locals: make(map[ssa.Value]llvm.Value), blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock), } var retType llvm.Type if f.Signature.Results() == nil { retType = llvm.VoidType() } else if f.Signature.Results().Len() == 1 { var err error retType, err = c.getLLVMType(f.Signature.Results().At(0).Type()) if err != nil { return nil, err } } else { return nil, errors.New("todo: return values") } var paramTypes []llvm.Type for i, param := range f.Params { paramType, err := c.getLLVMType(param.Type()) if err != nil { return nil, err } paramTypes = append(paramTypes, paramType) frame.params[param] = i } fnType := llvm.FunctionType(retType, paramTypes, false) name := getFunctionName(f) frame.llvmFn = c.mod.NamedFunction(name) if frame.llvmFn.IsNil() { frame.llvmFn = llvm.AddFunction(c.mod, name, fnType) } return frame, nil } // Special function parser for generated package initializers (which also // initializes global variables). func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { frame.llvmFn.SetLinkage(llvm.PrivateLinkage) llvmBlock := c.ctx.AddBasicBlock(frame.llvmFn, "entry") c.builder.SetInsertPointAtEnd(llvmBlock) for _, block := range f.DomPreorder() { for _, instr := range block.Instrs { var err error switch instr := instr.(type) { case *ssa.Call, *ssa.Return: err = c.parseInstr(frame, instr) case *ssa.Convert: // Ignore: CGo pointer conversion. case *ssa.FieldAddr, *ssa.IndexAddr: // Ignore: handled below with *ssa.Store. case *ssa.Store: switch addr := instr.Addr.(type) { case *ssa.Global: // Regular store, like a global int variable. if strings.HasPrefix(addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(addr.Name(), "_cgo_") { // Ignore CGo global variables which we don't use. continue } val, err := c.parseExpr(frame, instr.Val) if err != nil { return err } llvmAddr := c.mod.NamedGlobal(getGlobalName(addr)) llvmAddr.SetInitializer(val) case *ssa.FieldAddr: // Initialize field of a global struct. // LLVM does not allow setting an initializer on part of a // global variable. So we take the current initializer, add // the field, and replace the initializer with the new // initializer. val, err := c.parseExpr(frame, instr.Val) if err != nil { return err } global := addr.X.(*ssa.Global) llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) llvmValue := llvmAddr.Initializer() if llvmValue.IsNil() { llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) if err != nil { return err } } llvmValue = c.builder.CreateInsertValue(llvmValue, val, addr.Field, "") llvmAddr.SetInitializer(llvmValue) case *ssa.IndexAddr: val, err := c.parseExpr(frame, instr.Val) if err != nil { return err } constIndex := addr.Index.(*ssa.Const) index, exact := constant.Int64Val(constIndex.Value) if !exact { return errors.New("could not get store index: " + constIndex.Value.ExactString()) } fieldAddr := addr.X.(*ssa.FieldAddr) global := fieldAddr.X.(*ssa.Global) llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) llvmValue := llvmAddr.Initializer() if llvmValue.IsNil() { llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) if err != nil { return err } } llvmFieldValue := c.builder.CreateExtractValue(llvmValue, fieldAddr.Field, "") llvmFieldValue = c.builder.CreateInsertValue(llvmFieldValue, val, int(index), "") llvmValue = c.builder.CreateInsertValue(llvmValue, llvmFieldValue, fieldAddr.Field, "") llvmAddr.SetInitializer(llvmValue) default: return errors.New("unknown init store: " + fmt.Sprintf("%#v", addr)) } default: return errors.New("unknown init instruction: " + fmt.Sprintf("%#v", instr)) } if err != nil { return err } } } return nil } func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { frame.llvmFn.SetLinkage(llvm.PrivateLinkage) // Pre-create all basic blocks in the function. for _, block := range f.DomPreorder() { llvmBlock := c.ctx.AddBasicBlock(frame.llvmFn, block.Comment) frame.blocks[block] = llvmBlock } // Load function parameters for _, param := range f.Params { llvmParam := frame.llvmFn.Param(frame.params[param]) frame.locals[param] = llvmParam } // Fill those blocks with instructions. for _, block := range f.DomPreorder() { c.builder.SetInsertPointAtEnd(frame.blocks[block]) for _, instr := range block.Instrs { err := c.parseInstr(frame, instr) if err != nil { return err } } } // Resolve phi nodes for _, phi := range frame.phis { block := phi.ssa.Block() for i, edge := range phi.ssa.Edges { llvmVal, err := c.parseExpr(frame, edge) if err != nil { return err } llvmBlock := frame.blocks[block.Preds[i]] phi.llvm.AddIncoming([]llvm.Value{llvmVal}, []llvm.BasicBlock{llvmBlock}) } } return nil } func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { switch instr := instr.(type) { case ssa.Value: value, err := c.parseExpr(frame, instr) frame.locals[instr] = value return err case *ssa.If: cond, err := c.parseExpr(frame, instr.Cond) if err != nil { return err } block := instr.Block() blockThen := frame.blocks[block.Succs[0]] blockElse := frame.blocks[block.Succs[1]] c.builder.CreateCondBr(cond, blockThen, blockElse) return nil case *ssa.Jump: blockJump := frame.blocks[instr.Block().Succs[0]] c.builder.CreateBr(blockJump) return nil case *ssa.Panic: value, err := c.parseExpr(frame, instr.X) if err != nil { return err } c.builder.CreateCall(c.mod.NamedFunction("runtime._panic"), []llvm.Value{value}, "") c.builder.CreateUnreachable() return nil case *ssa.Return: if len(instr.Results) == 0 { c.builder.CreateRetVoid() return nil } else if len(instr.Results) == 1 { val, err := c.parseExpr(frame, instr.Results[0]) if err != nil { return err } c.builder.CreateRet(val) return nil } else { return errors.New("todo: return value") } case *ssa.Store: llvmAddr, err := c.parseExpr(frame, instr.Addr) if err != nil { return err } llvmVal, err := c.parseExpr(frame, instr.Val) if err != nil { return err } valType := instr.Addr.Type().(*types.Pointer).Elem() if valType, ok := valType.(*types.Named); ok && valType.Obj().Name() == "__reg" { // Magic type name to transform this store to a register store. registerAddr := c.builder.CreateLoad(llvmAddr, "") ptr := c.builder.CreateIntToPtr(registerAddr, llvmAddr.Type(), "") store := c.builder.CreateStore(llvmVal, ptr) store.SetVolatile(true) } else { c.builder.CreateStore(llvmVal, llvmAddr) } return nil default: return errors.New("unknown instruction: " + fmt.Sprintf("%#v", instr)) } } func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string) (llvm.Value, error) { switch callName { case "print", "println": for i, arg := range args { if i >= 1 { c.builder.CreateCall(c.mod.NamedFunction("runtime.printspace"), nil, "") } value, err := c.parseExpr(frame, arg) if err != nil { return llvm.Value{}, err } typ := arg.Type() if _, ok := typ.(*types.Named); ok { typ = typ.Underlying() } switch typ := typ.(type) { case *types.Basic: switch typ.Kind() { case types.Int8: c.builder.CreateCall(c.mod.NamedFunction("runtime.printint8"), []llvm.Value{value}, "") case types.Uint8: c.builder.CreateCall(c.mod.NamedFunction("runtime.printuint8"), []llvm.Value{value}, "") case types.Int, types.Int32: // TODO: assumes a 32-bit int type c.builder.CreateCall(c.mod.NamedFunction("runtime.printint32"), []llvm.Value{value}, "") case types.Uint, types.Uint32: c.builder.CreateCall(c.mod.NamedFunction("runtime.printuint32"), []llvm.Value{value}, "") case types.Int64: c.builder.CreateCall(c.mod.NamedFunction("runtime.printint64"), []llvm.Value{value}, "") case types.Uint64: c.builder.CreateCall(c.mod.NamedFunction("runtime.printuint64"), []llvm.Value{value}, "") case types.String: c.builder.CreateCall(c.mod.NamedFunction("runtime.printstring"), []llvm.Value{value}, "") case types.Uintptr: c.builder.CreateCall(c.mod.NamedFunction("runtime.printptr"), []llvm.Value{value}, "") case types.UnsafePointer: ptrValue := c.builder.CreatePtrToInt(value, c.uintptrType, "") c.builder.CreateCall(c.mod.NamedFunction("runtime.printptr"), []llvm.Value{ptrValue}, "") default: return llvm.Value{}, errors.New("unknown basic arg type: " + fmt.Sprintf("%#v", typ)) } case *types.Pointer: ptrValue := c.builder.CreatePtrToInt(value, c.uintptrType, "") c.builder.CreateCall(c.mod.NamedFunction("runtime.printptr"), []llvm.Value{ptrValue}, "") default: return llvm.Value{}, errors.New("unknown arg type: " + fmt.Sprintf("%#v", typ)) } } if callName == "println" { c.builder.CreateCall(c.mod.NamedFunction("runtime.printnl"), nil, "") } return llvm.Value{}, nil // print() or println() returns void case "len": value, err := c.parseExpr(frame, args[0]) if err != nil { return llvm.Value{}, err } switch typ := args[0].Type().(type) { case *types.Basic: switch typ.Kind() { case types.String: return c.builder.CreateExtractValue(value, 0, "len"), nil default: return llvm.Value{}, errors.New("todo: len: unknown basic type") } default: return llvm.Value{}, errors.New("todo: len: unknown type") } default: return llvm.Value{}, errors.New("todo: builtin: " + callName) } } func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) { fmt.Printf(" function: %s\n", fn) name := getFunctionName(fn) target := c.mod.NamedFunction(name) if target.IsNil() { return llvm.Value{}, errors.New("undefined function: " + name) } var params []llvm.Value for _, param := range call.Args { val, err := c.parseExpr(frame, param) if err != nil { return llvm.Value{}, err } params = append(params, val) } return c.builder.CreateCall(target, params, ""), nil } func (c *Compiler) parseCall(frame *Frame, instr *ssa.Call) (llvm.Value, error) { fmt.Printf(" call: %s\n", instr) switch call := instr.Common().Value.(type) { case *ssa.Builtin: return c.parseBuiltin(frame, instr.Common().Args, call.Name()) case *ssa.Function: return c.parseFunctionCall(frame, instr.Common(), call) default: return llvm.Value{}, errors.New("todo: unknown call type: " + fmt.Sprintf("%#v", call)) } } func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if frame != nil { if value, ok := frame.locals[expr]; ok { // Value is a local variable that has already been computed. if value.IsNil() { return llvm.Value{}, errors.New("undefined local var (from cgo?)") } return value, nil } } switch expr := expr.(type) { case *ssa.Alloc: typ, err := c.getLLVMType(expr.Type().Underlying().(*types.Pointer).Elem()) if err != nil { return llvm.Value{}, err } var buf llvm.Value if expr.Heap { // TODO: escape analysis size := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(typ), false) buf = c.builder.CreateCall(c.allocFunc, []llvm.Value{size}, expr.Comment) buf = c.builder.CreateBitCast(buf, llvm.PointerType(typ, 0), "") } else { buf = c.builder.CreateAlloca(typ, expr.Comment) zero, err := c.getZeroValue(typ) if err != nil { return llvm.Value{}, err } c.builder.CreateStore(zero, buf) // zero-initialize var } return buf, nil case *ssa.BinOp: return c.parseBinOp(frame, expr) case *ssa.Call: return c.parseCall(frame, expr) case *ssa.ChangeType: return c.parseConvert(frame, expr.Type(), expr.X) case *ssa.Const: return c.parseConst(expr) case *ssa.Convert: return c.parseConvert(frame, expr.Type(), expr.X) case *ssa.Extract: value, err := c.parseExpr(frame, expr.Tuple) if err != nil { return llvm.Value{}, err } result := c.builder.CreateExtractValue(value, expr.Index, "") return result, nil case *ssa.FieldAddr: val, err := c.parseExpr(frame, expr.X) if err != nil { return llvm.Value{}, err } indices := []llvm.Value{ llvm.ConstInt(llvm.Int32Type(), 0, false), llvm.ConstInt(llvm.Int32Type(), uint64(expr.Field), false), } return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Global: fullName := getGlobalName(expr) value := c.mod.NamedGlobal(fullName) if value.IsNil() { return llvm.Value{}, errors.New("global not found: " + fullName) } return value, nil case *ssa.IndexAddr: val, err := c.parseExpr(frame, expr.X) if err != nil { return llvm.Value{}, err } index, err := c.parseExpr(frame, expr.Index) if err != nil { return llvm.Value{}, err } // Get buffer length var buflen llvm.Value typ := expr.X.Type().(*types.Pointer).Elem() switch typ := typ.(type) { case *types.Array: buflen = llvm.ConstInt(llvm.Int32Type(), uint64(typ.Len()), false) default: return llvm.Value{}, errors.New("todo: indexaddr: len") } // Bounds check. // LLVM optimizes this away in most cases. // TODO: runtime.boundsCheck is undefined in packages imported by // package runtime, so we have to remove it. This should be fixed. boundsCheck := c.mod.NamedFunction("runtime.boundsCheck") if !boundsCheck.IsNil() { constZero := llvm.ConstInt(c.intType, 0, false) isNegative := c.builder.CreateICmp(llvm.IntSLT, index, constZero, "") // index < 0 isTooBig := c.builder.CreateICmp(llvm.IntSGE, index, buflen, "") // index >= len(value) isOverflow := c.builder.CreateOr(isNegative, isTooBig, "") c.builder.CreateCall(boundsCheck, []llvm.Value{isOverflow}, "") } indices := []llvm.Value{ llvm.ConstInt(llvm.Int32Type(), 0, false), index, } return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Lookup: if expr.CommaOk { return llvm.Value{}, errors.New("todo: lookup with comma-ok") } if _, ok := expr.X.Type().(*types.Map); ok { return llvm.Value{}, errors.New("todo: lookup in map") } // Value type must be a string, which is a basic type. if expr.X.Type().(*types.Basic).Kind() != types.String { panic("lookup on non-string?") } value, err := c.parseExpr(frame, expr.X) if err != nil { return llvm.Value{}, nil } index, err := c.parseExpr(frame, expr.Index) if err != nil { return llvm.Value{}, nil } // Bounds check. // LLVM optimizes this away in most cases. if frame.llvmFn.Name() != "runtime.boundsCheck" { constZero := llvm.ConstInt(c.intType, 0, false) isNegative := c.builder.CreateICmp(llvm.IntSLT, index, constZero, "") // index < 0 strlen, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") if err != nil { return llvm.Value{}, err // shouldn't happen } isTooBig := c.builder.CreateICmp(llvm.IntSGE, index, strlen, "") // index >= len(value) isOverflow := c.builder.CreateOr(isNegative, isTooBig, "") c.builder.CreateCall(c.mod.NamedFunction("runtime.boundsCheck"), []llvm.Value{isOverflow}, "") } // Lookup byte buf := c.builder.CreateExtractValue(value, 1, "") bufPtr := c.builder.CreateGEP(buf, []llvm.Value{index}, "") return c.builder.CreateLoad(bufPtr, ""), nil case *ssa.MakeInterface: val, err := c.parseExpr(frame, expr.X) if err != nil { return llvm.Value{}, err } var itfValue llvm.Value switch typ := expr.X.Type().(type) { case *types.Basic: if typ.Info() & types.IsInteger != 0 { // TODO: 64-bit int on 32-bit platform itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "") } else if typ.Kind() == types.String { // TODO: escape analysis size := c.targetData.TypeAllocSize(c.stringType) sizeValue := llvm.ConstInt(c.uintptrType, size, false) itfValue = c.builder.CreateCall(c.allocFunc, []llvm.Value{sizeValue}, "") itfValueCast := c.builder.CreateBitCast(itfValue, llvm.PointerType(c.stringType, 0), "") c.builder.CreateStore(val, itfValueCast) } else { return llvm.Value{}, errors.New("todo: make interface: unknown basic type") } default: return llvm.Value{}, errors.New("todo: make interface: unknown type") } itfType := c.getInterfaceType(expr.X.Type()) itf := c.ctx.ConstStruct([]llvm.Value{itfType, llvm.Undef(c.i8ptrType)}, false) itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") return itf, nil case *ssa.Phi: t, err := c.getLLVMType(expr.Type()) if err != nil { return llvm.Value{}, err } phi := c.builder.CreatePHI(t, "") frame.phis = append(frame.phis, Phi{expr, phi}) return phi, nil case *ssa.TypeAssert: if !expr.CommaOk { return llvm.Value{}, errors.New("todo: type assert without comma-ok") } itf, err := c.parseExpr(frame, expr.X) if err != nil { return llvm.Value{}, err } assertedType, err := c.getLLVMType(expr.AssertedType) if err != nil { return llvm.Value{}, err } assertedTypeNum := c.getInterfaceType(expr.AssertedType) actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value") var value llvm.Value switch typ := expr.AssertedType.(type) { case *types.Basic: if typ.Info() & types.IsInteger != 0 { value = c.builder.CreatePtrToInt(valuePtr, assertedType, "") } else if typ.Kind() == types.String { valueStringPtr := c.builder.CreateBitCast(valuePtr, llvm.PointerType(c.stringType, 0), "") value = c.builder.CreateLoad(valueStringPtr, "") } else { return llvm.Value{}, errors.New("todo: typeassert: unknown basic type") } default: return llvm.Value{}, errors.New("todo: typeassert: unknown type") } commaOk := c.builder.CreateICmp(llvm.IntEQ, assertedTypeNum, actualTypeNum, "") tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean return tuple, nil case *ssa.UnOp: return c.parseUnOp(frame, expr) default: return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) } } func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error) { x, err := c.parseExpr(frame, binop.X) if err != nil { return llvm.Value{}, err } y, err := c.parseExpr(frame, binop.Y) if err != nil { return llvm.Value{}, err } typ := binop.X.Type() if typNamed, ok := typ.(*types.Named); ok { typ = typNamed.Underlying() } signed := typ.(*types.Basic).Info() & types.IsUnsigned == 0 switch binop.Op { case token.ADD: // + return c.builder.CreateAdd(x, y, ""), nil case token.SUB: // - return c.builder.CreateSub(x, y, ""), nil case token.MUL: // * return c.builder.CreateMul(x, y, ""), nil case token.QUO: // / if signed { return c.builder.CreateSDiv(x, y, ""), nil } else { return c.builder.CreateUDiv(x, y, ""), nil } case token.REM: // % if signed { return c.builder.CreateSRem(x, y, ""), nil } else { return c.builder.CreateURem(x, y, ""), nil } case token.AND: // & return c.builder.CreateAnd(x, y, ""), nil case token.OR: // | return c.builder.CreateOr(x, y, ""), nil case token.XOR: // ^ return c.builder.CreateXor(x, y, ""), nil case token.SHL: // << return c.builder.CreateShl(x, y, ""), nil case token.SHR: // >> if signed { return c.builder.CreateAShr(x, y, ""), nil } else { return c.builder.CreateLShr(x, y, ""), nil } case token.AND_NOT: // &^ // Go specific. Calculate "and not" with x & (~y) inv := c.builder.CreateNot(y, "") // ~y return c.builder.CreateAnd(x, inv, ""), nil case token.EQL: // == return c.builder.CreateICmp(llvm.IntEQ, x, y, ""), nil case token.NEQ: // != return c.builder.CreateICmp(llvm.IntNE, x, y, ""), nil case token.LSS: // < if signed { return c.builder.CreateICmp(llvm.IntSLT, x, y, ""), nil } else { return c.builder.CreateICmp(llvm.IntULT, x, y, ""), nil } case token.LEQ: // <= if signed { return c.builder.CreateICmp(llvm.IntSLE, x, y, ""), nil } else { return c.builder.CreateICmp(llvm.IntULE, x, y, ""), nil } case token.GTR: // > if signed { return c.builder.CreateICmp(llvm.IntSGT, x, y, ""), nil } else { return c.builder.CreateICmp(llvm.IntUGT, x, y, ""), nil } case token.GEQ: // >= if signed { return c.builder.CreateICmp(llvm.IntSGE, x, y, ""), nil } else { return c.builder.CreateICmp(llvm.IntUGE, x, y, ""), nil } default: return llvm.Value{}, errors.New("unknown binop") } } func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) { typ := expr.Type() if named, ok := typ.(*types.Named); ok { typ = named.Underlying() } switch typ := typ.(type) { case *types.Basic: llvmType, err := c.getLLVMType(typ) if err != nil { return llvm.Value{}, err } if typ.Kind() == types.Bool { b := constant.BoolVal(expr.Value) n := uint64(0) if b { n = 1 } return llvm.ConstInt(llvmType, n, false), nil } else if typ.Kind() == types.String { str := constant.StringVal(expr.Value) strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false) strPtr := c.builder.CreateGlobalStringPtr(str, ".str") // TODO: remove \0 at end strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) return strObj, nil } else if typ.Kind() == types.UnsafePointer { if !expr.IsNil() { return llvm.Value{}, errors.New("todo: non-null constant pointer") } return llvm.ConstNull(c.i8ptrType), nil } else if typ.Info() & types.IsUnsigned != 0 { n, _ := constant.Uint64Val(expr.Value) return llvm.ConstInt(llvmType, n, false), nil } else if typ.Info() & types.IsInteger != 0 { // signed n, _ := constant.Int64Val(expr.Value) return llvm.ConstInt(llvmType, uint64(n), true), nil } else { return llvm.Value{}, errors.New("todo: unknown constant: " + fmt.Sprintf("%v", typ)) } default: return llvm.Value{}, errors.New("todo: unknown constant: " + fmt.Sprintf("%#v", typ)) } } func (c *Compiler) parseConvert(frame *Frame, typeTo types.Type, x ssa.Value) (llvm.Value, error) { value, err := c.parseExpr(frame, x) if err != nil { return value, nil } llvmTypeFrom, err := c.getLLVMType(x.Type()) if err != nil { return llvm.Value{}, err } llvmTypeTo, err := c.getLLVMType(typeTo) if err != nil { return llvm.Value{}, err } switch typeTo := typeTo.(type) { case *types.Basic: isPtrFrom := c.isPointer(x.Type()) isPtrTo := c.isPointer(typeTo) if isPtrFrom && !isPtrTo { return c.builder.CreatePtrToInt(value, llvmTypeTo, ""), nil } else if !isPtrFrom && isPtrTo { return c.builder.CreateIntToPtr(value, llvmTypeTo, ""), nil } sizeFrom := c.targetData.TypeAllocSize(llvmTypeFrom) sizeTo := c.targetData.TypeAllocSize(llvmTypeTo) if sizeFrom == sizeTo { return c.builder.CreateBitCast(value, llvmTypeTo, ""), nil } if typeTo.Info() & types.IsInteger == 0 { // if not integer return llvm.Value{}, errors.New("todo: convert: extend non-integer type") } if sizeFrom > sizeTo { return c.builder.CreateTrunc(value, llvmTypeTo, ""), nil } else if typeTo.Info() & types.IsUnsigned != 0 { // if unsigned return c.builder.CreateZExt(value, llvmTypeTo, ""), nil } else { // if signed return c.builder.CreateSExt(value, llvmTypeTo, ""), nil } case *types.Named: return c.parseConvert(frame, typeTo.Underlying(), x) case *types.Pointer: return c.builder.CreateBitCast(value, llvmTypeTo, ""), nil default: return llvm.Value{}, errors.New("todo: convert: extend non-basic type: " + fmt.Sprintf("%#v", typeTo)) } } func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) { x, err := c.parseExpr(frame, unop.X) if err != nil { return llvm.Value{}, err } switch unop.Op { case token.NOT: // !x return c.builder.CreateNot(x, ""), nil case token.SUB: // -x return c.builder.CreateSub(llvm.ConstInt(x.Type(), 0, false), x, ""), nil case token.MUL: // *x, dereference pointer valType := unop.X.Type().(*types.Pointer).Elem() if valType, ok := valType.(*types.Named); ok && valType.Obj().Name() == "__reg" { // Magic type name: treat the value as a register pointer. register := unop.X.(*ssa.FieldAddr) global := register.X.(*ssa.Global) llvmGlobal := c.mod.NamedGlobal(getGlobalName(global)) llvmAddr := c.builder.CreateExtractValue(llvmGlobal.Initializer(), register.Field, "") ptr := llvm.ConstIntToPtr(llvmAddr, x.Type()) load := c.builder.CreateLoad(ptr, "") load.SetVolatile(true) return load, nil } else { return c.builder.CreateLoad(x, ""), nil } case token.XOR: // ^x, toggle all bits in integer return c.builder.CreateXor(x, llvm.ConstInt(x.Type(), ^uint64(0), false), ""), nil default: return llvm.Value{}, errors.New("todo: unknown unop") } } // IR returns the whole IR as a human-readable string. func (c *Compiler) IR() string { return c.mod.String() } func (c *Compiler) Verify() error { return llvm.VerifyModule(c.mod, 0) } func (c *Compiler) LinkModule(mod llvm.Module) error { return llvm.LinkModules(c.mod, mod) } func (c *Compiler) ApplyFunctionSections() { // Put every function in a separate section. This makes it possible for the // linker to remove dead code (--gc-sections). llvmFn := c.mod.FirstFunction() for !llvmFn.IsNil() { if !llvmFn.IsDeclaration() { llvmFn.SetSection(".text." + llvmFn.Name()) } llvmFn = llvm.NextFunction(llvmFn) } } func (c *Compiler) Optimize(optLevel, sizeLevel int) { builder := llvm.NewPassManagerBuilder() defer builder.Dispose() builder.SetOptLevel(optLevel) builder.SetSizeLevel(sizeLevel) builder.UseInlinerWithThreshold(200) // TODO depend on opt level, and -Os funcPasses := llvm.NewFunctionPassManagerForModule(c.mod) defer funcPasses.Dispose() builder.PopulateFunc(funcPasses) modPasses := llvm.NewPassManager() defer modPasses.Dispose() builder.Populate(modPasses) modPasses.Run(c.mod) } func (c *Compiler) EmitObject(path string) error { // Generate output var buf []byte if strings.HasSuffix(path, ".o") { llvmBuf, err := c.machine.EmitToMemoryBuffer(c.mod, llvm.ObjectFile) if err != nil { return err } buf = llvmBuf.Bytes() } else if strings.HasSuffix(path, ".bc") { buf = llvm.WriteBitcodeToMemoryBuffer(c.mod).Bytes() } else if strings.HasSuffix(path, ".ll") { buf = []byte(c.mod.String()) } else { return errors.New("unknown output file extension") } // Write output to file f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) if err != nil { return err } f.Write(buf) f.Close() return nil } // Helper function for Compiler object. func Compile(pkgName, runtimePath, outpath, target string, printIR bool) error { var buildTags []string // TODO: put this somewhere else if target == "pca10040" { buildTags = append(buildTags, "nrf", "nrf52", "nrf52832") target = "armv7m-none-eabi" } c, err := NewCompiler(pkgName, target) if err != nil { return err } // Add C/LLVM runtime. runtime, err := llvm.ParseBitcodeFile(runtimePath) if err != nil { return err } err = c.LinkModule(runtime) if err != nil { return err } // Compile Go code to IR. parseErr := c.Parse(pkgName, buildTags) if printIR { fmt.Println(c.IR()) } if parseErr != nil { return parseErr } c.ApplyFunctionSections() // -ffunction-sections if err := c.Verify(); err != nil { return err } c.Optimize(2, 1) // -O2 -Os if err := c.Verify(); err != nil { return err } err = c.EmitObject(outpath) if err != nil { return err } return nil } func main() { outpath := flag.String("o", "", "output filename") printIR := flag.Bool("printir", false, "print LLVM IR after optimizing") runtime := flag.String("runtime", "", "runtime LLVM bitcode files (from C sources)") target := flag.String("target", llvm.DefaultTargetTriple(), "LLVM target") flag.Parse() if *outpath == "" || flag.NArg() != 1 { fmt.Fprintf(os.Stderr, "usage: %s [-printir] -runtime= [-target=] -o ", os.Args[0]) flag.PrintDefaults() return } os.Setenv("CC", "clang -target=" + *target) err := Compile(flag.Args()[0], *runtime, *outpath, *target, *printIR) if err != nil { fmt.Fprintln(os.Stderr, "error:", err) os.Exit(1) } }