neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

gen_api_dispatch.lua (28250B)


      1 -- Generates C code to bridge API <=> Lua.
      2 
      3 -- to obtain how the script is invoked, look in build/build.ninja and grep for
      4 -- "gen_api_dispatch.lua"
      5 local hashy = require 'gen.hashy'
      6 local c_grammar = require('gen.c_grammar')
      7 
      8 -- output h file with generated dispatch functions (dispatch_wrappers.generated.h)
      9 local dispatch_outputf = arg[1]
     10 -- output file with exported functions metadata
     11 local exported_funcs_metadata_outputf = arg[2]
     12 -- output mpack file with raw metadata, for use by gen_eval.lua (funcs_metadata.mpack)
     13 local eval_funcs_metadata_outputf = arg[3]
     14 local lua_c_bindings_outputf = arg[4] -- lua_api_c_bindings.generated.c
     15 local keysets_outputf = arg[5] -- keysets_defs.generated.h
     16 local dispatch_deprecated_inputf = arg[6]
     17 local pre_args = 6
     18 assert(#arg >= 6)
     19 
     20 local function real_type(type, exported)
     21  local ptype = c_grammar.typed_container:match(type)
     22  if ptype then
     23    local container = ptype[1]
     24    if container == 'Union' then
     25      return 'Object'
     26    elseif container == 'Tuple' or container == 'ArrayOf' then
     27      return 'Array'
     28    elseif container == 'DictOf' or container == 'DictAs' then
     29      return 'Dict'
     30    elseif container == 'LuaRefOf' then
     31      return 'LuaRef'
     32    elseif container == 'Enum' then
     33      return 'String'
     34    elseif container == 'Dict' then
     35      if exported then
     36        return 'Dict'
     37      end
     38      -- internal type, used for keysets
     39      return 'KeyDict_' .. ptype[2]
     40    end
     41  end
     42  return type
     43 end
     44 
     45 --- @class gen_api_dispatch.Function : nvim.c_grammar.Proto
     46 --- @field method boolean
     47 --- @field receives_array_args? true
     48 --- @field receives_channel_id? true
     49 --- @field can_fail? true
     50 --- @field has_lua_imp? true
     51 --- @field receives_arena? true
     52 --- @field impl_name? string
     53 --- @field remote? boolean
     54 --- @field lua? boolean
     55 --- @field eval? boolean
     56 --- @field handler_id? integer
     57 
     58 --- @type gen_api_dispatch.Function[]
     59 local functions = {}
     60 
     61 --- Names of all headers relative to the source root (for inclusion in the
     62 --- generated file)
     63 --- @type string[]
     64 local headers = {}
     65 
     66 --- Set of function names, used to detect duplicates
     67 --- @type table<string, true>
     68 local function_names = {}
     69 
     70 local startswith = vim.startswith
     71 
     72 --- @param fn gen_api_dispatch.Function
     73 local function add_function(fn)
     74  local public = startswith(fn.name, 'nvim_') or fn.deprecated_since
     75  if public and not fn.noexport then
     76    functions[#functions + 1] = fn
     77    function_names[fn.name] = true
     78    if
     79      #fn.parameters >= 2
     80      and fn.parameters[2][1] == 'Array'
     81      and fn.parameters[2][2] == 'uidata'
     82    then
     83      -- function receives the "args" as a parameter
     84      fn.receives_array_args = true
     85      -- remove the args parameter
     86      table.remove(fn.parameters, 2)
     87    end
     88    if #fn.parameters ~= 0 and fn.parameters[1][2] == 'channel_id' then
     89      -- this function should receive the channel id
     90      fn.receives_channel_id = true
     91      -- remove the parameter since it won't be passed by the api client
     92      table.remove(fn.parameters, 1)
     93    end
     94    if #fn.parameters ~= 0 and fn.parameters[#fn.parameters][1] == 'error' then
     95      -- function can fail if the last parameter type is 'Error'
     96      fn.can_fail = true
     97      -- remove the error parameter, msgpack has it's own special field
     98      -- for specifying errors
     99      fn.parameters[#fn.parameters] = nil
    100    end
    101    if #fn.parameters ~= 0 and fn.parameters[#fn.parameters][1] == 'lstate' then
    102      fn.has_lua_imp = true
    103      fn.parameters[#fn.parameters] = nil
    104    end
    105    if #fn.parameters ~= 0 and fn.parameters[#fn.parameters][1] == 'arena' then
    106      fn.receives_arena = true
    107      fn.parameters[#fn.parameters] = nil
    108    end
    109  end
    110 end
    111 
    112 --- @class gen_api_dispatch.Keyset
    113 --- @field name string
    114 --- @field keys string[]
    115 --- @field c_names table<string, string>
    116 --- @field types table<string, string>
    117 --- @field has_optional boolean
    118 
    119 --- @type gen_api_dispatch.Keyset[]
    120 local keysets = {}
    121 
    122 --- @param val nvim.c_grammar.Keyset
    123 local function add_keyset(val)
    124  local keys = {} --- @type string[]
    125  local types = {} --- @type table<string, string>
    126  local c_names = {} --- @type table<string, string>
    127  local is_set_name = 'is_set__' .. val.keyset_name .. '_'
    128  local has_optional = false
    129  for i, field in ipairs(val.fields) do
    130    local dict_key = field.dict_key or field.name
    131    if field.type ~= 'Object' then
    132      types[dict_key] = field.type
    133    end
    134    if field.name ~= is_set_name and field.type ~= 'OptionalKeys' then
    135      table.insert(keys, dict_key)
    136      if dict_key ~= field.name then
    137        c_names[dict_key] = field.name
    138      end
    139    else
    140      if i > 1 then
    141        error("'is_set__{type}_' must be first if present")
    142      elseif field.name ~= is_set_name then
    143        error(val.keyset_name .. ': name of first key should be ' .. is_set_name)
    144      elseif field.type ~= 'OptionalKeys' then
    145        error("'" .. is_set_name .. "' must have type 'OptionalKeys'")
    146      end
    147      has_optional = true
    148    end
    149  end
    150  keysets[#keysets + 1] = {
    151    name = val.keyset_name,
    152    keys = keys,
    153    c_names = c_names,
    154    types = types,
    155    has_optional = has_optional,
    156  }
    157 end
    158 
    159 -- read each input file, parse and append to the api metadata
    160 for i = pre_args + 1, #arg do
    161  local full_path = arg[i]
    162  local parts = {} --- @type string[]
    163  for part in full_path:gmatch('[^/\\]+') do
    164    parts[#parts + 1] = part
    165  end
    166  headers[#headers + 1] = parts[#parts - 1] .. '/' .. parts[#parts]
    167 
    168  local input = assert(io.open(full_path, 'rb'))
    169 
    170  --- @type string
    171  local text = input:read('*all')
    172  for _, val in ipairs(c_grammar.grammar:match(text)) do
    173    if val.keyset_name then
    174      --- @cast val nvim.c_grammar.Keyset
    175      add_keyset(val)
    176    elseif val.name then
    177      --- @cast val gen_api_dispatch.Function
    178      add_function(val)
    179    end
    180  end
    181 
    182  input:close()
    183 end
    184 
    185 --- @generic T: table
    186 --- @param orig T
    187 --- @return T
    188 local function shallowcopy(orig)
    189  local copy = {}
    190  for orig_key, orig_value in pairs(orig) do
    191    copy[orig_key] = orig_value
    192  end
    193  return copy
    194 end
    195 
    196 --- Export functions under older deprecated names.
    197 --- These will be removed eventually.
    198 --- @type table<string, string>
    199 local deprecated_aliases = loadfile(dispatch_deprecated_inputf)()
    200 
    201 for _, f in ipairs(shallowcopy(functions)) do
    202  local ismethod = false
    203  if startswith(f.name, 'nvim_') then
    204    if startswith(f.name, 'nvim__') or f.name == 'nvim_error_event' then
    205      f.since = -1
    206    elseif f.since == nil then
    207      print('Function ' .. f.name .. ' lacks since field.\n')
    208      os.exit(1)
    209    end
    210    f.since = tonumber(f.since)
    211    if f.deprecated_since ~= nil then
    212      f.deprecated_since = tonumber(f.deprecated_since)
    213    end
    214 
    215    if startswith(f.name, 'nvim_buf_') then
    216      ismethod = true
    217    elseif startswith(f.name, 'nvim_win_') then
    218      ismethod = true
    219    elseif startswith(f.name, 'nvim_tabpage_') then
    220      ismethod = true
    221    end
    222    f.remote = f.remote_only or not f.lua_only
    223    f.lua = f.lua_only or not f.remote_only
    224    f.eval = (not f.lua_only) and not f.remote_only
    225  else
    226    f.deprecated_since = tonumber(f.deprecated_since)
    227    assert(f.deprecated_since == 1)
    228    f.remote = true
    229    f.since = 0
    230  end
    231  f.method = ismethod
    232  local newname = deprecated_aliases[f.name]
    233  if newname ~= nil then
    234    if function_names[newname] then
    235      -- duplicate
    236      print(
    237        'Function '
    238          .. f.name
    239          .. ' has deprecated alias\n'
    240          .. newname
    241          .. ' which has a separate implementation.\n'
    242          .. 'Remove it from src/nvim/api/dispatch_deprecated.lua'
    243      )
    244      os.exit(1)
    245    end
    246    local newf = shallowcopy(f)
    247    newf.name = newname
    248    if newname == 'ui_try_resize' then
    249      -- The return type was incorrectly set to Object in 0.1.5.
    250      -- Keep it that way for clients that rely on this.
    251      newf.return_type = 'Object'
    252    end
    253    newf.impl_name = f.name
    254    newf.lua = false
    255    newf.eval = false
    256    newf.since = 0
    257    newf.deprecated_since = 1
    258    functions[#functions + 1] = newf
    259  end
    260 end
    261 
    262 --- don't expose internal attributes like "impl_name" in public metadata
    263 --- @class gen_api_dispatch.Function.Exported
    264 --- @field name string
    265 --- @field parameters [string, string][]
    266 --- @field return_type string
    267 --- @field method boolean
    268 --- @field since integer
    269 --- @field deprecated_since integer
    270 
    271 --- @type gen_api_dispatch.Function.Exported[]
    272 local exported_functions = {}
    273 
    274 for _, f in ipairs(functions) do
    275  if not (startswith(f.name, 'nvim__') or f.name == 'nvim_error_event' or f.name == 'redraw') then
    276    --- @type gen_api_dispatch.Function.Exported
    277    local f_exported = {
    278      name = f.name,
    279      method = f.method,
    280      since = f.since,
    281      deprecated_since = f.deprecated_since,
    282      parameters = {},
    283      return_type = real_type(f.return_type, true),
    284    }
    285    for i, param in ipairs(f.parameters) do
    286      f_exported.parameters[i] = { real_type(param[1], true), param[2] }
    287    end
    288    exported_functions[#exported_functions + 1] = f_exported
    289  end
    290 end
    291 
    292 local metadata_output = assert(io.open(exported_funcs_metadata_outputf, 'wb'))
    293 metadata_output:write(vim.mpack.encode(exported_functions))
    294 metadata_output:close()
    295 
    296 --- @type integer[]
    297 -- start building the dispatch wrapper output
    298 local output = assert(io.open(dispatch_outputf, 'wb'))
    299 
    300 -- ===========================================================================
    301 -- NEW API FILES MUST GO HERE.
    302 --
    303 --  When creating a new API file, you must include it here,
    304 --  so that the dispatcher can find the C functions that you are creating!
    305 -- ===========================================================================
    306 output:write([[
    307 #include "nvim/errors.h"
    308 #include "nvim/ex_docmd.h"
    309 #include "nvim/ex_getln.h"
    310 #include "nvim/globals.h"
    311 #include "nvim/log.h"
    312 #include "nvim/map_defs.h"
    313 
    314 #include "nvim/api/autocmd.h"
    315 #include "nvim/api/buffer.h"
    316 #include "nvim/api/command.h"
    317 #include "nvim/api/deprecated.h"
    318 #include "nvim/api/events.h"
    319 #include "nvim/api/extmark.h"
    320 #include "nvim/api/options.h"
    321 #include "nvim/api/tabpage.h"
    322 #include "nvim/api/ui.h"
    323 #include "nvim/api/vim.h"
    324 #include "nvim/api/vimscript.h"
    325 #include "nvim/api/win_config.h"
    326 #include "nvim/api/window.h"
    327 #include "nvim/ui_client.h"
    328 
    329 ]])
    330 
    331 local keysets_defs = assert(io.open(keysets_outputf, 'wb'))
    332 
    333 keysets_defs:write('// IWYU pragma: private, include "nvim/api/private/dispatch.h"\n\n')
    334 
    335 for _, k in ipairs(keysets) do
    336  local neworder, hashfun = hashy.hashy_hash(k.name, k.keys, function(idx)
    337    return k.name .. '_table[' .. idx .. '].str'
    338  end)
    339 
    340  keysets_defs:write('extern KeySetLink ' .. k.name .. '_table[' .. (1 + #neworder) .. '];\n')
    341 
    342  local function typename(type)
    343    if type == 'HLGroupID' then
    344      return 'kObjectTypeInteger'
    345    elseif not type or startswith(type, 'Union') then
    346      return 'kObjectTypeNil'
    347    elseif type == 'StringArray' then
    348      return 'kUnpackTypeStringArray'
    349    end
    350    return 'kObjectType' .. real_type(type)
    351  end
    352 
    353  output:write('KeySetLink ' .. k.name .. '_table[] = {\n')
    354  for i, key in ipairs(neworder) do
    355    local ind = -1
    356    if k.has_optional then
    357      ind = i
    358      keysets_defs:write('#define KEYSET_OPTIDX_' .. k.name .. '__' .. key .. ' ' .. ind .. '\n')
    359    end
    360    output:write(
    361      '  {"'
    362        .. key
    363        .. '", offsetof(KeyDict_'
    364        .. k.name
    365        .. ', '
    366        .. (k.c_names[key] or key)
    367        .. '), '
    368        .. typename(k.types[key])
    369        .. ', '
    370        .. ind
    371        .. ', '
    372        .. (k.types[key] == 'HLGroupID' and 'true' or 'false')
    373        .. '},\n'
    374    )
    375  end
    376  output:write('  {NULL, 0, kObjectTypeNil, -1, false},\n')
    377  output:write('};\n\n')
    378 
    379  output:write(hashfun)
    380 
    381  output:write([[
    382 KeySetLink *KeyDict_]] .. k.name .. [[_get_field(const char *str, size_t len)
    383 {
    384  int hash = ]] .. k.name .. [[_hash(str, len);
    385  if (hash == -1) {
    386    return NULL;
    387  }
    388  return &]] .. k.name .. [[_table[hash];
    389 }
    390 
    391 ]])
    392 end
    393 
    394 keysets_defs:close()
    395 
    396 local function attr_name(rt)
    397  if rt == 'Float' then
    398    return 'floating'
    399  else
    400    return rt:lower()
    401  end
    402 end
    403 
    404 -- start the handler functions. Visit each function metadata to build the
    405 -- handler function with code generated for validating arguments and calling to
    406 -- the real API.
    407 for i = 1, #functions do
    408  local fn = functions[i]
    409  if fn.impl_name == nil and fn.remote then
    410    local args = {} --- @type string[]
    411 
    412    output:write(
    413      'Object handle_' .. fn.name .. '(uint64_t channel_id, Array args, Arena* arena, Error *error)'
    414    )
    415    output:write('\n{')
    416    output:write('\n#ifdef NVIM_LOG_DEBUG')
    417    output:write('\n  DLOG("RPC: ch %" PRIu64 ": invoke ' .. fn.name .. '", channel_id);')
    418    output:write('\n#endif')
    419    output:write('\n  Object ret = NIL;')
    420    -- Declare/initialize variables that will hold converted arguments
    421    for j = 1, #fn.parameters do
    422      local param = fn.parameters[j]
    423      local rt = real_type(param[1])
    424      local converted = 'arg_' .. j
    425      output:write('\n  ' .. rt .. ' ' .. converted .. ';')
    426    end
    427    output:write('\n')
    428    if not fn.receives_array_args then
    429      output:write('\n  if (args.size != ' .. #fn.parameters .. ') {')
    430      output:write(
    431        '\n    api_set_error(error, kErrorTypeException, \
    432        "Wrong number of arguments: expecting '
    433          .. #fn.parameters
    434          .. ' but got %zu", args.size);'
    435      )
    436      output:write('\n    goto cleanup;')
    437      output:write('\n  }\n')
    438    end
    439 
    440    -- Validation/conversion for each argument
    441    for j = 1, #fn.parameters do
    442      local converted, param
    443      param = fn.parameters[j]
    444      converted = 'arg_' .. j
    445      local rt = real_type(param[1])
    446      if rt == 'Object' then
    447        output:write('\n  ' .. converted .. ' = args.items[' .. (j - 1) .. '];\n')
    448      elseif rt:match('^KeyDict_') then
    449        converted = '&' .. converted
    450        output:write('\n  if (args.items[' .. (j - 1) .. '].type == kObjectTypeDict) {') --luacheck: ignore 631
    451        output:write('\n    memset(' .. converted .. ', 0, sizeof(*' .. converted .. '));') -- TODO: neeeee
    452        output:write(
    453          '\n    if (!api_dict_to_keydict('
    454            .. converted
    455            .. ', '
    456            .. rt
    457            .. '_get_field, args.items['
    458            .. (j - 1)
    459            .. '].data.dict, error)) {'
    460        )
    461        output:write('\n      goto cleanup;')
    462        output:write('\n    }')
    463        output:write(
    464          '\n  } else if (args.items['
    465            .. (j - 1)
    466            .. '].type == kObjectTypeArray && args.items['
    467            .. (j - 1)
    468            .. '].data.array.size == 0) {'
    469        ) --luacheck: ignore 631
    470        output:write('\n    memset(' .. converted .. ', 0, sizeof(*' .. converted .. '));')
    471 
    472        output:write('\n  } else {')
    473        output:write(
    474          '\n    api_set_error(error, kErrorTypeException, \
    475          "Wrong type for argument '
    476            .. j
    477            .. ' when calling '
    478            .. fn.name
    479            .. ', expecting '
    480            .. param[1]
    481            .. '");'
    482        )
    483        output:write('\n    goto cleanup;')
    484        output:write('\n  }\n')
    485      else
    486        if rt:match('^Buffer$') or rt:match('^Window$') or rt:match('^Tabpage$') then
    487          -- Buffer, Window, and Tabpage have a specific type, but are stored in integer
    488          output:write(
    489            '\n  if (args.items['
    490              .. (j - 1)
    491              .. '].type == kObjectType'
    492              .. rt
    493              .. ' && args.items['
    494              .. (j - 1)
    495              .. '].data.integer >= 0) {'
    496          )
    497          output:write(
    498            '\n    ' .. converted .. ' = (handle_T)args.items[' .. (j - 1) .. '].data.integer;'
    499          )
    500        else
    501          output:write('\n  if (args.items[' .. (j - 1) .. '].type == kObjectType' .. rt .. ') {')
    502          output:write(
    503            '\n    '
    504              .. converted
    505              .. ' = args.items['
    506              .. (j - 1)
    507              .. '].data.'
    508              .. attr_name(rt)
    509              .. ';'
    510          )
    511        end
    512        if
    513          rt:match('^Buffer$')
    514          or rt:match('^Window$')
    515          or rt:match('^Tabpage$')
    516          or rt:match('^Boolean$')
    517        then
    518          -- accept nonnegative integers for Booleans, Buffers, Windows and Tabpages
    519          output:write(
    520            '\n  } else if (args.items['
    521              .. (j - 1)
    522              .. '].type == kObjectTypeInteger && args.items['
    523              .. (j - 1)
    524              .. '].data.integer >= 0) {'
    525          )
    526          output:write(
    527            '\n    ' .. converted .. ' = (handle_T)args.items[' .. (j - 1) .. '].data.integer;'
    528          )
    529        end
    530        if rt:match('^Float$') then
    531          -- accept integers for Floats
    532          output:write('\n  } else if (args.items[' .. (j - 1) .. '].type == kObjectTypeInteger) {')
    533          output:write(
    534            '\n    ' .. converted .. ' = (Float)args.items[' .. (j - 1) .. '].data.integer;'
    535          )
    536        end
    537        -- accept empty lua tables as empty dictionaries
    538        if rt:match('^Dict') then
    539          output:write(
    540            '\n  } else if (args.items['
    541              .. (j - 1)
    542              .. '].type == kObjectTypeArray && args.items['
    543              .. (j - 1)
    544              .. '].data.array.size == 0) {'
    545          ) --luacheck: ignore 631
    546          output:write('\n    ' .. converted .. ' = (Dict)ARRAY_DICT_INIT;')
    547        end
    548        output:write('\n  } else {')
    549        output:write(
    550          '\n    api_set_error(error, kErrorTypeException, \
    551          "Wrong type for argument '
    552            .. j
    553            .. ' when calling '
    554            .. fn.name
    555            .. ', expecting '
    556            .. param[1]
    557            .. '");'
    558        )
    559        output:write('\n    goto cleanup;')
    560        output:write('\n  }\n')
    561      end
    562      args[#args + 1] = converted
    563    end
    564 
    565    if fn.textlock then
    566      output:write('\n  if (text_locked()) {')
    567      output:write('\n    api_set_error(error, kErrorTypeException, "%s", get_text_locked_msg());')
    568      output:write('\n    goto cleanup;')
    569      output:write('\n  }\n')
    570    elseif fn.textlock_allow_cmdwin then
    571      output:write('\n  if (textlock != 0 || expr_map_locked()) {')
    572      output:write('\n    api_set_error(error, kErrorTypeException, "%s", e_textlock);')
    573      output:write('\n    goto cleanup;')
    574      output:write('\n  }\n')
    575    end
    576 
    577    -- function call
    578    output:write('\n  ')
    579    if fn.return_type ~= 'void' then
    580      -- has a return value, prefix the call with a declaration
    581      output:write(fn.return_type .. ' rv = ')
    582    end
    583 
    584    -- write the function name and the opening parenthesis
    585    output:write(fn.name .. '(')
    586 
    587    local call_args = {}
    588    if fn.receives_channel_id then
    589      table.insert(call_args, 'channel_id')
    590    end
    591 
    592    if fn.receives_array_args then
    593      table.insert(call_args, 'args')
    594    end
    595 
    596    for _, a in ipairs(args) do
    597      table.insert(call_args, a)
    598    end
    599 
    600    if fn.receives_arena then
    601      table.insert(call_args, 'arena')
    602    end
    603 
    604    if fn.has_lua_imp then
    605      table.insert(call_args, 'NULL')
    606    end
    607 
    608    if fn.can_fail then
    609      table.insert(call_args, 'error')
    610    end
    611 
    612    output:write(table.concat(call_args, ', '))
    613    output:write(');\n')
    614 
    615    if fn.can_fail then
    616      -- if the function can fail, also pass a pointer to the local error object
    617      -- and check for the error
    618      output:write('\n  if (ERROR_SET(error)) {')
    619      output:write('\n    goto cleanup;')
    620      output:write('\n  }\n')
    621    end
    622 
    623    local ret_type = real_type(fn.return_type)
    624    if ret_type:match('^KeyDict_') then
    625      local table = ret_type:sub(9) .. '_table'
    626      output:write(
    627        '\n  ret = DICT_OBJ(api_keydict_to_dict(&rv, '
    628          .. table
    629          .. ', ARRAY_SIZE('
    630          .. table
    631          .. '), arena));'
    632      )
    633    elseif ret_type ~= 'void' then
    634      output:write('\n  ret = ' .. real_type(fn.return_type):upper() .. '_OBJ(rv);')
    635    end
    636    output:write('\n\ncleanup:')
    637 
    638    output:write('\n  return ret;\n}\n\n')
    639  end
    640 end
    641 
    642 --- @type {[string]: gen_api_dispatch.Function, redraw: {impl_name: string, fast: boolean}}
    643 local remote_fns = {}
    644 for _, fn in ipairs(functions) do
    645  if fn.remote then
    646    remote_fns[fn.name] = fn
    647  end
    648 end
    649 remote_fns.redraw = { impl_name = 'ui_client_redraw', fast = true }
    650 
    651 local names = vim.tbl_keys(remote_fns)
    652 table.sort(names)
    653 local hashorder, hashfun = hashy.hashy_hash('msgpack_rpc_get_handler_for', names, function(idx)
    654  return 'method_handlers[' .. idx .. '].name'
    655 end)
    656 
    657 output:write('const MsgpackRpcRequestHandler method_handlers[] = {\n')
    658 for n, name in ipairs(hashorder) do
    659  local fn = remote_fns[name]
    660  fn.handler_id = n - 1
    661  output:write(
    662    '  { .name = "'
    663      .. name
    664      .. '", .fn = handle_'
    665      .. (fn.impl_name or fn.name)
    666      .. ', .fast = '
    667      .. tostring(fn.fast)
    668      .. ', .ret_alloc = '
    669      .. tostring(not not fn.ret_alloc)
    670      .. '},\n'
    671  )
    672 end
    673 output:write('};\n\n')
    674 output:write(hashfun)
    675 
    676 output:close()
    677 
    678 --- @cast functions {[integer]: gen_api_dispatch.Function, keysets: gen_api_dispatch.Keyset[]}
    679 functions.keysets = keysets
    680 local mpack_output = assert(io.open(eval_funcs_metadata_outputf, 'wb'))
    681 mpack_output:write(vim.mpack.encode(functions))
    682 mpack_output:close()
    683 
    684 local function include_headers(output_handle, headers_to_include)
    685  for i = 1, #headers_to_include do
    686    if headers_to_include[i]:sub(-12) ~= '.generated.h' then
    687      output_handle:write('\n#include "nvim/' .. headers_to_include[i] .. '"')
    688    end
    689  end
    690 end
    691 
    692 --- @param str string
    693 local function write_shifted_output(str, ...)
    694  str = str:gsub('\n  ', '\n')
    695  str = str:gsub('^  ', '')
    696  str = str:gsub(' +$', '')
    697  output:write(str:format(...))
    698 end
    699 
    700 -- start building lua output
    701 output = assert(io.open(lua_c_bindings_outputf, 'wb'))
    702 
    703 include_headers(output, headers)
    704 output:write('\n')
    705 
    706 --- @type {binding: string, api:string}[]
    707 local lua_c_functions = {}
    708 
    709 --- Generates C code to bridge RPC API <=> Lua.
    710 ---
    711 --- Inspect the result here:
    712 ---    build/src/nvim/auto/api/private/dispatch_wrappers.generated.h
    713 --- @param fn gen_api_dispatch.Function
    714 local function process_function(fn)
    715  local lua_c_function_name = ('nlua_api_%s'):format(fn.name)
    716  write_shifted_output(
    717    [[
    718 
    719  static int %s(lua_State *lstate)
    720  {
    721    Error err = ERROR_INIT;
    722    Arena arena = ARENA_EMPTY;
    723    char *err_param = 0;
    724    if (lua_gettop(lstate) != %i) {
    725      api_set_error(&err, kErrorTypeValidation, "Expected %i argument%s");
    726      goto exit_0;
    727    }
    728  ]],
    729    lua_c_function_name,
    730    #fn.parameters,
    731    #fn.parameters,
    732    (#fn.parameters == 1) and '' or 's'
    733  )
    734  lua_c_functions[#lua_c_functions + 1] = {
    735    binding = lua_c_function_name,
    736    api = fn.name,
    737  }
    738 
    739  if not fn.fast then
    740    write_shifted_output(
    741      [[
    742    if (!nlua_is_deferred_safe()) {
    743      return luaL_error(lstate, e_fast_api_disabled, "%s");
    744    }
    745    ]],
    746      fn.name
    747    )
    748  end
    749 
    750  if fn.textlock then
    751    write_shifted_output([[
    752    if (text_locked()) {
    753      api_set_error(&err, kErrorTypeException, "%%s", get_text_locked_msg());
    754      goto exit_0;
    755    }
    756    ]])
    757  elseif fn.textlock_allow_cmdwin then
    758    write_shifted_output([[
    759    if (textlock != 0 || expr_map_locked()) {
    760      api_set_error(&err, kErrorTypeException, "%%s", e_textlock);
    761      goto exit_0;
    762    }
    763    ]])
    764  end
    765 
    766  local cparams = ''
    767  local free_code = {} --- @type string[]
    768  for j = #fn.parameters, 1, -1 do
    769    local param = fn.parameters[j]
    770    local cparam = string.format('arg%u', j)
    771    local param_type = real_type(param[1])
    772    local extra = param_type == 'Dict' and 'false, ' or ''
    773    local arg_free_code = ''
    774    if param_type == 'Object' then
    775      extra = 'true, '
    776      arg_free_code = '  api_luarefs_free_object(' .. cparam .. ');'
    777    elseif param[1] == 'DictOf(LuaRef)' then
    778      extra = 'true, '
    779      arg_free_code = '  api_luarefs_free_dict(' .. cparam .. ');'
    780    elseif param[1] == 'LuaRef' then
    781      arg_free_code = '  api_free_luaref(' .. cparam .. ');'
    782    end
    783    local errshift = 0
    784    local seterr = ''
    785    if param_type:match('^KeyDict_') then
    786      write_shifted_output(
    787        [[
    788    %s %s = KEYDICT_INIT;
    789    nlua_pop_keydict(lstate, &%s, %s_get_field, &err_param, &arena, &err);
    790    ]],
    791        param_type,
    792        cparam,
    793        cparam,
    794        param_type
    795      )
    796      cparam = '&' .. cparam
    797      errshift = 1 -- free incomplete dict on error
    798      arg_free_code = '  api_luarefs_free_keydict('
    799        .. cparam
    800        .. ', '
    801        .. param_type:sub(9)
    802        .. '_table);'
    803    else
    804      write_shifted_output(
    805        [[
    806    const %s %s = nlua_pop_%s(lstate, %s&arena, &err);]],
    807        param[1],
    808        cparam,
    809        param_type,
    810        extra
    811      )
    812      seterr = '\n      err_param = "' .. param[2] .. '";'
    813    end
    814 
    815    write_shifted_output([[
    816 
    817    if (ERROR_SET(&err)) {]] .. seterr .. [[
    818 
    819      goto exit_%u;
    820    }
    821 
    822    ]], #fn.parameters - j + errshift)
    823    free_code[#free_code + 1] = arg_free_code
    824    cparams = cparam .. ', ' .. cparams
    825  end
    826  if fn.receives_channel_id then
    827    --- @type string
    828    cparams = 'LUA_INTERNAL_CALL, ' .. cparams
    829  end
    830  if fn.receives_arena then
    831    cparams = cparams .. '&arena, '
    832  end
    833 
    834  if fn.has_lua_imp then
    835    cparams = cparams .. 'lstate, '
    836  end
    837 
    838  if fn.can_fail then
    839    cparams = cparams .. '&err'
    840  else
    841    cparams = cparams:gsub(', $', '')
    842  end
    843 
    844  write_shifted_output('    ENTER_LUA_ACTIVE_STATE(lstate);\n')
    845  local free_at_exit_code = ''
    846  for i = 1, #free_code do
    847    local rev_i = #free_code - i + 1
    848    local code = free_code[rev_i]
    849    if i == 1 and not real_type(fn.parameters[1][1]):match('^KeyDict_') then
    850      free_at_exit_code = free_at_exit_code .. ('\n%s'):format(code)
    851    else
    852      free_at_exit_code = free_at_exit_code .. ('\nexit_%u:\n%s'):format(rev_i, code)
    853    end
    854  end
    855  local err_throw_code = [[
    856 
    857 exit_0:
    858  arena_mem_free(arena_finish(&arena));
    859  if (ERROR_SET(&err)) {
    860    luaL_where(lstate, 1);
    861    if (err_param) {
    862      lua_pushstring(lstate, "Invalid '");
    863      lua_pushstring(lstate, err_param);
    864      lua_pushstring(lstate, "': ");
    865    }
    866    lua_pushstring(lstate, err.msg);
    867    api_clear_error(&err);
    868    lua_concat(lstate, err_param ? 5 : 2);
    869    return lua_error(lstate);
    870  }
    871 ]]
    872  if fn.return_type ~= 'void' then
    873    local return_type = real_type(fn.return_type)
    874    local free_retval = ''
    875    if fn.ret_alloc then
    876      free_retval = '  api_free_' .. return_type:lower() .. '(ret);'
    877    end
    878    write_shifted_output('    %s ret = %s(%s);\n', fn.return_type, fn.name, cparams)
    879 
    880    local ret_type = real_type(fn.return_type)
    881    local ret_mode = (ret_type == 'Object') and '&' or ''
    882    if fn.has_lua_imp then
    883      -- only push onto the Lua stack if we haven't already
    884      write_shifted_output(
    885        [[
    886    if (lua_gettop(lstate) == 0) {
    887      nlua_push_%s(lstate, %sret, kNluaPushSpecial | kNluaPushFreeRefs);
    888    }
    889      ]],
    890        return_type,
    891        ret_mode
    892      )
    893    elseif ret_type:match('^KeyDict_') then
    894      write_shifted_output('    nlua_push_keydict(lstate, &ret, %s_table);\n', return_type:sub(9))
    895    else
    896      local special = (fn.since ~= nil and fn.since < 11)
    897      write_shifted_output(
    898        '    nlua_push_%s(lstate, %sret, %s | kNluaPushFreeRefs);\n',
    899        return_type,
    900        ret_mode,
    901        special and 'kNluaPushSpecial' or '0'
    902      )
    903    end
    904 
    905    -- NOTE: we currently assume err_throw needs nothing from arena
    906    write_shifted_output(
    907      [[
    908    LEAVE_LUA_ACTIVE_STATE();
    909  %s
    910  %s
    911  %s
    912    return 1;
    913    ]],
    914      free_retval,
    915      free_at_exit_code,
    916      err_throw_code
    917    )
    918  else
    919    write_shifted_output(
    920      [[
    921    %s(%s);
    922    LEAVE_LUA_ACTIVE_STATE();
    923  %s
    924  %s
    925    return 0;
    926    ]],
    927      fn.name,
    928      cparams,
    929      free_at_exit_code,
    930      err_throw_code
    931    )
    932  end
    933  write_shifted_output([[
    934  }
    935  ]])
    936 end
    937 
    938 for _, fn in ipairs(functions) do
    939  if fn.lua or fn.name:sub(1, 4) == '_vim' then
    940    process_function(fn)
    941  end
    942 end
    943 
    944 output:write(string.format(
    945  [[
    946 void nlua_add_api_functions(lua_State *lstate)
    947 {
    948  lua_createtable(lstate, 0, %u);
    949 ]],
    950  #lua_c_functions
    951 ))
    952 for _, func in ipairs(lua_c_functions) do
    953  output:write(string.format(
    954    [[
    955 
    956  lua_pushcfunction(lstate, &%s);
    957  lua_setfield(lstate, -2, "%s");]],
    958    func.binding,
    959    func.api
    960  ))
    961 end
    962 output:write([[
    963 
    964  lua_setfield(lstate, -2, "api");
    965 }
    966 ]])
    967 
    968 output:close()