diff --git a/lualib/snax/hotfix.lua b/lualib/snax/hotfix.lua index f56a50727..a73d680ea 100644 --- a/lualib/snax/hotfix.lua +++ b/lualib/snax/hotfix.lua @@ -53,8 +53,8 @@ local function collect_all_uv(funcs) end local function loader(source) - return function (filename, ...) - return load(source, "=patch", ...) + return function (path, name, G) + return load(source, "=patch", "bt", G) end end @@ -68,10 +68,10 @@ local function find_func(funcs, group , name) end local dummy_env = {} +for k,v in pairs(_ENV) do dummy_env[k] = v end -local function patch_func(funcs, global, group, name, f) - local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name)) - local i = 1 +local function _patch(global, f) + local i = 1 while true do local name, value = debug.getupvalue(f, i) if name == nil then @@ -81,9 +81,18 @@ local function patch_func(funcs, global, group, name, f) if old_uv then debug.upvaluejoin(f, i, old_uv.func, old_uv.index) end + else + if type(value) == "function" then + _patch(global, value) + end end i = i + 1 end +end + +local function patch_func(funcs, global, group, name, f) + local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name)) + _patch(global, f) desc[4] = f end diff --git a/lualib/snax/interface.lua b/lualib/snax/interface.lua index 740f09d93..e8c8d5961 100644 --- a/lualib/snax/interface.lua +++ b/lualib/snax/interface.lua @@ -1,8 +1,24 @@ local skynet = require "skynet" +local function dft_loader(path, name, G) + local errlist = {} + + for pat in string.gmatch(path,"[^;]+") do + local filename = string.gsub(pat, "?", name) + local f , err = loadfile(filename, "bt", G) + if f then + return f, pat + else + table.insert(errlist, err) + end + end + + error(table.concat(errlist, "\n")) +end + return function (name , G, loader) - loader = loader or loadfile - local mainfunc + loader = loader or dft_loader + local mainfunc local function func_id(id, group) local tmp = {} @@ -61,30 +77,11 @@ return function (name , G, loader) local pattern - do - local path = assert(skynet.getenv "snax" , "please set snax in config file") - - local errlist = {} - - for pat in string.gmatch(path,"[^;]+") do - local filename = string.gsub(pat, "?", name) - local f , err = loader(filename, "bt", G) - if f then - pattern = pat - mainfunc = f - break - else - table.insert(errlist, err) - end - end - - if mainfunc == nil then - error(table.concat(errlist, "\n")) - end - end + local path = assert(skynet.getenv "snax" , "please set snax in config file") + mainfunc, pattern = loader(path, name, G) setmetatable(G, { __index = env , __newindex = init_system }) - local ok, err = pcall(mainfunc) + local ok, err = xpcall(mainfunc, debug.traceback) setmetatable(G, nil) assert(ok,err) diff --git a/service/snaxd.lua b/service/snaxd.lua index 1b8ddee6e..5aaa33537 100644 --- a/service/snaxd.lua +++ b/service/snaxd.lua @@ -5,7 +5,9 @@ local profile = require "profile" local snax = require "snax" local snax_name = tostring(...) -local func, pattern = snax_interface(snax_name, _ENV) +local loaderpath = skynet.getenv"snax_loader" +local loader = loaderpath and assert(dofile(loaderpath)) or require"snax.loader" +local func, pattern = snax_interface(snax_name, _ENV, loader) local snax_path = pattern:sub(1,pattern:find("?", 1, true)-1) .. snax_name .. "/" package.path = snax_path .. "?.lua;" .. package.path