neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

session.lua (6858B)


      1 ---
      2 --- Nvim msgpack-RPC protocol session. Manages requests/notifications/responses.
      3 ---
      4 
      5 local uv = vim.uv
      6 local RpcStream = require('test.client.rpc_stream')
      7 
      8 --- Nvim msgpack-RPC protocol session. Manages requests/notifications/responses.
      9 ---
     10 --- @class test.Session
     11 --- @field private _pending_messages string[] Requests/notifications received from the remote end.
     12 --- @field private _rpc_stream test.RpcStream
     13 --- @field private _prepare uv.uv_prepare_t
     14 --- @field private _timer uv.uv_timer_t
     15 --- @field private _is_running boolean true during `Session:run()` scope.
     16 --- @field data table Arbitrary user data.
     17 local Session = {}
     18 Session.__index = Session
     19 if package.loaded['jit'] then
     20  -- luajit pcall is already coroutine safe
     21  Session.safe_pcall = pcall
     22 else
     23  Session.safe_pcall = require 'coxpcall'.pcall
     24 end
     25 
     26 local function resume(co, ...)
     27  local status, result = coroutine.resume(co, ...)
     28 
     29  if coroutine.status(co) == 'dead' then
     30    if not status then
     31      error(result)
     32    end
     33    return
     34  end
     35 
     36  assert(coroutine.status(co) == 'suspended')
     37  result(co)
     38 end
     39 
     40 local function coroutine_exec(func, ...)
     41  local args = { ... }
     42  local on_complete --- @type function?
     43 
     44  if #args > 0 and type(args[#args]) == 'function' then
     45    -- completion callback
     46    on_complete = table.remove(args)
     47  end
     48 
     49  resume(coroutine.create(function()
     50    local status, result, flag = Session.safe_pcall(func, unpack(args))
     51    if on_complete then
     52      coroutine.yield(function()
     53        -- run the completion callback on the main thread
     54        on_complete(status, result, flag)
     55      end)
     56    end
     57  end))
     58 end
     59 
     60 --- Creates a new msgpack-RPC session.
     61 function Session.new(stream)
     62  return setmetatable({
     63    _rpc_stream = RpcStream.new(stream),
     64    _pending_messages = {},
     65    _prepare = uv.new_prepare(),
     66    _timer = uv.new_timer(),
     67    _is_running = false,
     68  }, Session)
     69 end
     70 
     71 --- @param timeout integer?
     72 --- @return string?
     73 function Session:next_message(timeout)
     74  local function on_request(method, args, response)
     75    table.insert(self._pending_messages, { 'request', method, args, response })
     76    uv.stop()
     77  end
     78 
     79  local function on_notification(method, args)
     80    table.insert(self._pending_messages, { 'notification', method, args })
     81    uv.stop()
     82  end
     83 
     84  if self._is_running then
     85    error('Event loop already running')
     86  end
     87 
     88  if #self._pending_messages > 0 then
     89    return table.remove(self._pending_messages, 1)
     90  end
     91 
     92  -- if closed, only return pending messages
     93  if self.closed then
     94    return nil
     95  end
     96 
     97  self:_run(on_request, on_notification, timeout)
     98  return table.remove(self._pending_messages, 1)
     99 end
    100 
    101 --- Sends a notification to the RPC endpoint.
    102 function Session:notify(method, ...)
    103  self._rpc_stream:write(method, { ... })
    104 end
    105 
    106 --- Sends a request to the RPC endpoint.
    107 ---
    108 --- @param method string
    109 --- @param ... any
    110 --- @return boolean, table
    111 function Session:request(method, ...)
    112  local args = { ... }
    113  local err, result
    114  if self._is_running then
    115    err, result = self:_yielding_request(method, args)
    116  else
    117    err, result = self:_blocking_request(method, args)
    118  end
    119 
    120  if err then
    121    return false, err
    122  end
    123 
    124  return true, result
    125 end
    126 
    127 --- Processes incoming RPC requests/notifications until exhausted.
    128 ---
    129 --- TODO(justinmk): luaclient2 avoids this via uvutil.cb_wait() + uvutil.add_idle_call()?
    130 ---
    131 --- @param request_cb function Handles requests from the sever to the local end.
    132 --- @param notification_cb function Handles notifications from the sever to the local end.
    133 --- @param setup_cb function
    134 --- @param timeout number
    135 function Session:run(request_cb, notification_cb, setup_cb, timeout)
    136  --- Handles an incoming request.
    137  local function on_request(method, args, response)
    138    coroutine_exec(request_cb, method, args, function(status, result, flag)
    139      if status then
    140        response:send(result, flag)
    141      else
    142        response:send(result, true)
    143      end
    144    end)
    145  end
    146 
    147  --- Handles an incoming notification.
    148  local function on_notification(method, args)
    149    coroutine_exec(notification_cb, method, args)
    150  end
    151 
    152  self._is_running = true
    153 
    154  if setup_cb then
    155    coroutine_exec(setup_cb)
    156  end
    157 
    158  while #self._pending_messages > 0 do
    159    local msg = table.remove(self._pending_messages, 1)
    160    if msg[1] == 'request' then
    161      on_request(msg[2], msg[3], msg[4])
    162    else
    163      on_notification(msg[2], msg[3])
    164    end
    165  end
    166 
    167  self:_run(on_request, on_notification, timeout)
    168  self._is_running = false
    169 end
    170 
    171 function Session:stop()
    172  uv.stop()
    173 end
    174 
    175 function Session:close(signal, noblock)
    176  if not self._timer:is_closing() then
    177    self._timer:close()
    178  end
    179  if not self._prepare:is_closing() then
    180    self._prepare:close()
    181  end
    182  self._rpc_stream:close(signal, noblock)
    183  self.closed = true
    184 end
    185 
    186 --- Sends a request to the RPC endpoint, without blocking (schedules a coroutine).
    187 function Session:_yielding_request(method, args)
    188  return coroutine.yield(function(co)
    189    self._rpc_stream:write(method, args, function(err, result)
    190      resume(co, err, result)
    191    end)
    192  end)
    193 end
    194 
    195 --- Sends a request to the RPC endpoint, and blocks (polls event loop) until a response is received.
    196 function Session:_blocking_request(method, args)
    197  local err, result
    198 
    199  -- Invoked when a request is received from the remote end.
    200  local function on_request(method_, args_, response)
    201    table.insert(self._pending_messages, { 'request', method_, args_, response })
    202  end
    203 
    204  -- Invoked when a notification is received from the remote end.
    205  local function on_notification(method_, args_)
    206    table.insert(self._pending_messages, { 'notification', method_, args_ })
    207  end
    208 
    209  self._rpc_stream:write(method, args, function(e, r)
    210    err = e
    211    result = r
    212    uv.stop()
    213  end)
    214 
    215  -- Poll for incoming requests/notifications received from the remote end.
    216  self:_run(on_request, on_notification)
    217  return (err or self.eof_err), result
    218 end
    219 
    220 --- Polls for incoming requests/notifications received from the remote end.
    221 function Session:_run(request_cb, notification_cb, timeout)
    222  if type(timeout) == 'number' then
    223    self._prepare:start(function()
    224      self._timer:start(timeout, 0, function()
    225        uv.stop()
    226      end)
    227      self._prepare:stop()
    228    end)
    229  end
    230  self._rpc_stream:read_start(request_cb, notification_cb, function()
    231    uv.stop()
    232 
    233    --- @diagnostic disable-next-line: invisible
    234    local stderr = self._rpc_stream._stream.stderr --[[@as string?]]
    235    -- See if `ProcStream.stderr` has anything useful.
    236    stderr = '' ~= ((stderr or ''):match('^%s*(.*%S)') or '') and ' stderr:\n' .. stderr or ''
    237 
    238    self.eof_err = { 1, 'EOF was received from Nvim. Likely the Nvim process crashed.' .. stderr }
    239  end)
    240  local ret, err, _ = uv.run()
    241  if ret == nil then
    242    error(err)
    243  end
    244  self._prepare:stop()
    245  self._timer:stop()
    246  self._rpc_stream:read_stop()
    247 end
    248 
    249 --- Nvim msgpack-RPC session.
    250 return Session