neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

exec_lua.lua (3186B)


      1 --- @param f function
      2 --- @return table<string,any>
      3 local function get_upvalues(f)
      4  local i = 1
      5  local upvalues = {} --- @type table<string,any>
      6  while true do
      7    local n, v = debug.getupvalue(f, i)
      8    if not n then
      9      break
     10    end
     11    upvalues[n] = v
     12    i = i + 1
     13  end
     14  return upvalues
     15 end
     16 
     17 --- @param f function
     18 --- @param upvalues table<string,any>
     19 local function set_upvalues(f, upvalues)
     20  local i = 1
     21  while true do
     22    local n = debug.getupvalue(f, i)
     23    if not n then
     24      break
     25    end
     26    if upvalues[n] then
     27      debug.setupvalue(f, i, upvalues[n])
     28    end
     29    i = i + 1
     30  end
     31 end
     32 
     33 --- @param messages string[]
     34 --- @param ... ...
     35 local function add_print(messages, ...)
     36  local msg = {} --- @type string[]
     37  for i = 1, select('#', ...) do
     38    msg[#msg + 1] = tostring(select(i, ...))
     39  end
     40  table.insert(messages, table.concat(msg, '\t'))
     41 end
     42 
     43 local invalid_types = {
     44  ['thread'] = true,
     45  ['function'] = true,
     46  ['userdata'] = true,
     47 }
     48 
     49 --- @param r any[]
     50 local function check_returns(r)
     51  for k, v in pairs(r) do
     52    if invalid_types[type(v)] then
     53      error(
     54        string.format(
     55          "Return index %d with value '%s' of type '%s' cannot be serialized over RPC",
     56          k,
     57          tostring(v),
     58          type(v)
     59        ),
     60        2
     61      )
     62    end
     63  end
     64 end
     65 
     66 local M = {}
     67 
     68 --- This is run in the context of the remote Nvim instance.
     69 --- @param bytecode string
     70 --- @param upvalues table<string,any>
     71 --- @param ... any[]
     72 --- @return any[] result
     73 --- @return table<string,any> upvalues
     74 --- @return string[] messages
     75 function M.handler(bytecode, upvalues, ...)
     76  local messages = {} --- @type string[]
     77  local orig_print = _G.print
     78 
     79  function _G.print(...)
     80    add_print(messages, ...)
     81    return orig_print(...)
     82  end
     83 
     84  local f = assert(loadstring(bytecode))
     85 
     86  set_upvalues(f, upvalues)
     87 
     88  -- Run in pcall so we can return any print messages
     89  local ret = { pcall(f, ...) } --- @type any[]
     90 
     91  _G.print = orig_print
     92 
     93  local new_upvalues = get_upvalues(f)
     94 
     95  -- Check return value types for better error messages
     96  check_returns(ret)
     97 
     98  return ret, new_upvalues, messages
     99 end
    100 
    101 --- @param session test.Session
    102 --- @param lvl integer
    103 --- @param code function
    104 --- @param ... ...
    105 local function run(session, lvl, code, ...)
    106  local stat, rv = session:request(
    107    'nvim_exec_lua',
    108    [[return { require('test.functional.testnvim.exec_lua').handler(...) }]],
    109    { string.dump(code), get_upvalues(code), ... }
    110  )
    111 
    112  if not stat then
    113    error(rv[2], 2)
    114  end
    115 
    116  --- @type any[], table<string,any>, string[]
    117  local ret, upvalues, messages = unpack(rv)
    118 
    119  for _, m in ipairs(messages) do
    120    print(m)
    121  end
    122 
    123  if not ret[1] then
    124    error(ret[2], 2)
    125  end
    126 
    127  -- Update upvalues
    128  if next(upvalues) then
    129    local caller = debug.getinfo(lvl)
    130    local i = 0
    131 
    132    -- On PUC-Lua, if the function is a tail call, then func will be nil.
    133    -- In this case we need to use the caller.
    134    while not caller.func do
    135      i = i + 1
    136      caller = debug.getinfo(lvl + i)
    137    end
    138    set_upvalues(caller.func, upvalues)
    139  end
    140 
    141  return unpack(ret, 2, table.maxn(ret))
    142 end
    143 
    144 return setmetatable(M, {
    145  __call = function(_, ...)
    146    return run(...)
    147  end,
    148 })