From bd96330d037660d9a1769c6c0d989f017e5f0278 Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Wed, 17 Oct 2018 10:44:42 -0300 Subject: [PATCH] First "complete" implementation of to-be-closed variables Still missing: - handling of memory errors when creating upvalue (must run closing method all the same) - interaction with coroutines --- ldo.c | 20 +++++++++------ ldo.h | 1 + lfunc.c | 47 +++++++++++++++++++++++++++++++--- lfunc.h | 2 +- lgc.c | 5 ++++ lparser.c | 2 +- lstate.c | 4 +-- ltests.c | 2 +- ltm.c | 2 +- ltm.h | 1 + lvm.c | 7 +++--- testes/api.lua | 14 ++++++++++- testes/locals.lua | 64 +++++++++++++++++++++++++++++++++++++++++++---- 13 files changed, 145 insertions(+), 26 deletions(-) diff --git a/ldo.c b/ldo.c index 2349aaed..f2d12f04 100644 --- a/ldo.c +++ b/ldo.c @@ -88,7 +88,7 @@ struct lua_longjmp { }; -static void seterrorobj (lua_State *L, int errcode, StkId oldtop) { +void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop) { switch (errcode) { case LUA_ERRMEM: { /* memory error? */ setsvalue2s(L, oldtop, G(L)->memerrmsg); /* reuse preregistered msg. */ @@ -121,7 +121,7 @@ l_noret luaD_throw (lua_State *L, int errcode) { } else { /* no handler at all; abort */ if (g->panic) { /* panic function? */ - seterrorobj(L, errcode, L->top); /* assume EXTRA_STACK */ + luaD_seterrorobj(L, errcode, L->top); /* assume EXTRA_STACK */ if (L->ci->top < L->top) L->ci->top = L->top; /* pushing msg. can break this invariant */ lua_unlock(L); @@ -584,8 +584,8 @@ static int recover (lua_State *L, int status) { if (ci == NULL) return 0; /* no recovery point */ /* "finish" luaD_pcall */ oldtop = restorestack(L, ci->u2.funcidx); - luaF_close(L, oldtop); - seterrorobj(L, status, oldtop); + luaF_close(L, oldtop, status); + luaD_seterrorobj(L, status, oldtop); L->ci = ci; L->allowhook = getoah(ci->callstatus); /* restore original 'allowhook' */ L->nny = 0; /* should be zero to be yieldable */ @@ -678,7 +678,7 @@ LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs, } if (unlikely(errorstatus(status))) { /* unrecoverable error? */ L->status = cast_byte(status); /* mark thread as 'dead' */ - seterrorobj(L, status, L->top); /* push error message */ + luaD_seterrorobj(L, status, L->top); /* push error message */ L->ci->top = L->top; } else lua_assert(status == L->status); /* normal end or yield */ @@ -726,6 +726,11 @@ LUA_API int lua_yieldk (lua_State *L, int nresults, lua_KContext ctx, } +/* +** Call the C function 'func' in protected mode, restoring basic +** thread information ('allowhook', 'nny', etc.) and in particular +** its stack level in case of errors. +*/ int luaD_pcall (lua_State *L, Pfunc func, void *u, ptrdiff_t old_top, ptrdiff_t ef) { int status; @@ -737,11 +742,12 @@ int luaD_pcall (lua_State *L, Pfunc func, void *u, status = luaD_rawrunprotected(L, func, u); if (unlikely(status != LUA_OK)) { /* an error occurred? */ StkId oldtop = restorestack(L, old_top); - luaF_close(L, oldtop); /* close possible pending closures */ - seterrorobj(L, status, oldtop); L->ci = old_ci; L->allowhook = old_allowhooks; L->nny = old_nny; + status = luaF_close(L, oldtop, status); + oldtop = restorestack(L, old_top); /* previous call may change stack */ + luaD_seterrorobj(L, status, oldtop); luaD_shrinkstack(L); } L->errfunc = old_errfunc; diff --git a/ldo.h b/ldo.h index c836a2a1..7760f853 100644 --- a/ldo.h +++ b/ldo.h @@ -50,6 +50,7 @@ /* type of protected functions, to be ran by 'runprotected' */ typedef void (*Pfunc) (lua_State *L, void *ud); +LUAI_FUNC void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop); LUAI_FUNC int luaD_protectedparser (lua_State *L, ZIO *z, const char *name, const char *mode); LUAI_FUNC void luaD_hook (lua_State *L, int event, int line, diff --git a/lfunc.c b/lfunc.c index 16e00731..fde72b8c 100644 --- a/lfunc.c +++ b/lfunc.c @@ -14,6 +14,7 @@ #include "lua.h" +#include "ldo.h" #include "lfunc.h" #include "lgc.h" #include "lmem.h" @@ -83,6 +84,40 @@ UpVal *luaF_findupval (lua_State *L, StkId level) { } +static void callclose (lua_State *L, void *ud) { + luaD_callnoyield(L, cast(StkId, ud), 0); +} + + +static int closeupval (lua_State *L, UpVal *uv, StkId level, int status) { + StkId func = level + 1; /* save slot for old error message */ + if (status != LUA_OK) /* was there an error? */ + luaD_seterrorobj(L, status, level); /* save error message */ + else + setnilvalue(s2v(level)); + if (ttisfunction(uv->v)) { /* object to-be-closed is a function? */ + setobj2s(L, func, uv->v); /* will call it */ + setobjs2s(L, func + 1, level); /* error msg. as argument */ + } + else { /* try '__close' metamethod */ + const TValue *tm = luaT_gettmbyobj(L, uv->v, TM_CLOSE); + if (ttisnil(tm)) + return status; /* no metamethod */ + setobj2s(L, func, tm); /* will call metamethod */ + setobj2s(L, func + 1, uv->v); /* with 'self' as argument */ + } + L->top = func + 2; /* add function and argument */ + if (status == LUA_OK) /* not in "error mode"? */ + callclose(L, func); /* call closing method */ + else { /* already inside error handler; cannot raise another error */ + int newstatus = luaD_pcall(L, callclose, func, savestack(L, level), 0); + if (newstatus != LUA_OK) /* error when closing? */ + status = newstatus; /* this will be the new error */ + } + return status; +} + + void luaF_unlinkupval (UpVal *uv) { lua_assert(upisopen(uv)); *uv->u.open.previous = uv->u.open.next; @@ -91,10 +126,10 @@ void luaF_unlinkupval (UpVal *uv) { } -void luaF_close (lua_State *L, StkId level) { +int luaF_close (lua_State *L, StkId level, int status) { UpVal *uv; - while (L->openupval != NULL && - (uv = L->openupval, uplevel(uv) >= level)) { + while ((uv = L->openupval) != NULL && uplevel(uv) >= level) { + StkId upl = uplevel(uv); TValue *slot = &uv->u.value; /* new position for value */ luaF_unlinkupval(uv); setobj(L, slot, uv->v); /* move value to upvalue slot */ @@ -102,7 +137,13 @@ void luaF_close (lua_State *L, StkId level) { if (!iswhite(uv)) gray2black(uv); /* closed upvalues cannot be gray */ luaC_barrier(L, uv, slot); + if (status >= 0 && uv->tt == LUA_TUPVALTBC) { /* must be closed? */ + ptrdiff_t levelrel = savestack(L, level); + status = closeupval(L, uv, upl, status); /* may reallocate the stack */ + level = restorestack(L, levelrel); + } } + return status; } diff --git a/lfunc.h b/lfunc.h index 859ccc12..4c788005 100644 --- a/lfunc.h +++ b/lfunc.h @@ -47,7 +47,7 @@ LUAI_FUNC CClosure *luaF_newCclosure (lua_State *L, int nelems); LUAI_FUNC LClosure *luaF_newLclosure (lua_State *L, int nelems); LUAI_FUNC void luaF_initupvals (lua_State *L, LClosure *cl); LUAI_FUNC UpVal *luaF_findupval (lua_State *L, StkId level); -LUAI_FUNC void luaF_close (lua_State *L, StkId level); +LUAI_FUNC int luaF_close (lua_State *L, StkId level, int status); LUAI_FUNC void luaF_unlinkupval (UpVal *uv); LUAI_FUNC void luaF_freeproto (lua_State *L, Proto *f); LUAI_FUNC const char *luaF_getlocalname (const Proto *func, int local_number, diff --git a/lgc.c b/lgc.c index 39b3ab73..9d196a18 100644 --- a/lgc.c +++ b/lgc.c @@ -609,6 +609,7 @@ static int traverseLclosure (global_State *g, LClosure *cl) { ** That ensures that the entire stack have valid (non-dead) objects. */ static int traversethread (global_State *g, lua_State *th) { + UpVal *uv; StkId o = th->stack; if (o == NULL) return 1; /* stack not completely built yet */ @@ -616,6 +617,10 @@ static int traversethread (global_State *g, lua_State *th) { th->openupval == NULL || isintwups(th)); for (; o < th->top; o++) /* mark live elements in the stack */ markvalue(g, s2v(o)); + for (uv = th->openupval; uv != NULL; uv = uv->u.open.next) { + if (uv->tt == LUA_TUPVALTBC) /* to be closed? */ + markobject(g, uv); /* cannot be collected */ + } if (g->gcstate == GCSatomic) { /* final traversal? */ StkId lim = th->stack + th->stacksize; /* real end of stack */ for (; o < lim; o++) /* clear not-marked stack slice */ diff --git a/lparser.c b/lparser.c index 84abeb90..6b14b800 100644 --- a/lparser.c +++ b/lparser.c @@ -1536,9 +1536,9 @@ static void scopedlocalstat (LexState *ls) { FuncState *fs = ls->fs; new_localvar(ls, str_checkname(ls)); checknext(ls, '='); + exp1(ls, 0); luaK_codeABC(fs, OP_TBC, fs->nactvar, 0, 0); markupval(fs, fs->nactvar); - exp1(ls, 0); adjustlocalvars(ls, 1); } diff --git a/lstate.c b/lstate.c index 8b0219bc..4a2453d1 100644 --- a/lstate.c +++ b/lstate.c @@ -258,7 +258,7 @@ static void preinit_thread (lua_State *L, global_State *g) { static void close_state (lua_State *L) { global_State *g = G(L); - luaF_close(L, L->stack); /* close all upvalues for this thread */ + luaF_close(L, L->stack, -1); /* close all upvalues for this thread */ luaC_freeallobjects(L); /* collect all objects */ if (ttisnil(&g->nilvalue)) /* closing a fully built state? */ luai_userstateclose(L); @@ -301,7 +301,7 @@ LUA_API lua_State *lua_newthread (lua_State *L) { void luaE_freethread (lua_State *L, lua_State *L1) { LX *l = fromstate(L1); - luaF_close(L1, L1->stack); /* close all upvalues for this thread */ + luaF_close(L1, L1->stack, -1); /* close all upvalues for this thread */ lua_assert(L1->openupval == NULL); luai_userstatefree(L, L1); freestack(L1); diff --git a/ltests.c b/ltests.c index ff962543..a6968653 100644 --- a/ltests.c +++ b/ltests.c @@ -1208,7 +1208,7 @@ static int getindex_aux (lua_State *L, lua_State *L1, const char **pc) { static void pushcode (lua_State *L, int code) { static const char *const codes[] = {"OK", "YIELD", "ERRRUN", - "ERRSYNTAX", "ERRMEM", "ERRGCMM", "ERRERR"}; + "ERRSYNTAX", MEMERRMSG, "ERRGCMM", "ERRERR"}; lua_pushstring(L, codes[code]); } diff --git a/ltm.c b/ltm.c index 5c148180..53e15c7f 100644 --- a/ltm.c +++ b/ltm.c @@ -43,7 +43,7 @@ void luaT_init (lua_State *L) { "__div", "__idiv", "__band", "__bor", "__bxor", "__shl", "__shr", "__unm", "__bnot", "__lt", "__le", - "__concat", "__call" + "__concat", "__call", "__close" }; int i; for (i=0; itt = LUA_TUPVALTBC; /* mark it to be closed */ - setnilvalue(s2v(ra)); /* intialize it with nil */ vmbreak; } vmcase(OP_JMP) { @@ -1591,7 +1590,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { int nparams1 = GETARG_C(i); if (nparams1) /* vararg function? */ delta = ci->u.l.nextraargs + nparams1; - luaF_close(L, base); /* close upvalues from current call */ + luaF_close(L, base, LUA_OK); /* close upvalues from current call */ } if (!ttisfunction(s2v(ra))) { /* not a function? */ luaD_tryfuncTM(L, ra); /* try '__call' metamethod */ @@ -1625,7 +1624,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { int nparams1 = GETARG_C(i); if (nparams1) /* vararg function? */ ci->func -= ci->u.l.nextraargs + nparams1; - luaF_close(L, base); /* there may be open upvalues */ + luaF_close(L, base, LUA_OK); /* there may be open upvalues */ } halfProtect(luaD_poscall(L, ci, n)); return; diff --git a/testes/api.lua b/testes/api.lua index bebb6d2d..925a80c1 100644 --- a/testes/api.lua +++ b/testes/api.lua @@ -1,4 +1,4 @@ --- $Id: testes/api.lua $ +-- $Id: testes/api.lua 2018-07-25 15:31:04 -0300 $ -- See Copyright Notice in file all.lua if T==nil then @@ -1027,6 +1027,18 @@ testamem("coroutine creation", function() end) +-- testing to-be-closed variables +testamem("to-be-closed variables", function() + local flag + do + local scoped x = function () flag = true end + flag = false + local x = {} + end + return flag +end) + + -- testing threads -- get main thread from registry (at index LUA_RIDX_MAINTHREAD == 1) diff --git a/testes/locals.lua b/testes/locals.lua index 20ecae4b..8d55e9f5 100644 --- a/testes/locals.lua +++ b/testes/locals.lua @@ -173,15 +173,69 @@ end assert(x==20) --- tests for to-be-closed variables +print"testing to-be-closed variables" + +do + local a = {} + do + local scoped x = setmetatable({"x"}, {__close = function (self) + a[#a + 1] = self[1] end}) + local scoped y = function () a[#a + 1] = "y" end + a[#a + 1] = "in" + end + a[#a + 1] = "out" + assert(a[1] == "in" and a[2] == "y" and a[3] == "x" and a[4] == "out") +end + + +do -- errors in __close + local log = {} + local function foo (err) + local scoped x = function (msg) log[#log + 1] = msg; error(1) end + local scoped x1 = function (msg) log[#log + 1] = msg; end + local scoped gc = function () collectgarbage() end + local scoped y = function (msg) log[#log + 1] = msg; error(2) end + local scoped z = function (msg) log[#log + 1] = msg or 10; error(3) end + if err then error(4) end + end + local stat, msg = pcall(foo, false) + assert(msg == 1) + assert(log[1] == 10 and log[2] == 3 and log[3] == 2 and log[4] == 2 + and #log == 4) + + log = {} + local stat, msg = pcall(foo, true) + assert(msg == 1) + assert(log[1] == 4 and log[2] == 3 and log[3] == 2 and log[4] == 2 + and #log == 4) +end + do - local scoped x = 3 - local a - local scoped y = 5 - assert(x == 3 and y == 5) + -- memory error inside closing function + local function foo () + local scoped y = function () io.write(2); T.alloccount() end + local scoped x = setmetatable({}, {__close = function () + T.alloccount(0); local x = {} -- force a memory error + end}) + io.write("1\n") + error("a") -- common error inside the function's body + end + + local _, msg = pcall(foo) +T.alloccount() + assert(msg == "not enough memory") + end +-- a suspended coroutine should not close its variables when collected +local co = coroutine.wrap(function() + local scoped x = function () os.exit(1) end -- should not run + coroutine.yield() +end) +co() +co = nil + print('OK') return 5,f