session.lua (6858B)
1 --- 2 --- Nvim msgpack-RPC protocol session. Manages requests/notifications/responses. 3 --- 4 5 local uv = vim.uv 6 local RpcStream = require('test.client.rpc_stream') 7 8 --- Nvim msgpack-RPC protocol session. Manages requests/notifications/responses. 9 --- 10 --- @class test.Session 11 --- @field private _pending_messages string[] Requests/notifications received from the remote end. 12 --- @field private _rpc_stream test.RpcStream 13 --- @field private _prepare uv.uv_prepare_t 14 --- @field private _timer uv.uv_timer_t 15 --- @field private _is_running boolean true during `Session:run()` scope. 16 --- @field data table Arbitrary user data. 17 local Session = {} 18 Session.__index = Session 19 if package.loaded['jit'] then 20 -- luajit pcall is already coroutine safe 21 Session.safe_pcall = pcall 22 else 23 Session.safe_pcall = require 'coxpcall'.pcall 24 end 25 26 local function resume(co, ...) 27 local status, result = coroutine.resume(co, ...) 28 29 if coroutine.status(co) == 'dead' then 30 if not status then 31 error(result) 32 end 33 return 34 end 35 36 assert(coroutine.status(co) == 'suspended') 37 result(co) 38 end 39 40 local function coroutine_exec(func, ...) 41 local args = { ... } 42 local on_complete --- @type function? 43 44 if #args > 0 and type(args[#args]) == 'function' then 45 -- completion callback 46 on_complete = table.remove(args) 47 end 48 49 resume(coroutine.create(function() 50 local status, result, flag = Session.safe_pcall(func, unpack(args)) 51 if on_complete then 52 coroutine.yield(function() 53 -- run the completion callback on the main thread 54 on_complete(status, result, flag) 55 end) 56 end 57 end)) 58 end 59 60 --- Creates a new msgpack-RPC session. 61 function Session.new(stream) 62 return setmetatable({ 63 _rpc_stream = RpcStream.new(stream), 64 _pending_messages = {}, 65 _prepare = uv.new_prepare(), 66 _timer = uv.new_timer(), 67 _is_running = false, 68 }, Session) 69 end 70 71 --- @param timeout integer? 72 --- @return string? 73 function Session:next_message(timeout) 74 local function on_request(method, args, response) 75 table.insert(self._pending_messages, { 'request', method, args, response }) 76 uv.stop() 77 end 78 79 local function on_notification(method, args) 80 table.insert(self._pending_messages, { 'notification', method, args }) 81 uv.stop() 82 end 83 84 if self._is_running then 85 error('Event loop already running') 86 end 87 88 if #self._pending_messages > 0 then 89 return table.remove(self._pending_messages, 1) 90 end 91 92 -- if closed, only return pending messages 93 if self.closed then 94 return nil 95 end 96 97 self:_run(on_request, on_notification, timeout) 98 return table.remove(self._pending_messages, 1) 99 end 100 101 --- Sends a notification to the RPC endpoint. 102 function Session:notify(method, ...) 103 self._rpc_stream:write(method, { ... }) 104 end 105 106 --- Sends a request to the RPC endpoint. 107 --- 108 --- @param method string 109 --- @param ... any 110 --- @return boolean, table 111 function Session:request(method, ...) 112 local args = { ... } 113 local err, result 114 if self._is_running then 115 err, result = self:_yielding_request(method, args) 116 else 117 err, result = self:_blocking_request(method, args) 118 end 119 120 if err then 121 return false, err 122 end 123 124 return true, result 125 end 126 127 --- Processes incoming RPC requests/notifications until exhausted. 128 --- 129 --- TODO(justinmk): luaclient2 avoids this via uvutil.cb_wait() + uvutil.add_idle_call()? 130 --- 131 --- @param request_cb function Handles requests from the sever to the local end. 132 --- @param notification_cb function Handles notifications from the sever to the local end. 133 --- @param setup_cb function 134 --- @param timeout number 135 function Session:run(request_cb, notification_cb, setup_cb, timeout) 136 --- Handles an incoming request. 137 local function on_request(method, args, response) 138 coroutine_exec(request_cb, method, args, function(status, result, flag) 139 if status then 140 response:send(result, flag) 141 else 142 response:send(result, true) 143 end 144 end) 145 end 146 147 --- Handles an incoming notification. 148 local function on_notification(method, args) 149 coroutine_exec(notification_cb, method, args) 150 end 151 152 self._is_running = true 153 154 if setup_cb then 155 coroutine_exec(setup_cb) 156 end 157 158 while #self._pending_messages > 0 do 159 local msg = table.remove(self._pending_messages, 1) 160 if msg[1] == 'request' then 161 on_request(msg[2], msg[3], msg[4]) 162 else 163 on_notification(msg[2], msg[3]) 164 end 165 end 166 167 self:_run(on_request, on_notification, timeout) 168 self._is_running = false 169 end 170 171 function Session:stop() 172 uv.stop() 173 end 174 175 function Session:close(signal, noblock) 176 if not self._timer:is_closing() then 177 self._timer:close() 178 end 179 if not self._prepare:is_closing() then 180 self._prepare:close() 181 end 182 self._rpc_stream:close(signal, noblock) 183 self.closed = true 184 end 185 186 --- Sends a request to the RPC endpoint, without blocking (schedules a coroutine). 187 function Session:_yielding_request(method, args) 188 return coroutine.yield(function(co) 189 self._rpc_stream:write(method, args, function(err, result) 190 resume(co, err, result) 191 end) 192 end) 193 end 194 195 --- Sends a request to the RPC endpoint, and blocks (polls event loop) until a response is received. 196 function Session:_blocking_request(method, args) 197 local err, result 198 199 -- Invoked when a request is received from the remote end. 200 local function on_request(method_, args_, response) 201 table.insert(self._pending_messages, { 'request', method_, args_, response }) 202 end 203 204 -- Invoked when a notification is received from the remote end. 205 local function on_notification(method_, args_) 206 table.insert(self._pending_messages, { 'notification', method_, args_ }) 207 end 208 209 self._rpc_stream:write(method, args, function(e, r) 210 err = e 211 result = r 212 uv.stop() 213 end) 214 215 -- Poll for incoming requests/notifications received from the remote end. 216 self:_run(on_request, on_notification) 217 return (err or self.eof_err), result 218 end 219 220 --- Polls for incoming requests/notifications received from the remote end. 221 function Session:_run(request_cb, notification_cb, timeout) 222 if type(timeout) == 'number' then 223 self._prepare:start(function() 224 self._timer:start(timeout, 0, function() 225 uv.stop() 226 end) 227 self._prepare:stop() 228 end) 229 end 230 self._rpc_stream:read_start(request_cb, notification_cb, function() 231 uv.stop() 232 233 --- @diagnostic disable-next-line: invisible 234 local stderr = self._rpc_stream._stream.stderr --[[@as string?]] 235 -- See if `ProcStream.stderr` has anything useful. 236 stderr = '' ~= ((stderr or ''):match('^%s*(.*%S)') or '') and ' stderr:\n' .. stderr or '' 237 238 self.eof_err = { 1, 'EOF was received from Nvim. Likely the Nvim process crashed.' .. stderr } 239 end) 240 local ret, err, _ = uv.run() 241 if ret == nil then 242 error(err) 243 end 244 self._prepare:stop() 245 self._timer:stop() 246 self._rpc_stream:read_stop() 247 end 248 249 --- Nvim msgpack-RPC session. 250 return Session