neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

coxpcall.lua (3760B)


      1 -------------------------------------------------------------------------------
      2 -- (Not needed for LuaJIT or Lua 5.2+)
      3 --
      4 -- Coroutine safe xpcall and pcall versions
      5 --
      6 -- https://keplerproject.github.io/coxpcall/
      7 --
      8 -- Encapsulates the protected calls with a coroutine based loop, so errors can
      9 -- be dealed without the usual Lua 5.x pcall/xpcall issues with coroutines
     10 -- yielding inside the call to pcall or xpcall.
     11 --
     12 -- Authors: Roberto Ierusalimschy and Andre Carregal
     13 -- Contributors: Thomas Harning Jr., Ignacio BurgueƱo, Fabio Mascarenhas
     14 --
     15 -- Copyright 2005 - Kepler Project
     16 --
     17 -- $Id: coxpcall.lua,v 1.13 2008/05/19 19:20:02 mascarenhas Exp $
     18 -------------------------------------------------------------------------------
     19 
     20 -------------------------------------------------------------------------------
     21 -- Checks if (x)pcall function is coroutine safe
     22 -------------------------------------------------------------------------------
     23 local function isCoroutineSafe(func)
     24    local co = coroutine.create(function()
     25        return func(coroutine.yield, function() end)
     26    end)
     27 
     28    coroutine.resume(co)
     29    return coroutine.resume(co)
     30 end
     31 
     32 -- No need to do anything if pcall and xpcall are already safe.
     33 if isCoroutineSafe(pcall) and isCoroutineSafe(xpcall) then
     34    _G.copcall = pcall
     35    _G.coxpcall = xpcall
     36    return { pcall = pcall, xpcall = xpcall, running = coroutine.running }
     37 end
     38 
     39 -------------------------------------------------------------------------------
     40 -- Implements xpcall with coroutines
     41 -------------------------------------------------------------------------------
     42 ---@diagnostic disable-next-line
     43 local performResume
     44 local oldpcall, oldxpcall = pcall, xpcall
     45 local pack = table.pack or function(...) return {n = select("#", ...), ...} end
     46 local unpack = table.unpack or unpack
     47 local running = coroutine.running
     48 --- @type table<thread,thread>
     49 local coromap = setmetatable({}, { __mode = "k" })
     50 
     51 local function handleReturnValue(err, co, status, ...)
     52    if not status then
     53        return false, err(debug.traceback(co, (...)), ...)
     54    end
     55    if coroutine.status(co) == 'suspended' then
     56        return performResume(err, co, coroutine.yield(...))
     57    else
     58        return true, ...
     59    end
     60 end
     61 
     62 function performResume(err, co, ...)
     63    return handleReturnValue(err, co, coroutine.resume(co, ...))
     64 end
     65 
     66 --- @diagnostic disable-next-line: unused-vararg
     67 local function id(trace, ...)
     68    return trace
     69 end
     70 
     71 function _G.coxpcall(f, err, ...)
     72    local current = running()
     73    if not current then
     74        if err == id then
     75            return oldpcall(f, ...)
     76        else
     77            if select("#", ...) > 0 then
     78                local oldf, params = f, pack(...)
     79                f = function() return oldf(unpack(params, 1, params.n)) end
     80            end
     81            return oldxpcall(f, err)
     82        end
     83    else
     84        local res, co = oldpcall(coroutine.create, f)
     85        if not res then
     86            local newf = function(...) return f(...) end
     87            co = coroutine.create(newf)
     88        end
     89        coromap[co] = current
     90        return performResume(err, co, ...)
     91    end
     92 end
     93 
     94 --- @param coro? thread
     95 local function corunning(coro)
     96  if coro ~= nil then
     97    assert(type(coro)=="thread", "Bad argument; expected thread, got: "..type(coro))
     98  else
     99    coro = running()
    100  end
    101  while coromap[coro] do
    102    coro = coromap[coro]
    103  end
    104  if coro == "mainthread" then return nil end
    105  return coro
    106 end
    107 
    108 -------------------------------------------------------------------------------
    109 -- Implements pcall with coroutines
    110 -------------------------------------------------------------------------------
    111 
    112 function _G.copcall(f, ...)
    113    return coxpcall(f, id, ...)
    114 end
    115 
    116 return { pcall = copcall, xpcall = coxpcall, running = corunning }