You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

338 lines
12 KiB

package compiler
// This file implements the 'defer' keyword in Go.
// Defer statements are implemented by transforming the function in the
// following way:
// * Creating an alloca in the entry block that contains a pointer (initially
// null) to the linked list of defer frames.
// * Every time a defer statement is executed, a new defer frame is created
// using alloca with a pointer to the previous defer frame, and the head
// pointer in the entry block is replaced with a pointer to this defer
// frame.
// * On return, runtime.rundefers is called which calls all deferred functions
// from the head of the linked list until it has gone through all defer
// frames.
import (
"github.com/tinygo-org/tinygo/ir"
"golang.org/x/tools/go/ssa"
"tinygo.org/x/go-llvm"
)
// deferInitFunc sets up this function for future deferred calls. It must be
// called from within the entry block when this function contains deferred
// calls.
func (c *Compiler) deferInitFunc(frame *Frame) {
// Some setup.
frame.deferFuncs = make(map[*ir.Function]int)
frame.deferInvokeFuncs = make(map[string]int)
frame.deferClosureFuncs = make(map[*ir.Function]int)
// Create defer list pointer.
deferType := llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)
frame.deferPtr = c.builder.CreateAlloca(deferType, "deferPtr")
c.builder.CreateStore(llvm.ConstPointerNull(deferType), frame.deferPtr)
}
// emitDefer emits a single defer instruction, to be run when this function
// returns.
func (c *Compiler) emitDefer(frame *Frame, instr *ssa.Defer) error {
// The pointer to the previous defer struct, which we will replace to
// make a linked list.
next := c.builder.CreateLoad(frame.deferPtr, "defer.next")
var values []llvm.Value
valueTypes := []llvm.Type{c.uintptrType, next.Type()}
if instr.Call.IsInvoke() {
// Method call on an interface.
// Get callback type number.
methodName := instr.Call.Method.FullName()
if _, ok := frame.deferInvokeFuncs[methodName]; !ok {
frame.deferInvokeFuncs[methodName] = len(frame.allDeferFuncs)
frame.allDeferFuncs = append(frame.allDeferFuncs, &instr.Call)
}
callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferInvokeFuncs[methodName]), false)
// Collect all values to be put in the struct (starting with
// runtime._defer fields, followed by the call parameters).
itf, err := c.parseExpr(frame, instr.Call.Value) // interface
if err != nil {
return err
}
receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver")
values = []llvm.Value{callback, next, receiverValue}
valueTypes = append(valueTypes, c.i8ptrType)
for _, arg := range instr.Call.Args {
val, err := c.parseExpr(frame, arg)
if err != nil {
return err
}
values = append(values, val)
valueTypes = append(valueTypes, val.Type())
}
} else if callee, ok := instr.Call.Value.(*ssa.Function); ok {
// Regular function call.
fn := c.ir.GetFunction(callee)
if _, ok := frame.deferFuncs[fn]; !ok {
frame.deferFuncs[fn] = len(frame.allDeferFuncs)
frame.allDeferFuncs = append(frame.allDeferFuncs, fn)
}
callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferFuncs[fn]), false)
// Collect all values to be put in the struct (starting with
// runtime._defer fields).
values = []llvm.Value{callback, next}
for _, param := range instr.Call.Args {
llvmParam, err := c.parseExpr(frame, param)
if err != nil {
return err
}
values = append(values, llvmParam)
valueTypes = append(valueTypes, llvmParam.Type())
}
} else if makeClosure, ok := instr.Call.Value.(*ssa.MakeClosure); ok {
// Immediately applied function literal with free variables.
// Extract the context from the closure. We won't need the function
// pointer.
// TODO: ignore this closure entirely and put pointers to the free
// variables directly in the defer struct, avoiding a memory allocation.
closure, err := c.parseExpr(frame, instr.Call.Value)
if err != nil {
return err
}
context := c.builder.CreateExtractValue(closure, 0, "")
// Get the callback number.
fn := c.ir.GetFunction(makeClosure.Fn.(*ssa.Function))
if _, ok := frame.deferClosureFuncs[fn]; !ok {
frame.deferClosureFuncs[fn] = len(frame.allDeferFuncs)
frame.allDeferFuncs = append(frame.allDeferFuncs, makeClosure)
}
callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferClosureFuncs[fn]), false)
// Collect all values to be put in the struct (starting with
// runtime._defer fields, followed by all parameters including the
// context pointer).
values = []llvm.Value{callback, next}
for _, param := range instr.Call.Args {
llvmParam, err := c.parseExpr(frame, param)
if err != nil {
return err
}
values = append(values, llvmParam)
valueTypes = append(valueTypes, llvmParam.Type())
}
values = append(values, context)
valueTypes = append(valueTypes, context.Type())
} else {
return c.makeError(instr.Pos(), "todo: defer on uncommon function call type")
}
// Make a struct out of the collected values to put in the defer frame.
deferFrameType := c.ctx.StructType(valueTypes, false)
deferFrame, err := c.getZeroValue(deferFrameType)
if err != nil {
return err
}
for i, value := range values {
deferFrame = c.builder.CreateInsertValue(deferFrame, value, i, "")
}
// Put this struct in an alloca.
alloca := c.builder.CreateAlloca(deferFrameType, "defer.alloca")
c.builder.CreateStore(deferFrame, alloca)
// Push it on top of the linked list by replacing deferPtr.
allocaCast := c.builder.CreateBitCast(alloca, next.Type(), "defer.alloca.cast")
c.builder.CreateStore(allocaCast, frame.deferPtr)
return nil
}
// emitRunDefers emits code to run all deferred functions.
func (c *Compiler) emitRunDefers(frame *Frame) error {
// Add a loop like the following:
// for stack != nil {
// _stack := stack
// stack = stack.next
// switch _stack.callback {
// case 0:
// // run first deferred call
// case 1:
// // run second deferred call
// // etc.
// default:
// unreachable
// }
// }
// Create loop.
loophead := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loophead")
loop := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loop")
unreachable := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.default")
end := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.end")
c.builder.CreateBr(loophead)
// Create loop head:
// for stack != nil {
c.builder.SetInsertPointAtEnd(loophead)
deferData := c.builder.CreateLoad(frame.deferPtr, "")
stackIsNil := c.builder.CreateICmp(llvm.IntEQ, deferData, llvm.ConstPointerNull(deferData.Type()), "stackIsNil")
c.builder.CreateCondBr(stackIsNil, end, loop)
// Create loop body:
// _stack := stack
// stack = stack.next
// switch stack.callback {
c.builder.SetInsertPointAtEnd(loop)
nextStackGEP := c.builder.CreateGEP(deferData, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
llvm.ConstInt(c.ctx.Int32Type(), 1, false), // .next field
}, "stack.next.gep")
nextStack := c.builder.CreateLoad(nextStackGEP, "stack.next")
c.builder.CreateStore(nextStack, frame.deferPtr)
gep := c.builder.CreateGEP(deferData, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
llvm.ConstInt(c.ctx.Int32Type(), 0, false), // .callback field
}, "callback.gep")
callback := c.builder.CreateLoad(gep, "callback")
sw := c.builder.CreateSwitch(callback, unreachable, len(frame.allDeferFuncs))
for i, callback := range frame.allDeferFuncs {
// Create switch case, for example:
// case 0:
// // run first deferred call
block := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.callback")
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(i), false), block)
c.builder.SetInsertPointAtEnd(block)
switch callback := callback.(type) {
case *ssa.CallCommon:
// Call on an interface value.
if !callback.IsInvoke() {
panic("expected an invoke call, not a direct call")
}
// Get the real defer struct type and cast to it.
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0), c.i8ptrType}
for _, arg := range callback.Args {
llvmType, err := c.getLLVMType(arg.Type())
if err != nil {
return err
}
valueTypes = append(valueTypes, llvmType)
}
deferFrameType := c.ctx.StructType(valueTypes, false)
deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
// Extract the params from the struct (including receiver).
forwardParams := []llvm.Value{}
zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
for i := 2; i < len(valueTypes); i++ {
gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "gep")
forwardParam := c.builder.CreateLoad(gep, "param")
forwardParams = append(forwardParams, forwardParam)
}
// Add the context parameter. An interface call cannot also be a
// closure but we have to supply the parameter anyway for platforms
// with a strict calling convention.
forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
// Parent coroutine handle.
forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
fnPtr, _, err := c.getInvokeCall(frame, callback)
if err != nil {
return err
}
c.createCall(fnPtr, forwardParams, "")
case *ir.Function:
// Direct call.
// Get the real defer struct type and cast to it.
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
for _, param := range callback.Params {
llvmType, err := c.getLLVMType(param.Type())
if err != nil {
return err
}
valueTypes = append(valueTypes, llvmType)
}
deferFrameType := c.ctx.StructType(valueTypes, false)
deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
// Extract the params from the struct.
forwardParams := []llvm.Value{}
zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
for i := range callback.Params {
gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+2), false)}, "gep")
forwardParam := c.builder.CreateLoad(gep, "param")
forwardParams = append(forwardParams, forwardParam)
}
// Add the context parameter. We know it is ignored by the receiving
// function, but we have to pass one anyway.
forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
// Parent coroutine handle.
forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
// Call real function.
c.createCall(callback.LLVMFn, forwardParams, "")
case *ssa.MakeClosure:
// Get the real defer struct type and cast to it.
fn := c.ir.GetFunction(callback.Fn.(*ssa.Function))
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
params := fn.Signature.Params()
for i := 0; i < params.Len(); i++ {
llvmType, err := c.getLLVMType(params.At(i).Type())
if err != nil {
return err
}
valueTypes = append(valueTypes, llvmType)
}
valueTypes = append(valueTypes, c.i8ptrType) // closure
deferFrameType := c.ctx.StructType(valueTypes, false)
deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
// Extract the params from the struct.
forwardParams := []llvm.Value{}
zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
for i := 2; i < len(valueTypes); i++ {
gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "")
forwardParam := c.builder.CreateLoad(gep, "param")
forwardParams = append(forwardParams, forwardParam)
}
// Parent coroutine handle.
forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
// Call deferred function.
c.createCall(fn.LLVMFn, forwardParams, "")
default:
panic("unknown deferred function type")
}
// Branch back to the start of the loop.
c.builder.CreateBr(loophead)
}
// Create default unreachable block:
// default:
// unreachable
// }
c.builder.SetInsertPointAtEnd(unreachable)
c.builder.CreateUnreachable()
// End of loop.
c.builder.SetInsertPointAtEnd(end)
return nil
}