exec_lua.lua (3186B)
1 --- @param f function 2 --- @return table<string,any> 3 local function get_upvalues(f) 4 local i = 1 5 local upvalues = {} --- @type table<string,any> 6 while true do 7 local n, v = debug.getupvalue(f, i) 8 if not n then 9 break 10 end 11 upvalues[n] = v 12 i = i + 1 13 end 14 return upvalues 15 end 16 17 --- @param f function 18 --- @param upvalues table<string,any> 19 local function set_upvalues(f, upvalues) 20 local i = 1 21 while true do 22 local n = debug.getupvalue(f, i) 23 if not n then 24 break 25 end 26 if upvalues[n] then 27 debug.setupvalue(f, i, upvalues[n]) 28 end 29 i = i + 1 30 end 31 end 32 33 --- @param messages string[] 34 --- @param ... ... 35 local function add_print(messages, ...) 36 local msg = {} --- @type string[] 37 for i = 1, select('#', ...) do 38 msg[#msg + 1] = tostring(select(i, ...)) 39 end 40 table.insert(messages, table.concat(msg, '\t')) 41 end 42 43 local invalid_types = { 44 ['thread'] = true, 45 ['function'] = true, 46 ['userdata'] = true, 47 } 48 49 --- @param r any[] 50 local function check_returns(r) 51 for k, v in pairs(r) do 52 if invalid_types[type(v)] then 53 error( 54 string.format( 55 "Return index %d with value '%s' of type '%s' cannot be serialized over RPC", 56 k, 57 tostring(v), 58 type(v) 59 ), 60 2 61 ) 62 end 63 end 64 end 65 66 local M = {} 67 68 --- This is run in the context of the remote Nvim instance. 69 --- @param bytecode string 70 --- @param upvalues table<string,any> 71 --- @param ... any[] 72 --- @return any[] result 73 --- @return table<string,any> upvalues 74 --- @return string[] messages 75 function M.handler(bytecode, upvalues, ...) 76 local messages = {} --- @type string[] 77 local orig_print = _G.print 78 79 function _G.print(...) 80 add_print(messages, ...) 81 return orig_print(...) 82 end 83 84 local f = assert(loadstring(bytecode)) 85 86 set_upvalues(f, upvalues) 87 88 -- Run in pcall so we can return any print messages 89 local ret = { pcall(f, ...) } --- @type any[] 90 91 _G.print = orig_print 92 93 local new_upvalues = get_upvalues(f) 94 95 -- Check return value types for better error messages 96 check_returns(ret) 97 98 return ret, new_upvalues, messages 99 end 100 101 --- @param session test.Session 102 --- @param lvl integer 103 --- @param code function 104 --- @param ... ... 105 local function run(session, lvl, code, ...) 106 local stat, rv = session:request( 107 'nvim_exec_lua', 108 [[return { require('test.functional.testnvim.exec_lua').handler(...) }]], 109 { string.dump(code), get_upvalues(code), ... } 110 ) 111 112 if not stat then 113 error(rv[2], 2) 114 end 115 116 --- @type any[], table<string,any>, string[] 117 local ret, upvalues, messages = unpack(rv) 118 119 for _, m in ipairs(messages) do 120 print(m) 121 end 122 123 if not ret[1] then 124 error(ret[2], 2) 125 end 126 127 -- Update upvalues 128 if next(upvalues) then 129 local caller = debug.getinfo(lvl) 130 local i = 0 131 132 -- On PUC-Lua, if the function is a tail call, then func will be nil. 133 -- In this case we need to use the caller. 134 while not caller.func do 135 i = i + 1 136 caller = debug.getinfo(lvl + i) 137 end 138 set_upvalues(caller.func, upvalues) 139 end 140 141 return unpack(ret, 2, table.maxn(ret)) 142 end 143 144 return setmetatable(M, { 145 __call = function(_, ...) 146 return run(...) 147 end, 148 })