_async.lua (2659B)
1 local M = {} 2 3 local max_timeout = 120000 4 local copcall = package.loaded.jit and pcall or require('coxpcall').pcall 5 6 --- @param thread thread 7 --- @param on_finish fun(err: string?, ...:any) 8 --- @param ... any 9 local function resume(thread, on_finish, ...) 10 --- @type {n: integer, [1]:boolean, [2]:string|function} 11 local ret = vim.F.pack_len(coroutine.resume(thread, ...)) 12 local stat = ret[1] 13 14 if not stat then 15 -- Coroutine had error 16 on_finish(ret[2] --[[@as string]]) 17 elseif coroutine.status(thread) == 'dead' then 18 -- Coroutine finished 19 on_finish(nil, unpack(ret, 2, ret.n)) 20 else 21 local fn = ret[2] 22 --- @cast fn -string 23 24 --- @type boolean, string? 25 local ok, err = copcall(fn, function(...) 26 resume(thread, on_finish, ...) 27 end) 28 29 if not ok then 30 on_finish(err) 31 end 32 end 33 end 34 35 --- @param func async fun(): ...:any 36 --- @param on_finish? fun(err: string?, ...:any) 37 function M.run(func, on_finish) 38 local res --- @type {n:integer, [integer]:any}? 39 resume(coroutine.create(func), function(err, ...) 40 res = vim.F.pack_len(err, ...) 41 if on_finish then 42 on_finish(err, ...) 43 end 44 end) 45 46 return { 47 --- @param timeout? integer 48 --- @return any ... return values of `func` 49 wait = function(_self, timeout) 50 vim.wait(timeout or max_timeout, function() 51 return res ~= nil 52 end) 53 assert(res, 'timeout') 54 if res[1] then 55 error(res[1]) 56 end 57 return unpack(res, 2, res.n) 58 end, 59 } 60 end 61 62 --- Asynchronous blocking wait 63 --- @async 64 --- @param argc integer 65 --- @param fun function 66 --- @param ... any func arguments 67 --- @return any ... 68 function M.await(argc, fun, ...) 69 assert(coroutine.running(), 'Async.await() must be called from an async function') 70 local args = vim.F.pack_len(...) --- @type {n:integer, [integer]:any} 71 72 --- @param callback fun(...:any) 73 return coroutine.yield(function(callback) 74 args[argc] = assert(callback) 75 fun(unpack(args, 1, math.max(argc, args.n))) 76 end) 77 end 78 79 --- @async 80 --- @param max_jobs integer 81 --- @param funs (async fun())[] 82 function M.join(max_jobs, funs) 83 if #funs == 0 then 84 return 85 end 86 87 max_jobs = math.min(max_jobs, #funs) 88 89 --- @type (async fun())[] 90 local remaining = { select(max_jobs + 1, unpack(funs)) } 91 local to_go = #funs 92 93 M.await(1, function(on_finish) 94 local function run_next() 95 to_go = to_go - 1 96 if to_go == 0 then 97 on_finish() 98 elseif #remaining > 0 then 99 local next_fun = table.remove(remaining) 100 M.run(next_fun, run_next) 101 end 102 end 103 104 for i = 1, max_jobs do 105 M.run(funs[i], run_next) 106 end 107 end) 108 end 109 110 return M