neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

treesitter.c (49514B)


      1 // lua bindings for treesitter.
      2 // NB: this file mostly contains a generic lua interface for treesitter
      3 // trees and nodes, and could be broken out as a reusable lua package
      4 
      5 #include <assert.h>
      6 #include <ctype.h>
      7 #include <lauxlib.h>
      8 #include <limits.h>
      9 #include <lua.h>
     10 #include <stdbool.h>
     11 #include <stdint.h>
     12 #include <stdio.h>
     13 #include <stdlib.h>
     14 #include <string.h>
     15 #include <tree_sitter/api.h>
     16 #include <uv.h>
     17 
     18 #include "nvim/os/time.h"
     19 
     20 #ifdef HAVE_WASMTIME
     21 # include <wasm.h>
     22 
     23 # include "nvim/os/fs.h"
     24 #endif
     25 
     26 #include "nvim/api/private/helpers.h"
     27 #include "nvim/ascii_defs.h"
     28 #include "nvim/buffer_defs.h"
     29 #include "nvim/globals.h"
     30 #include "nvim/lua/treesitter.h"
     31 #include "nvim/macros_defs.h"
     32 #include "nvim/map_defs.h"
     33 #include "nvim/memline.h"
     34 #include "nvim/memory.h"
     35 #include "nvim/pos_defs.h"
     36 #include "nvim/strings.h"
     37 #include "nvim/types_defs.h"
     38 
     39 #define TS_META_PARSER "treesitter_parser"
     40 #define TS_META_TREE "treesitter_tree"
     41 #define TS_META_NODE "treesitter_node"
     42 #define TS_META_QUERY "treesitter_query"
     43 #define TS_META_QUERYCURSOR "treesitter_querycursor"
     44 #define TS_META_QUERYMATCH "treesitter_querymatch"
     45 
     46 typedef struct {
     47  LuaRef cb;
     48  lua_State *lstate;
     49  bool lex;
     50  bool parse;
     51 } TSLuaLoggerOpts;
     52 
     53 typedef struct {
     54  // We derive TSNode's, TSQueryCursor's, etc., from the TSTree, so it must not be mutated.
     55  const TSTree *tree;
     56 } TSLuaTree;
     57 
     58 typedef struct {
     59  uint64_t parse_start_time;
     60  uint64_t timeout_threshold_ns;
     61 } TSLuaParserCallbackPayload;
     62 
     63 #include "lua/treesitter.c.generated.h"
     64 
     65 static PMap(cstr_t) langs = MAP_INIT;
     66 
     67 #ifdef HAVE_WASMTIME
     68 static wasm_engine_t *wasmengine;
     69 static TSWasmStore *ts_wasmstore;
     70 #endif
     71 
     72 // TSLanguage
     73 
     74 static int tslua_has_language(lua_State *L)
     75 {
     76  const char *lang_name = luaL_checkstring(L, 1);
     77  lua_pushboolean(L, map_has(cstr_t, &langs, lang_name));
     78  return 1;
     79 }
     80 
     81 #ifdef HAVE_WASMTIME
     82 static char *read_file(const char *path, size_t *len)
     83  FUNC_ATTR_MALLOC
     84 {
     85  FILE *file = os_fopen(path, "r");
     86  if (file == NULL) {
     87    return NULL;
     88  }
     89  fseek(file, 0L, SEEK_END);
     90  *len = (size_t)ftell(file);
     91  fseek(file, 0L, SEEK_SET);
     92  char *data = xmalloc(*len);
     93  if (fread(data, *len, 1, file) != 1) {
     94    xfree(data);
     95    fclose(file);
     96    return NULL;
     97  }
     98  fclose(file);
     99  return data;
    100 }
    101 
    102 static const char *wasmerr_to_str(TSWasmErrorKind werr)
    103 {
    104  switch (werr) {
    105  case TSWasmErrorKindParse:
    106    return "PARSE";
    107  case TSWasmErrorKindCompile:
    108    return "COMPILE";
    109  case TSWasmErrorKindInstantiate:
    110    return "INSTANTIATE";
    111  case TSWasmErrorKindAllocate:
    112    return "ALLOCATE";
    113  default:
    114    return "UNKNOWN";
    115  }
    116 }
    117 #endif
    118 
    119 #ifdef HAVE_WASMTIME
    120 static int tslua_add_language_from_wasm(lua_State *L)
    121 {
    122  return add_language(L, true);
    123 }
    124 #endif
    125 
    126 // Creates the language into the internal language map.
    127 //
    128 // Returns true if the language is correctly loaded in the language map
    129 static int tslua_add_language_from_object(lua_State *L)
    130 {
    131  return add_language(L, false);
    132 }
    133 
    134 static const TSLanguage *load_language_from_object(lua_State *L, const char *path,
    135                                                   const char *lang_name, const char *symbol)
    136 {
    137  uv_lib_t lib;
    138  if (uv_dlopen(path, &lib)) {
    139    xstrlcpy(IObuff, uv_dlerror(&lib), sizeof(IObuff));
    140    uv_dlclose(&lib);
    141    luaL_error(L, "Failed to load parser for language '%s': uv_dlopen: %s", lang_name, IObuff);
    142  }
    143 
    144  char symbol_buf[128];
    145  snprintf(symbol_buf, sizeof(symbol_buf), "tree_sitter_%s", symbol);
    146 
    147  TSLanguage *(*lang_parser)(void);
    148  if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) {
    149    xstrlcpy(IObuff, uv_dlerror(&lib), sizeof(IObuff));
    150    uv_dlclose(&lib);
    151    luaL_error(L, "Failed to load parser: uv_dlsym: %s", IObuff);
    152  }
    153 
    154  TSLanguage *lang = lang_parser();
    155 
    156  if (lang == NULL) {
    157    uv_dlclose(&lib);
    158    luaL_error(L, "Failed to load parser %s: internal error", path);
    159  }
    160 
    161  return lang;
    162 }
    163 
    164 static const TSLanguage *load_language_from_wasm(lua_State *L, const char *path,
    165                                                 const char *lang_name)
    166 {
    167 #ifndef HAVE_WASMTIME
    168  luaL_error(L, "Not supported");
    169  return NULL;
    170 #else
    171  if (wasmengine == NULL) {
    172    wasmengine = wasm_engine_new();
    173  }
    174  assert(wasmengine != NULL);
    175 
    176  TSWasmError werr = { 0 };
    177  if (ts_wasmstore == NULL) {
    178    ts_wasmstore = ts_wasm_store_new(wasmengine, &werr);
    179  }
    180 
    181  if (werr.kind > 0) {
    182    luaL_error(L, "Failed to create WASM store: (%s) %s", wasmerr_to_str(werr.kind), werr.message);
    183  }
    184 
    185  size_t file_size = 0;
    186  char *data = read_file(path, &file_size);
    187 
    188  if (data == NULL) {
    189    luaL_error(L, "Unable to read file", path);
    190  }
    191 
    192  const TSLanguage *lang = ts_wasm_store_load_language(ts_wasmstore, lang_name, data,
    193                                                       (uint32_t)file_size, &werr);
    194 
    195  xfree(data);
    196 
    197  if (werr.kind > 0) {
    198    luaL_error(L, "Failed to load WASM parser %s: (%s) %s", path, wasmerr_to_str(werr.kind),
    199               werr.message);
    200  }
    201 
    202  if (lang == NULL) {
    203    luaL_error(L, "Failed to load parser %s: internal error", path);
    204  }
    205 
    206  return lang;
    207 #endif
    208 }
    209 
    210 static int add_language(lua_State *L, bool is_wasm)
    211 {
    212  const char *path = luaL_checkstring(L, 1);
    213  const char *lang_name = luaL_checkstring(L, 2);
    214  const char *symbol_name = lang_name;
    215 
    216  if (!is_wasm && lua_gettop(L) >= 3 && !lua_isnil(L, 3)) {
    217    symbol_name = luaL_checkstring(L, 3);
    218  }
    219 
    220  if (map_has(cstr_t, &langs, lang_name)) {
    221    lua_pushboolean(L, true);
    222    return 1;
    223  }
    224 
    225  const TSLanguage *lang = is_wasm
    226                           ? load_language_from_wasm(L, path, lang_name)
    227                           : load_language_from_object(L, path, lang_name, symbol_name);
    228 
    229  uint32_t lang_version = ts_language_abi_version(lang);
    230  if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION
    231      || lang_version > TREE_SITTER_LANGUAGE_VERSION) {
    232    return luaL_error(L,
    233                      "ABI version mismatch for %s: supported between %d and %d, found %d",
    234                      path,
    235                      TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION,
    236                      TREE_SITTER_LANGUAGE_VERSION, lang_version);
    237  }
    238 
    239  pmap_put(cstr_t)(&langs, xstrdup(lang_name), (TSLanguage *)lang);
    240 
    241  lua_pushboolean(L, true);
    242  return 1;
    243 }
    244 
    245 static int tslua_remove_lang(lua_State *L)
    246 {
    247  const char *lang_name = luaL_checkstring(L, 1);
    248  bool present = map_has(cstr_t, &langs, lang_name);
    249  if (present) {
    250    cstr_t key;
    251    pmap_del(cstr_t)(&langs, lang_name, &key);
    252    xfree((void *)key);
    253  }
    254  lua_pushboolean(L, present);
    255  return 1;
    256 }
    257 
    258 static TSLanguage *lang_check(lua_State *L, int index)
    259 {
    260  const char *lang_name = luaL_checkstring(L, index);
    261  TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
    262  if (!lang) {
    263    luaL_error(L, "no such language: %s", lang_name);
    264  }
    265  return lang;
    266 }
    267 
    268 static int tslua_inspect_lang(lua_State *L)
    269 {
    270  TSLanguage *lang = lang_check(L, 1);
    271 
    272  lua_createtable(L, 0, 2);  // [retval]
    273 
    274  {  // Symbols
    275    uint32_t nsymbols = ts_language_symbol_count(lang);
    276    assert(nsymbols < INT_MAX);
    277 
    278    lua_createtable(L, (int)(nsymbols - 1), 1);  // [retval, symbols]
    279    for (uint32_t i = 0; i < nsymbols; i++) {
    280      TSSymbolType t = ts_language_symbol_type(lang, (TSSymbol)i);
    281      if (t == TSSymbolTypeAuxiliary) {
    282        // not used by the API
    283        continue;
    284      }
    285      const char *name = ts_language_symbol_name(lang, (TSSymbol)i);
    286      bool named = t != TSSymbolTypeAnonymous;
    287      lua_pushboolean(L, named);  // [retval, symbols, is_named]
    288      if (!named) {
    289        char buf[256];
    290        snprintf(buf, sizeof(buf), "\"%s\"", name);
    291        lua_setfield(L, -2, buf);  // [retval, symbols]
    292      } else {
    293        lua_setfield(L, -2, name);  // [retval, symbols]
    294      }
    295    }
    296 
    297    lua_setfield(L, -2, "symbols");  // [retval]
    298  }
    299 
    300  {  // Fields
    301    uint32_t nfields = ts_language_field_count(lang);
    302    lua_createtable(L, (int)nfields, 1);  // [retval, fields]
    303    // Field IDs go from 1 to nfields inclusive (extra index 0 maps to NULL)
    304    for (uint32_t i = 1; i <= nfields; i++) {
    305      lua_pushstring(L, ts_language_field_name_for_id(lang, (TSFieldId)i));
    306      lua_rawseti(L, -2, (int)i);  // [retval, fields]
    307    }
    308 
    309    lua_setfield(L, -2, "fields");  // [retval]
    310  }
    311 
    312  lua_pushboolean(L, ts_language_is_wasm(lang));
    313  lua_setfield(L, -2, "_wasm");
    314 
    315  lua_pushinteger(L, ts_language_abi_version(lang));  // [retval, version]
    316  lua_setfield(L, -2, "abi_version");
    317 
    318  {  // Metadata
    319    const TSLanguageMetadata *meta = ts_language_metadata(lang);
    320 
    321    if (meta != NULL) {
    322      lua_createtable(L, 0, 3);
    323 
    324      lua_pushinteger(L, meta->major_version);
    325      lua_setfield(L, -2, "major_version");
    326      lua_pushinteger(L, meta->minor_version);
    327      lua_setfield(L, -2, "minor_version");
    328      lua_pushinteger(L, meta->patch_version);
    329      lua_setfield(L, -2, "patch_version");
    330 
    331      lua_setfield(L, -2, "metadata");
    332    }
    333  }
    334 
    335  lua_pushinteger(L, ts_language_state_count(lang));
    336  lua_setfield(L, -2, "state_count");
    337 
    338  {  // Supertypes
    339    uint32_t nsupertypes;
    340    const TSSymbol *supertypes = ts_language_supertypes(lang, &nsupertypes);
    341 
    342    lua_createtable(L, 0, (int)nsupertypes);  // [retval, supertypes]
    343    for (uint32_t i = 0; i < nsupertypes; i++) {
    344      const TSSymbol supertype = *(supertypes + i);
    345 
    346      uint32_t nsubtypes;
    347      const TSSymbol *subtypes = ts_language_subtypes(lang, supertype, &nsubtypes);
    348 
    349      lua_createtable(L, (int)nsubtypes, 0);
    350      for (uint32_t j = 1; j <= nsubtypes; j++) {
    351        lua_pushstring(L, ts_language_symbol_name(lang, *(subtypes + j)));
    352        lua_rawseti(L, -2, (int)j);
    353      }
    354 
    355      lua_setfield(L, -2, ts_language_symbol_name(lang, supertype));
    356    }
    357 
    358    lua_setfield(L, -2, "supertypes");  // [retval]
    359  }
    360 
    361  return 1;
    362 }
    363 
    364 // TSParser
    365 
    366 static struct luaL_Reg parser_meta[] = {
    367  { "__gc", parser_gc },
    368  { "__tostring", parser_tostring },
    369  { "parse", parser_parse },
    370  { "reset", parser_reset },
    371  { "set_included_ranges", parser_set_ranges },
    372  { "included_ranges", parser_get_ranges },
    373  { "_set_logger", parser_set_logger },
    374  { "_logger", parser_get_logger },
    375  { NULL, NULL }
    376 };
    377 
    378 static int tslua_push_parser(lua_State *L)
    379 {
    380  TSLanguage *lang = lang_check(L, 1);
    381 
    382  TSParser **parser = lua_newuserdata(L, sizeof(TSParser *));
    383  *parser = ts_parser_new();
    384 
    385 #ifdef HAVE_WASMTIME
    386  if (ts_language_is_wasm(lang)) {
    387    assert(wasmengine != NULL);
    388    ts_parser_set_wasm_store(*parser, ts_wasmstore);
    389  }
    390 #endif
    391 
    392  if (!ts_parser_set_language(*parser, lang)) {
    393    ts_parser_delete(*parser);
    394    const char *lang_name = luaL_checkstring(L, 1);
    395    return luaL_error(L, "Failed to load language : %s", lang_name);
    396  }
    397 
    398  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_PARSER);  // [udata, meta]
    399  lua_setmetatable(L, -2);  // [udata]
    400  return 1;
    401 }
    402 
    403 static TSParser *parser_check(lua_State *L, uint16_t index)
    404 {
    405  TSParser **ud = luaL_checkudata(L, index, TS_META_PARSER);
    406  luaL_argcheck(L, *ud, index, "TSParser expected");
    407  return *ud;
    408 }
    409 
    410 static void logger_gc(TSLogger logger)
    411 {
    412  if (!logger.log) {
    413    return;
    414  }
    415 
    416  TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload;
    417  luaL_unref(opts->lstate, LUA_REGISTRYINDEX, opts->cb);
    418  xfree(opts);
    419 }
    420 
    421 static int parser_gc(lua_State *L)
    422 {
    423  TSParser *p = parser_check(L, 1);
    424  logger_gc(ts_parser_logger(p));
    425  ts_parser_delete(p);
    426  return 0;
    427 }
    428 
    429 static int parser_tostring(lua_State *L)
    430 {
    431  lua_pushstring(L, "<parser>");
    432  return 1;
    433 }
    434 
    435 static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position,
    436                            uint32_t *bytes_read)
    437 {
    438  buf_T *bp = payload;
    439 #define BUFSIZE 256
    440  static char buf[BUFSIZE];
    441 
    442  if ((linenr_T)position.row >= bp->b_ml.ml_line_count) {
    443    *bytes_read = 0;
    444    return "";
    445  }
    446  linenr_T lnum = (linenr_T)position.row + 1;
    447  char *line = ml_get_buf(bp, lnum);
    448  size_t len = (size_t)ml_get_buf_len(bp, lnum);
    449  if (position.column > len) {
    450    *bytes_read = 0;
    451    return "";
    452  }
    453  size_t tocopy = MIN(len - position.column, BUFSIZE);
    454 
    455  memcpy(buf, line + position.column, tocopy);
    456  // Translate embedded \n to NUL
    457  memchrsub(buf, '\n', NUL, tocopy);
    458  *bytes_read = (uint32_t)tocopy;
    459  if (tocopy < BUFSIZE) {
    460    // now add the final \n, if it is meant to be present for this buffer. If it didn't fit,
    461    // input_cb will be called again on the same line with advanced column.
    462    if (lnum != bp->b_ml.ml_line_count || (!bp->b_p_bin && bp->b_p_fixeol)
    463        || (lnum != bp->b_no_eol_lnum && bp->b_p_eol)) {
    464      buf[tocopy] = '\n';
    465      (*bytes_read)++;
    466    }
    467  }
    468  return buf;
    469 #undef BUFSIZE
    470 }
    471 
    472 static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length,
    473                        bool include_bytes)
    474 {
    475  lua_createtable(L, (int)length, 0);
    476  for (size_t i = 0; i < length; i++) {
    477    lua_createtable(L, include_bytes ? 6 : 4, 0);
    478    int j = 1;
    479    lua_pushinteger(L, ranges[i].start_point.row);
    480    lua_rawseti(L, -2, j++);
    481    lua_pushinteger(L, ranges[i].start_point.column);
    482    lua_rawseti(L, -2, j++);
    483    if (include_bytes) {
    484      lua_pushinteger(L, ranges[i].start_byte);
    485      lua_rawseti(L, -2, j++);
    486    }
    487    lua_pushinteger(L, ranges[i].end_point.row);
    488    lua_rawseti(L, -2, j++);
    489    lua_pushinteger(L, ranges[i].end_point.column);
    490    lua_rawseti(L, -2, j++);
    491    if (include_bytes) {
    492      lua_pushinteger(L, ranges[i].end_byte);
    493      lua_rawseti(L, -2, j++);
    494    }
    495 
    496    lua_rawseti(L, -2, (int)(i + 1));
    497  }
    498 }
    499 
    500 static bool on_parser_progress(TSParseState *state)
    501 {
    502  TSLuaParserCallbackPayload *payload = state->payload;
    503  uint64_t parse_time = os_hrtime() - payload->parse_start_time;
    504  return parse_time >= payload->timeout_threshold_ns;
    505 }
    506 
    507 static int parser_parse(lua_State *L)
    508 {
    509  TSParser *p = parser_check(L, 1);
    510  const TSTree *old_tree = NULL;
    511  if (!lua_isnil(L, 2)) {
    512    TSLuaTree *ud = luaL_checkudata(L, 2, TS_META_TREE);
    513    old_tree = ud ? ud->tree : NULL;
    514  }
    515 
    516  TSTree *new_tree = NULL;
    517  size_t len;
    518  const char *str;
    519  handle_T bufnr;
    520  buf_T *buf;
    521  TSInput input;
    522 
    523  // This switch is necessary because of the behavior of lua_isstring, that
    524  // consider numbers as strings...
    525  switch (lua_type(L, 3)) {
    526  case LUA_TSTRING:
    527    str = lua_tolstring(L, 3, &len);
    528    new_tree = ts_parser_parse_string(p, old_tree, str, (uint32_t)len);
    529    break;
    530 
    531  case LUA_TNUMBER:
    532    bufnr = (handle_T)lua_tointeger(L, 3);
    533    buf = handle_get_buffer(bufnr);
    534 
    535    if (!buf) {
    536 #define BUFSIZE 256
    537      char ebuf[BUFSIZE] = { 0 };
    538      vim_snprintf(ebuf, BUFSIZE, "invalid buffer handle: %d", bufnr);
    539      return luaL_argerror(L, 3, ebuf);
    540 #undef BUFSIZE
    541    }
    542 
    543    input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8, NULL };
    544    if (!lua_isnil(L, 5)) {
    545      uint64_t timeout_ns = (uint64_t)lua_tointeger(L, 5);
    546      TSLuaParserCallbackPayload payload =
    547        (TSLuaParserCallbackPayload){ .parse_start_time = os_hrtime(),
    548                                      .timeout_threshold_ns = timeout_ns };
    549      TSParseOptions parse_options = { .payload = &payload,
    550                                       .progress_callback = on_parser_progress };
    551      new_tree = ts_parser_parse_with_options(p, old_tree, input, parse_options);
    552    } else {
    553      new_tree = ts_parser_parse(p, old_tree, input);
    554    }
    555 
    556    break;
    557 
    558  default:
    559    return luaL_argerror(L, 3, "expected either string or buffer handle");
    560  }
    561 
    562  bool include_bytes = (lua_gettop(L) >= 4) && lua_toboolean(L, 4);
    563 
    564  if (!new_tree) {
    565    // Sometimes parsing fails (no language was set, or it was set to one with an incompatible ABI)
    566    // In those cases, just return an error.
    567    if (!ts_parser_language(p)) {
    568      return luaL_error(L, "Language was unset, or has an incompatible ABI.");
    569    }
    570    return 0;
    571  }
    572 
    573  // The new tree will be pushed to the stack, without copy, ownership is now to the lua GC.
    574  // Old tree is owned by lua GC since before
    575  uint32_t n_ranges = 0;
    576  TSRange *changed = old_tree ? ts_tree_get_changed_ranges(old_tree, new_tree, &n_ranges)
    577                              : ts_tree_included_ranges(new_tree, &n_ranges);
    578 
    579  push_tree(L, new_tree);  // [tree]
    580 
    581  push_ranges(L, changed, n_ranges, include_bytes);  // [tree, ranges]
    582 
    583  xfree(changed);
    584  return 2;
    585 }
    586 
    587 static int parser_reset(lua_State *L)
    588 {
    589  TSParser *p = parser_check(L, 1);
    590  ts_parser_reset(p);
    591  return 0;
    592 }
    593 
    594 static void range_err(lua_State *L)
    595 {
    596  luaL_error(L, "Ranges can only be made from 6 element long tables or nodes.");
    597 }
    598 
    599 // Use the top of the stack (without popping it) to create a TSRange, it can be
    600 // either a lua table or a TSNode
    601 static void range_from_lua(lua_State *L, TSRange *range)
    602 {
    603  TSNode node;
    604 
    605  if (lua_istable(L, -1)) {
    606    // should be a table of 6 elements
    607    if (lua_objlen(L, -1) != 6) {
    608      range_err(L);
    609    }
    610 
    611    lua_rawgeti(L, -1, 1);  // [ range, start_row]
    612    uint32_t start_row = (uint32_t)luaL_checkinteger(L, -1);
    613    lua_pop(L, 1);
    614 
    615    lua_rawgeti(L, -1, 2);  // [ range, start_col]
    616    uint32_t start_col = (uint32_t)luaL_checkinteger(L, -1);
    617    lua_pop(L, 1);
    618 
    619    lua_rawgeti(L, -1, 3);  // [ range, start_byte]
    620    uint32_t start_byte = (uint32_t)luaL_checkinteger(L, -1);
    621    lua_pop(L, 1);
    622 
    623    lua_rawgeti(L, -1, 4);  // [ range, end_row]
    624    uint32_t end_row = (uint32_t)luaL_checkinteger(L, -1);
    625    lua_pop(L, 1);
    626 
    627    lua_rawgeti(L, -1, 5);  // [ range, end_col]
    628    uint32_t end_col = (uint32_t)luaL_checkinteger(L, -1);
    629    lua_pop(L, 1);
    630 
    631    lua_rawgeti(L, -1, 6);  // [ range, end_byte]
    632    uint32_t end_byte = (uint32_t)luaL_checkinteger(L, -1);
    633    lua_pop(L, 1);  // [ range ]
    634 
    635    *range = (TSRange) {
    636      .start_point = (TSPoint) {
    637        .row = start_row,
    638        .column = start_col
    639      },
    640      .end_point = (TSPoint) {
    641        .row = end_row,
    642        .column = end_col
    643      },
    644      .start_byte = start_byte,
    645      .end_byte = end_byte,
    646    };
    647  } else if (node_check_opt(L, -1, &node)) {
    648    *range = (TSRange) {
    649      .start_point = ts_node_start_point(node),
    650      .end_point = ts_node_end_point(node),
    651      .start_byte = ts_node_start_byte(node),
    652      .end_byte = ts_node_end_byte(node)
    653    };
    654  } else {
    655    range_err(L);
    656  }
    657 }
    658 
    659 static int parser_set_ranges(lua_State *L)
    660 {
    661  if (lua_gettop(L) < 2) {
    662    return luaL_error(L, "not enough args to parser:set_included_ranges()");
    663  }
    664 
    665  TSParser *p = parser_check(L, 1);
    666 
    667  luaL_argcheck(L, lua_istable(L, 2), 2, "table expected.");
    668 
    669  size_t tbl_len = lua_objlen(L, 2);
    670  TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len);
    671 
    672  // [ parser, ranges ]
    673  for (size_t index = 0; index < tbl_len; index++) {
    674    lua_rawgeti(L, 2, (int)index + 1);  // [ parser, ranges, range ]
    675    range_from_lua(L, ranges + index);
    676    lua_pop(L, 1);
    677  }
    678 
    679  // This memcpies ranges, thus we can free it afterwards
    680  ts_parser_set_included_ranges(p, ranges, (uint32_t)tbl_len);
    681  xfree(ranges);
    682 
    683  return 0;
    684 }
    685 
    686 static int parser_get_ranges(lua_State *L)
    687 {
    688  TSParser *p = parser_check(L, 1);
    689 
    690  bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
    691 
    692  uint32_t len;
    693  const TSRange *ranges = ts_parser_included_ranges(p, &len);
    694 
    695  push_ranges(L, ranges, len, include_bytes);
    696  return 1;
    697 }
    698 
    699 static void logger_cb(void *payload, TSLogType logtype, const char *s)
    700 {
    701  TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)payload;
    702  if ((!opts->lex && logtype == TSLogTypeLex)
    703      || (!opts->parse && logtype == TSLogTypeParse)) {
    704    return;
    705  }
    706 
    707  lua_State *lstate = opts->lstate;
    708 
    709  lua_rawgeti(lstate, LUA_REGISTRYINDEX, opts->cb);
    710  lua_pushstring(lstate, logtype == TSLogTypeParse ? "parse" : "lex");
    711  lua_pushstring(lstate, s);
    712  if (lua_pcall(lstate, 2, 0, 0)) {
    713    luaL_error(lstate, "treesitter logger callback failed");
    714  }
    715 }
    716 
    717 static int parser_set_logger(lua_State *L)
    718 {
    719  TSParser *p = parser_check(L, 1);
    720 
    721  luaL_argcheck(L, lua_isboolean(L, 2), 2, "boolean expected");
    722  luaL_argcheck(L, lua_isboolean(L, 3), 3, "boolean expected");
    723  luaL_argcheck(L, lua_isfunction(L, 4), 4, "function expected");
    724 
    725  TSLuaLoggerOpts *opts = xmalloc(sizeof(TSLuaLoggerOpts));
    726  lua_pushvalue(L, 4);
    727  LuaRef ref = luaL_ref(L, LUA_REGISTRYINDEX);
    728 
    729  *opts = (TSLuaLoggerOpts){
    730    .lex = lua_toboolean(L, 2),
    731    .parse = lua_toboolean(L, 3),
    732    .cb = ref,
    733    .lstate = L
    734  };
    735 
    736  TSLogger logger = {
    737    .payload = (void *)opts,
    738    .log = logger_cb
    739  };
    740 
    741  ts_parser_set_logger(p, logger);
    742  return 0;
    743 }
    744 
    745 static int parser_get_logger(lua_State *L)
    746 {
    747  TSParser *p = parser_check(L, 1);
    748  TSLogger logger = ts_parser_logger(p);
    749  if (logger.log) {
    750    TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload;
    751    lua_rawgeti(L, LUA_REGISTRYINDEX, opts->cb);
    752  } else {
    753    lua_pushnil(L);
    754  }
    755 
    756  return 1;
    757 }
    758 
    759 // TSTree
    760 
    761 static struct luaL_Reg tree_meta[] = {
    762  { "__gc", tree_gc },
    763  { "__tostring", tree_tostring },
    764  { "root", tree_root },
    765  { "edit", tree_edit },
    766  { "included_ranges", tree_get_ranges },
    767  { "copy", tree_copy },
    768  { NULL, NULL }
    769 };
    770 
    771 /// Push tree interface on to the lua stack.
    772 ///
    773 /// The tree is not copied. Ownership of the tree is transferred from C to
    774 /// Lua. If needed use ts_tree_copy() in the caller.
    775 static void push_tree(lua_State *L, const TSTree *tree)
    776 {
    777  if (tree == NULL) {
    778    lua_pushnil(L);
    779    return;
    780  }
    781 
    782  TSLuaTree *ud = lua_newuserdata(L, sizeof(TSLuaTree));  // [udata]
    783  ud->tree = tree;
    784  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREE);  // [udata, meta]
    785  lua_setmetatable(L, -2);  // [udata]
    786 }
    787 
    788 static int tree_copy(lua_State *L)
    789 {
    790  TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
    791  TSTree *copy = ts_tree_copy(ud->tree);
    792  push_tree(L, copy);  // [tree]
    793 
    794  return 1;
    795 }
    796 
    797 static int tree_edit(lua_State *L)
    798 {
    799  if (lua_gettop(L) < 10) {
    800    lua_pushstring(L, "not enough args to tree:edit()");
    801    return lua_error(L);
    802  }
    803 
    804  TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
    805 
    806  uint32_t start_byte = (uint32_t)luaL_checkint(L, 2);
    807  uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3);
    808  uint32_t new_end_byte = (uint32_t)luaL_checkint(L, 4);
    809  TSPoint start_point = { (uint32_t)luaL_checkint(L, 5), (uint32_t)luaL_checkint(L, 6) };
    810  TSPoint old_end_point = { (uint32_t)luaL_checkint(L, 7), (uint32_t)luaL_checkint(L, 8) };
    811  TSPoint new_end_point = { (uint32_t)luaL_checkint(L, 9), (uint32_t)luaL_checkint(L, 10) };
    812 
    813  TSInputEdit edit = { start_byte, old_end_byte, new_end_byte,
    814                       start_point, old_end_point, new_end_point };
    815 
    816  TSTree *new_tree = ts_tree_copy(ud->tree);
    817  ts_tree_edit(new_tree, &edit);
    818 
    819  push_tree(L, new_tree);  // [tree]
    820 
    821  return 1;
    822 }
    823 
    824 static int tree_get_ranges(lua_State *L)
    825 {
    826  TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
    827 
    828  bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
    829 
    830  uint32_t len;
    831  TSRange *ranges = ts_tree_included_ranges(ud->tree, &len);
    832 
    833  push_ranges(L, ranges, len, include_bytes);
    834 
    835  xfree(ranges);
    836  return 1;
    837 }
    838 
    839 static int tree_gc(lua_State *L)
    840 {
    841  TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
    842 
    843  // SAFETY: we can cast the const away because the tree is only garbage collected after all of its
    844  // TSNode's, TSQuerCurors, etc., are unreachable (each contains a reference to the TSLuaTree)
    845  TSTree *tree = (TSTree *)ud->tree;
    846 
    847  ts_tree_delete(tree);
    848  return 0;
    849 }
    850 
    851 static int tree_tostring(lua_State *L)
    852 {
    853  lua_pushstring(L, "<tree>");
    854  return 1;
    855 }
    856 
    857 static int tree_root(lua_State *L)
    858 {
    859  TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
    860 
    861  TSNode root = ts_tree_root_node(ud->tree);
    862 
    863  TSNode *node_ud = lua_newuserdata(L, sizeof(TSNode));  // [node]
    864  *node_ud = root;
    865  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE);  // [node, meta]
    866  lua_setmetatable(L, -2);  // [node]
    867 
    868  // To prevent the tree from being garbage collected, create a reference to it
    869  // in the fenv which will be passed to userdata nodes of the tree.
    870  // Note: environments (fenvs) associated with userdata have no meaning in Lua
    871  // and are only used to associate a table.
    872  lua_createtable(L, 1, 0);  // [node, reftable]
    873  lua_pushvalue(L, 1);  // [node, reftable, tree]
    874  lua_rawseti(L, -2, 1);  // [node, reftable]
    875  lua_setfenv(L, -2);  // [node]
    876 
    877  return 1;
    878 }
    879 
    880 // TSNode
    881 static struct luaL_Reg node_meta[] = {
    882  { "__tostring", node_tostring },
    883  { "__eq", node_eq },
    884  { "__len", node_child_count },
    885  { "id", node_id },
    886  { "range", node_range },
    887  { "start", node_start },
    888  { "end_", node_end },
    889  { "type", node_type },
    890  { "symbol", node_symbol },
    891  { "field", node_field },
    892  { "named", node_named },
    893  { "missing", node_missing },
    894  { "extra", node_extra },
    895  { "has_changes", node_has_changes },
    896  { "has_error", node_has_error },
    897  { "sexpr", node_sexpr },
    898  { "child_count", node_child_count },
    899  { "named_child_count", node_named_child_count },
    900  { "child", node_child },
    901  { "named_child", node_named_child },
    902  { "descendant_for_range", node_descendant_for_range },
    903  { "named_descendant_for_range", node_named_descendant_for_range },
    904  { "parent", node_parent },
    905  { "__has_ancestor", __has_ancestor },
    906  { "child_with_descendant", node_child_with_descendant },
    907  { "iter_children", node_iter_children },
    908  { "next_sibling", node_next_sibling },
    909  { "prev_sibling", node_prev_sibling },
    910  { "next_named_sibling", node_next_named_sibling },
    911  { "prev_named_sibling", node_prev_named_sibling },
    912  { "named_children", node_named_children },
    913  { "root", node_root },
    914  { "tree", node_tree },
    915  { "byte_length", node_byte_length },
    916  { "equal", node_equal },
    917 
    918  { NULL, NULL }
    919 };
    920 
    921 /// Push node interface on to the Lua stack
    922 ///
    923 /// Stack at `uindex` must have a value with a fenv with a reference to node's
    924 /// tree. This value is not popped. Can only be called inside a cfunction with
    925 /// the tslua environment.
    926 static void push_node(lua_State *L, TSNode node, int uindex)
    927 {
    928  assert(uindex > 0 || uindex < -LUA_MINSTACK);
    929  if (ts_node_is_null(node)) {
    930    lua_pushnil(L);  // [nil]
    931    return;
    932  }
    933 
    934  TSNode *ud = lua_newuserdata(L, sizeof(TSNode));  // [udata]
    935  *ud = node;
    936  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE);  // [udata, meta]
    937  lua_setmetatable(L, -2);  // [udata]
    938 
    939  // Copy the fenv to keep alive a reference to the node's tree.
    940  lua_getfenv(L, uindex);  // [udata, reftable]
    941  lua_setfenv(L, -2);  // [udata]
    942 }
    943 
    944 static bool node_check_opt(lua_State *L, int index, TSNode *res)
    945 {
    946  TSNode *ud = luaL_checkudata(L, index, TS_META_NODE);
    947  if (ud) {
    948    *res = *ud;
    949    return true;
    950  }
    951  return false;
    952 }
    953 
    954 static TSNode node_check(lua_State *L, int index)
    955 {
    956  TSNode *ud = luaL_checkudata(L, index, TS_META_NODE);
    957  return *ud;
    958 }
    959 
    960 static int node_tostring(lua_State *L)
    961 {
    962  TSNode node = node_check(L, 1);
    963  lua_pushstring(L, "<node ");
    964  lua_pushstring(L, ts_node_type(node));
    965  lua_pushstring(L, ">");
    966  lua_concat(L, 3);
    967  return 1;
    968 }
    969 
    970 static int node_eq(lua_State *L)
    971 {
    972  TSNode node = node_check(L, 1);
    973  TSNode node2 = node_check(L, 2);
    974  lua_pushboolean(L, ts_node_eq(node, node2));
    975  return 1;
    976 }
    977 
    978 static int node_id(lua_State *L)
    979 {
    980  TSNode node = node_check(L, 1);
    981  lua_pushlstring(L, (const char *)&node.id, sizeof node.id);
    982  return 1;
    983 }
    984 
    985 static int node_range(lua_State *L)
    986 {
    987  TSNode node = node_check(L, 1);
    988 
    989  bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
    990 
    991  TSPoint start = ts_node_start_point(node);
    992  TSPoint end = ts_node_end_point(node);
    993 
    994  if (include_bytes) {
    995    lua_pushinteger(L, start.row);
    996    lua_pushinteger(L, start.column);
    997    lua_pushinteger(L, ts_node_start_byte(node));
    998    lua_pushinteger(L, end.row);
    999    lua_pushinteger(L, end.column);
   1000    lua_pushinteger(L, ts_node_end_byte(node));
   1001    return 6;
   1002  }
   1003 
   1004  lua_pushinteger(L, start.row);
   1005  lua_pushinteger(L, start.column);
   1006  lua_pushinteger(L, end.row);
   1007  lua_pushinteger(L, end.column);
   1008  return 4;
   1009 }
   1010 
   1011 static int node_start(lua_State *L)
   1012 {
   1013  TSNode node = node_check(L, 1);
   1014  TSPoint start = ts_node_start_point(node);
   1015  uint32_t start_byte = ts_node_start_byte(node);
   1016  lua_pushinteger(L, start.row);
   1017  lua_pushinteger(L, start.column);
   1018  lua_pushinteger(L, start_byte);
   1019  return 3;
   1020 }
   1021 
   1022 static int node_end(lua_State *L)
   1023 {
   1024  TSNode node = node_check(L, 1);
   1025  TSPoint end = ts_node_end_point(node);
   1026  uint32_t end_byte = ts_node_end_byte(node);
   1027  lua_pushinteger(L, end.row);
   1028  lua_pushinteger(L, end.column);
   1029  lua_pushinteger(L, end_byte);
   1030  return 3;
   1031 }
   1032 
   1033 static int node_child_count(lua_State *L)
   1034 {
   1035  TSNode node = node_check(L, 1);
   1036  uint32_t count = ts_node_child_count(node);
   1037  lua_pushinteger(L, count);
   1038  return 1;
   1039 }
   1040 
   1041 static int node_named_child_count(lua_State *L)
   1042 {
   1043  TSNode node = node_check(L, 1);
   1044  uint32_t count = ts_node_named_child_count(node);
   1045  lua_pushinteger(L, count);
   1046  return 1;
   1047 }
   1048 
   1049 static int node_type(lua_State *L)
   1050 {
   1051  TSNode node = node_check(L, 1);
   1052  lua_pushstring(L, ts_node_type(node));
   1053  return 1;
   1054 }
   1055 
   1056 static int node_symbol(lua_State *L)
   1057 {
   1058  TSNode node = node_check(L, 1);
   1059  TSSymbol symbol = ts_node_symbol(node);
   1060  lua_pushinteger(L, symbol);
   1061  return 1;
   1062 }
   1063 
   1064 static int node_field(lua_State *L)
   1065 {
   1066  TSNode node = node_check(L, 1);
   1067  uint32_t count = ts_node_child_count(node);
   1068  int curr_index = 0;
   1069 
   1070  size_t name_len;
   1071  const char *field_name = luaL_checklstring(L, 2, &name_len);
   1072 
   1073  lua_newtable(L);
   1074 
   1075  for (uint32_t i = 0; i < count; i++) {
   1076    const char *child_field_name = ts_node_field_name_for_child(node, i);
   1077    if (strequal(field_name, child_field_name)) {
   1078      TSNode child = ts_node_child(node, i);
   1079      push_node(L, child, 1);
   1080      lua_rawseti(L, -2, ++curr_index);
   1081    }
   1082  }
   1083 
   1084  return 1;
   1085 }
   1086 
   1087 static int node_named(lua_State *L)
   1088 {
   1089  TSNode node = node_check(L, 1);
   1090  lua_pushboolean(L, ts_node_is_named(node));
   1091  return 1;
   1092 }
   1093 
   1094 static int node_sexpr(lua_State *L)
   1095 {
   1096  TSNode node = node_check(L, 1);
   1097  char *allocated = ts_node_string(node);
   1098  lua_pushstring(L, allocated);
   1099  xfree(allocated);
   1100  return 1;
   1101 }
   1102 
   1103 static int node_missing(lua_State *L)
   1104 {
   1105  TSNode node = node_check(L, 1);
   1106  lua_pushboolean(L, ts_node_is_missing(node));
   1107  return 1;
   1108 }
   1109 
   1110 static int node_extra(lua_State *L)
   1111 {
   1112  TSNode node = node_check(L, 1);
   1113  lua_pushboolean(L, ts_node_is_extra(node));
   1114  return 1;
   1115 }
   1116 
   1117 static int node_has_changes(lua_State *L)
   1118 {
   1119  TSNode node = node_check(L, 1);
   1120  lua_pushboolean(L, ts_node_has_changes(node));
   1121  return 1;
   1122 }
   1123 
   1124 static int node_has_error(lua_State *L)
   1125 {
   1126  TSNode node = node_check(L, 1);
   1127  lua_pushboolean(L, ts_node_has_error(node));
   1128  return 1;
   1129 }
   1130 
   1131 static int node_child(lua_State *L)
   1132 {
   1133  TSNode node = node_check(L, 1);
   1134  uint32_t num = (uint32_t)lua_tointeger(L, 2);
   1135  TSNode child = ts_node_child(node, num);
   1136 
   1137  push_node(L, child, 1);
   1138  return 1;
   1139 }
   1140 
   1141 static int node_named_child(lua_State *L)
   1142 {
   1143  TSNode node = node_check(L, 1);
   1144  uint32_t num = (uint32_t)lua_tointeger(L, 2);
   1145  TSNode child = ts_node_named_child(node, num);
   1146 
   1147  push_node(L, child, 1);
   1148  return 1;
   1149 }
   1150 
   1151 static int node_descendant_for_range(lua_State *L)
   1152 {
   1153  TSNode node = node_check(L, 1);
   1154  TSPoint start = { (uint32_t)lua_tointeger(L, 2),
   1155                    (uint32_t)lua_tointeger(L, 3) };
   1156  TSPoint end = { (uint32_t)lua_tointeger(L, 4),
   1157                  (uint32_t)lua_tointeger(L, 5) };
   1158  TSNode child = ts_node_descendant_for_point_range(node, start, end);
   1159 
   1160  push_node(L, child, 1);
   1161  return 1;
   1162 }
   1163 
   1164 static int node_named_descendant_for_range(lua_State *L)
   1165 {
   1166  TSNode node = node_check(L, 1);
   1167  TSPoint start = { (uint32_t)lua_tointeger(L, 2),
   1168                    (uint32_t)lua_tointeger(L, 3) };
   1169  TSPoint end = { (uint32_t)lua_tointeger(L, 4),
   1170                  (uint32_t)lua_tointeger(L, 5) };
   1171  TSNode child = ts_node_named_descendant_for_point_range(node, start, end);
   1172 
   1173  push_node(L, child, 1);
   1174  return 1;
   1175 }
   1176 
   1177 static int node_next_child(lua_State *L)
   1178 {
   1179  uint32_t *child_index = lua_touserdata(L, lua_upvalueindex(1));
   1180  TSNode source = node_check(L, lua_upvalueindex(2));
   1181 
   1182  if (*child_index >= ts_node_child_count(source)) {
   1183    return 0;
   1184  }
   1185 
   1186  TSNode child = ts_node_child(source, *child_index);
   1187  push_node(L, child, lua_upvalueindex(2));
   1188 
   1189  const char *field = ts_node_field_name_for_child(source, *child_index);
   1190  if (field != NULL) {
   1191    lua_pushstring(L, field);
   1192  } else {
   1193    lua_pushnil(L);
   1194  }  // [node, field_name_or_nil]
   1195 
   1196  (*child_index)++;
   1197 
   1198  return 2;
   1199 }
   1200 
   1201 static int node_iter_children(lua_State *L)
   1202 {
   1203  node_check(L, 1);
   1204  uint32_t *child_index = lua_newuserdata(L, sizeof(uint32_t));  // [source_node,..., udata]
   1205  *child_index = 0;
   1206 
   1207  lua_pushvalue(L, 1);  // [source_node, ..., udata, source_node]
   1208  lua_pushcclosure(L, node_next_child, 2);
   1209 
   1210  return 1;
   1211 }
   1212 
   1213 static int node_parent(lua_State *L)
   1214 {
   1215  TSNode node = node_check(L, 1);
   1216  TSNode parent = ts_node_parent(node);
   1217  push_node(L, parent, 1);
   1218  return 1;
   1219 }
   1220 
   1221 static int __has_ancestor(lua_State *L)
   1222 {
   1223  TSNode descendant = node_check(L, 1);
   1224  if (lua_type(L, 2) != LUA_TTABLE) {
   1225    lua_pushboolean(L, false);
   1226    return 1;
   1227  }
   1228  int const pred_len = (int)lua_objlen(L, 2);
   1229 
   1230  TSNode node = ts_tree_root_node(descendant.tree);
   1231  while (node.id != descendant.id && !ts_node_is_null(node)) {
   1232    char const *node_type = ts_node_type(node);
   1233    size_t node_type_len = strlen(node_type);
   1234 
   1235    for (int i = 3; i <= pred_len; i++) {
   1236      lua_rawgeti(L, 2, i);
   1237      if (lua_type(L, -1) == LUA_TSTRING) {
   1238        size_t check_len;
   1239        char const *check_str = lua_tolstring(L, -1, &check_len);
   1240        if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) {
   1241          lua_pushboolean(L, true);
   1242          return 1;
   1243        }
   1244      }
   1245      lua_pop(L, 1);
   1246    }
   1247 
   1248    node = ts_node_child_with_descendant(node, descendant);
   1249  }
   1250 
   1251  lua_pushboolean(L, false);
   1252  return 1;
   1253 }
   1254 
   1255 static int node_child_with_descendant(lua_State *L)
   1256 {
   1257  TSNode node = node_check(L, 1);
   1258  TSNode descendant = node_check(L, 2);
   1259  TSNode child = ts_node_child_with_descendant(node, descendant);
   1260  push_node(L, child, 1);
   1261  return 1;
   1262 }
   1263 
   1264 static int node_next_sibling(lua_State *L)
   1265 {
   1266  TSNode node = node_check(L, 1);
   1267  TSNode sibling = ts_node_next_sibling(node);
   1268  push_node(L, sibling, 1);
   1269  return 1;
   1270 }
   1271 
   1272 static int node_prev_sibling(lua_State *L)
   1273 {
   1274  TSNode node = node_check(L, 1);
   1275  TSNode sibling = ts_node_prev_sibling(node);
   1276  push_node(L, sibling, 1);
   1277  return 1;
   1278 }
   1279 
   1280 static int node_next_named_sibling(lua_State *L)
   1281 {
   1282  TSNode node = node_check(L, 1);
   1283  TSNode sibling = ts_node_next_named_sibling(node);
   1284  push_node(L, sibling, 1);
   1285  return 1;
   1286 }
   1287 
   1288 static int node_prev_named_sibling(lua_State *L)
   1289 {
   1290  TSNode node = node_check(L, 1);
   1291  TSNode sibling = ts_node_prev_named_sibling(node);
   1292  push_node(L, sibling, 1);
   1293  return 1;
   1294 }
   1295 
   1296 static int node_named_children(lua_State *L)
   1297 {
   1298  TSNode source = node_check(L, 1);
   1299 
   1300  lua_newtable(L);
   1301  int curr_index = 0;
   1302 
   1303  uint32_t n = ts_node_child_count(source);
   1304  for (uint32_t i = 0; i < n; i++) {
   1305    TSNode child = ts_node_child(source, i);
   1306    if (ts_node_is_named(child)) {
   1307      push_node(L, child, 1);
   1308      lua_rawseti(L, -2, ++curr_index);
   1309    }
   1310  }
   1311 
   1312  return 1;
   1313 }
   1314 
   1315 static int node_root(lua_State *L)
   1316 {
   1317  TSNode node = node_check(L, 1);
   1318  TSNode root = ts_tree_root_node(node.tree);
   1319  push_node(L, root, 1);
   1320  return 1;
   1321 }
   1322 
   1323 static int node_tree(lua_State *L)
   1324 {
   1325  node_check(L, 1);
   1326 
   1327  // Get the tree from the node fenv. We cannot use `push_tree(node.tree)` here because that would
   1328  // cause a double free.
   1329  lua_getfenv(L, 1);  // [node, reftable]
   1330  lua_rawgeti(L, 2, 1);  // [node, reftable, tree]
   1331 
   1332  return 1;
   1333 }
   1334 
   1335 static int node_byte_length(lua_State *L)
   1336 {
   1337  TSNode node = node_check(L, 1);
   1338  uint32_t start_byte = ts_node_start_byte(node);
   1339  uint32_t end_byte = ts_node_end_byte(node);
   1340  lua_pushinteger(L, end_byte - start_byte);
   1341  return 1;
   1342 }
   1343 
   1344 static int node_equal(lua_State *L)
   1345 {
   1346  TSNode node1 = node_check(L, 1);
   1347  TSNode node2 = node_check(L, 2);
   1348  lua_pushboolean(L, ts_node_eq(node1, node2));
   1349  return 1;
   1350 }
   1351 
   1352 // TSQueryCursor
   1353 
   1354 static struct luaL_Reg querycursor_meta[] = {
   1355  { "remove_match", querycursor_remove_match },
   1356  { "next_capture", querycursor_next_capture },
   1357  { "next_match", querycursor_next_match },
   1358  { "__gc", querycursor_gc },
   1359  { NULL, NULL }
   1360 };
   1361 
   1362 static int tslua_push_querycursor(lua_State *L)
   1363 {
   1364  TSNode node = node_check(L, 1);
   1365 
   1366  TSQuery *query = query_check(L, 2);
   1367  TSQueryCursor *cursor = ts_query_cursor_new();
   1368 
   1369  if (lua_gettop(L) >= 3 && !lua_isnil(L, 3)) {
   1370    luaL_argcheck(L, lua_istable(L, 3), 3, "table expected");
   1371  }
   1372 
   1373  lua_getfield(L, 3, "start_row");
   1374  uint32_t start_row = (uint32_t)luaL_checkinteger(L, -1);
   1375  lua_pop(L, 1);
   1376 
   1377  lua_getfield(L, 3, "start_col");
   1378  uint32_t start_col = (uint32_t)luaL_checkinteger(L, -1);
   1379  lua_pop(L, 1);
   1380 
   1381  lua_getfield(L, 3, "end_row");
   1382  uint32_t end_row = (uint32_t)luaL_checkinteger(L, -1);
   1383  lua_pop(L, 1);
   1384 
   1385  lua_getfield(L, 3, "end_col");
   1386  uint32_t end_col = (uint32_t)luaL_checkinteger(L, -1);
   1387  lua_pop(L, 1);
   1388 
   1389  ts_query_cursor_set_point_range(cursor, (TSPoint){ start_row, start_col },
   1390                                  (TSPoint){ end_row, end_col });
   1391 
   1392  lua_getfield(L, 3, "max_start_depth");
   1393  if (!lua_isnil(L, -1)) {
   1394    uint32_t max_start_depth = (uint32_t)luaL_checkinteger(L, -1);
   1395    ts_query_cursor_set_max_start_depth(cursor, max_start_depth);
   1396  }
   1397  lua_pop(L, 1);
   1398 
   1399  lua_getfield(L, 3, "match_limit");
   1400  if (!lua_isnil(L, -1)) {
   1401    uint32_t match_limit = (uint32_t)luaL_checkinteger(L, -1);
   1402    ts_query_cursor_set_match_limit(cursor, match_limit);
   1403  }
   1404  lua_pop(L, 1);
   1405 
   1406  ts_query_cursor_exec(cursor, query, node);
   1407 
   1408  TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud));  // [node, query, ..., udata]
   1409  *ud = cursor;
   1410  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR);  // [node, query, ..., udata, meta]
   1411  lua_setmetatable(L, -2);  // [node, query, ..., udata]
   1412 
   1413  // Copy the fenv which contains the nodes tree.
   1414  lua_getfenv(L, 1);  // [udata, reftable]
   1415  lua_setfenv(L, -2);  // [udata]
   1416 
   1417  return 1;
   1418 }
   1419 
   1420 static int querycursor_remove_match(lua_State *L)
   1421 {
   1422  TSQueryCursor *cursor = querycursor_check(L, 1);
   1423  uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2);
   1424  ts_query_cursor_remove_match(cursor, match_id);
   1425  return 0;
   1426 }
   1427 
   1428 static int querycursor_next_capture(lua_State *L)
   1429 {
   1430  TSQueryCursor *cursor = querycursor_check(L, 1);
   1431  TSQueryMatch match;
   1432  uint32_t capture_index;
   1433  if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
   1434    return 0;
   1435  }
   1436 
   1437  TSQueryCapture capture = match.captures[capture_index];
   1438 
   1439  // Handle capture quantifiers here
   1440  lua_pushinteger(L, capture.index + 1);  // [index]
   1441  push_node(L, capture.node, 1);  // [index, node]
   1442  push_querymatch(L, &match, 1);
   1443 
   1444  return 3;
   1445 }
   1446 
   1447 static int querycursor_next_match(lua_State *L)
   1448 {
   1449  TSQueryCursor *cursor = querycursor_check(L, 1);
   1450 
   1451  TSQueryMatch match;
   1452  if (!ts_query_cursor_next_match(cursor, &match)) {
   1453    return 0;
   1454  }
   1455 
   1456  push_querymatch(L, &match, 1);
   1457 
   1458  return 1;
   1459 }
   1460 
   1461 static TSQueryCursor *querycursor_check(lua_State *L, int index)
   1462 {
   1463  TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR);
   1464  luaL_argcheck(L, *ud, index, "TSQueryCursor expected");
   1465  return *ud;
   1466 }
   1467 
   1468 static int querycursor_gc(lua_State *L)
   1469 {
   1470  TSQueryCursor *cursor = querycursor_check(L, 1);
   1471  ts_query_cursor_delete(cursor);
   1472  return 0;
   1473 }
   1474 
   1475 // TSQueryMatch
   1476 
   1477 static struct luaL_Reg querymatch_meta[] = {
   1478  { "info", querymatch_info },
   1479  { "captures", querymatch_captures },
   1480  { NULL, NULL }
   1481 };
   1482 
   1483 static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex)
   1484 {
   1485  TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch));  // [udata]
   1486  *ud = *match;
   1487  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH);  // [udata, meta]
   1488  lua_setmetatable(L, -2);  // [udata]
   1489 
   1490  // Copy the fenv which contains the nodes tree.
   1491  lua_getfenv(L, uindex);  // [udata, reftable]
   1492  lua_setfenv(L, -2);  // [udata]
   1493 }
   1494 
   1495 static int querymatch_info(lua_State *L)
   1496 {
   1497  TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
   1498  lua_pushinteger(L, match->id);
   1499  lua_pushinteger(L, match->pattern_index + 1);
   1500  return 2;
   1501 }
   1502 
   1503 static int querymatch_captures(lua_State *L)
   1504 {
   1505  TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
   1506  lua_newtable(L);  // [match, nodes, captures]
   1507  for (size_t i = 0; i < match->capture_count; i++) {
   1508    TSQueryCapture capture = match->captures[i];
   1509    int index = (int)capture.index + 1;
   1510 
   1511    lua_rawgeti(L, -1, index);  // [match, node, captures]
   1512    if (lua_isnil(L, -1)) {  // [match, node, captures, nil]
   1513      lua_pop(L, 1);  // [match, node, captures]
   1514      lua_newtable(L);  // [match, node, captures, nodes]
   1515    }
   1516    push_node(L, capture.node, 1);  // [match, node, captures, nodes, node]
   1517    lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1);  // [match, node, captures, nodes]
   1518    lua_rawseti(L, -2, index);  // [match, node, captures]
   1519  }
   1520  return 1;
   1521 }
   1522 
   1523 // TSQuery
   1524 
   1525 static struct luaL_Reg query_meta[] = {
   1526  { "__gc", query_gc },
   1527  { "__tostring", query_tostring },
   1528  { "inspect", query_inspect },
   1529  { "disable_capture", query_disable_capture },
   1530  { "disable_pattern", query_disable_pattern },
   1531  { NULL, NULL }
   1532 };
   1533 
   1534 static int tslua_parse_query(lua_State *L)
   1535 {
   1536  if (lua_gettop(L) < 2 || !lua_isstring(L, 1) || !lua_isstring(L, 2)) {
   1537    return luaL_error(L, "string expected");
   1538  }
   1539 
   1540  TSLanguage *lang = lang_check(L, 1);
   1541 
   1542  size_t len;
   1543  const char *src = lua_tolstring(L, 2, &len);
   1544 
   1545  tslua_query_parse_count++;
   1546  uint32_t error_offset;
   1547  TSQueryError error_type;
   1548  TSQuery *query = ts_query_new(lang, src, (uint32_t)len, &error_offset, &error_type);
   1549 
   1550  if (!query) {
   1551    char err_msg[IOSIZE];
   1552    query_err_string(src, (int)error_offset, error_type, err_msg, sizeof(err_msg));
   1553    return luaL_error(L, "%s", err_msg);
   1554  }
   1555 
   1556  TSQuery **ud = lua_newuserdata(L, sizeof(TSQuery *));  // [udata]
   1557  *ud = query;
   1558  lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERY);  // [udata, meta]
   1559  lua_setmetatable(L, -2);  // [udata]
   1560  return 1;
   1561 }
   1562 
   1563 static const char *query_err_to_string(TSQueryError error_type)
   1564 {
   1565  switch (error_type) {
   1566  case TSQueryErrorSyntax:
   1567    return "Invalid syntax:\n";
   1568  case TSQueryErrorNodeType:
   1569    return "Invalid node type ";
   1570  case TSQueryErrorField:
   1571    return "Invalid field name ";
   1572  case TSQueryErrorCapture:
   1573    return "Invalid capture name ";
   1574  case TSQueryErrorStructure:
   1575    return "Impossible pattern:\n";
   1576  default:
   1577    return "error";
   1578  }
   1579 }
   1580 
   1581 static void query_err_string(const char *src, int error_offset, TSQueryError error_type, char *err,
   1582                             size_t errlen)
   1583 {
   1584  int line_start = 0;
   1585  int row = 0;
   1586  const char *error_line = NULL;
   1587  int error_line_len = 0;
   1588 
   1589  const char *end_str;
   1590  do {
   1591    const char *src_tmp = src + line_start;
   1592    end_str = strchr(src_tmp, '\n');
   1593    int line_length = end_str != NULL ? (int)(end_str - src_tmp) : (int)strlen(src_tmp);
   1594    int line_end = line_start + line_length;
   1595    if (line_end > error_offset) {
   1596      error_line = src_tmp;
   1597      error_line_len = line_length;
   1598      break;
   1599    }
   1600    line_start = line_end + 1;
   1601    row++;
   1602  } while (end_str != NULL);
   1603 
   1604  int column = error_offset - line_start;
   1605 
   1606  const char *type_msg = query_err_to_string(error_type);
   1607  snprintf(err, errlen, "Query error at %d:%d. %s", row + 1, column + 1, type_msg);
   1608  size_t offset = strlen(err);
   1609  errlen = errlen - offset;
   1610  err = err + offset;
   1611 
   1612  // Error types that report names
   1613  if (error_type == TSQueryErrorNodeType
   1614      || error_type == TSQueryErrorField
   1615      || error_type == TSQueryErrorCapture) {
   1616    const char *suffix = src + error_offset;
   1617    bool is_anonymous = error_type == TSQueryErrorNodeType && suffix[-1] == '"';
   1618    int suffix_len = 0;
   1619    char c = suffix[suffix_len];
   1620    if (is_anonymous) {
   1621      int backslashes = 0;
   1622      // Stop when we hit an unescaped double quote
   1623      while (c != '"' || backslashes % 2 != 0) {
   1624        if (c == '\\') {
   1625          backslashes += 1;
   1626        } else {
   1627          backslashes = 0;
   1628        }
   1629        c = suffix[++suffix_len];
   1630      }
   1631    } else {
   1632      // Stop when we hit the end of the identifier
   1633      while (isalnum(c) || c == '_' || c == '-' || c == '.') {
   1634        c = suffix[++suffix_len];
   1635      }
   1636    }
   1637    snprintf(err, errlen, "\"%.*s\":\n", suffix_len, suffix);
   1638    offset = strlen(err);
   1639    errlen = errlen - offset;
   1640    err = err + offset;
   1641  }
   1642 
   1643  if (!error_line) {
   1644    snprintf(err, errlen, "Unexpected EOF\n");
   1645    return;
   1646  }
   1647 
   1648  snprintf(err, errlen, "%.*s\n%*s^\n", error_line_len, error_line, column, "");
   1649 }
   1650 
   1651 static TSQuery *query_check(lua_State *L, int index)
   1652 {
   1653  TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
   1654  luaL_argcheck(L, *ud, index, "TSQuery expected");
   1655  return *ud;
   1656 }
   1657 
   1658 static int query_gc(lua_State *L)
   1659 {
   1660  TSQuery *query = query_check(L, 1);
   1661  ts_query_delete(query);
   1662  return 0;
   1663 }
   1664 
   1665 static int query_tostring(lua_State *L)
   1666 {
   1667  lua_pushstring(L, "<query>");
   1668  return 1;
   1669 }
   1670 
   1671 static int query_inspect(lua_State *L)
   1672 {
   1673  TSQuery *query = query_check(L, 1);
   1674 
   1675  // TSQueryInfo
   1676  lua_createtable(L, 0, 2);  // [retval]
   1677 
   1678  uint32_t n_pat = ts_query_pattern_count(query);
   1679  lua_createtable(L, (int)n_pat, 1);  // [retval, patterns]
   1680  for (size_t i = 0; i < n_pat; i++) {
   1681    uint32_t len;
   1682    const TSQueryPredicateStep *step = ts_query_predicates_for_pattern(query, (uint32_t)i, &len);
   1683    if (len == 0) {
   1684      continue;
   1685    }
   1686    lua_createtable(L, (int)len/4, 1);  // [retval, patterns, pat]
   1687    lua_createtable(L, 3, 0);  // [retval, patterns, pat, pred]
   1688    int nextpred = 1;
   1689    int nextitem = 1;
   1690    for (size_t k = 0; k < len; k++) {
   1691      if (step[k].type == TSQueryPredicateStepTypeDone) {
   1692        lua_rawseti(L, -2, nextpred++);  // [retval, patterns, pat]
   1693        lua_createtable(L, 3, 0);  // [retval, patterns, pat, pred]
   1694        nextitem = 1;
   1695        continue;
   1696      }
   1697 
   1698      if (step[k].type == TSQueryPredicateStepTypeString) {
   1699        uint32_t strlen;
   1700        const char *str = ts_query_string_value_for_id(query, step[k].value_id,
   1701                                                       &strlen);
   1702        lua_pushlstring(L, str, strlen);  // [retval, patterns, pat, pred, item]
   1703      } else if (step[k].type == TSQueryPredicateStepTypeCapture) {
   1704        lua_pushinteger(L, step[k].value_id + 1);  // [..., pat, pred, item]
   1705      } else {
   1706        abort();
   1707      }
   1708      lua_rawseti(L, -2, nextitem++);  // [retval, patterns, pat, pred]
   1709    }
   1710    // last predicate should have ended with TypeDone
   1711    lua_pop(L, 1);  // [retval, patterns, pat]
   1712    lua_rawseti(L, -2, (int)i + 1);  // [retval, patterns]
   1713  }
   1714  lua_setfield(L, -2, "patterns");  // [retval]
   1715 
   1716  uint32_t n_captures = ts_query_capture_count(query);
   1717  lua_createtable(L, (int)n_captures, 0);  // [retval, captures]
   1718  for (size_t i = 0; i < n_captures; i++) {
   1719    uint32_t strlen;
   1720    const char *str = ts_query_capture_name_for_id(query, (uint32_t)i, &strlen);
   1721    lua_pushlstring(L, str, strlen);  // [retval, captures, capture]
   1722    lua_rawseti(L, -2, (int)i + 1);
   1723  }
   1724  lua_setfield(L, -2, "captures");  // [retval]
   1725 
   1726  return 1;
   1727 }
   1728 
   1729 static int query_disable_capture(lua_State *L)
   1730 {
   1731  TSQuery *query = query_check(L, 1);
   1732  size_t name_len;
   1733  const char *name = luaL_checklstring(L, 2, &name_len);
   1734  ts_query_disable_capture(query, name, (uint32_t)name_len);
   1735  return 0;
   1736 }
   1737 
   1738 static int query_disable_pattern(lua_State *L)
   1739 {
   1740  TSQuery *query = query_check(L, 1);
   1741  const uint32_t pattern_index = (uint32_t)luaL_checkinteger(L, 2);
   1742  ts_query_disable_pattern(query, pattern_index - 1);
   1743  return 0;
   1744 }
   1745 
   1746 // Library init
   1747 
   1748 static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
   1749 {
   1750  if (luaL_newmetatable(L, tname)) {  // [meta]
   1751    luaL_register(L, NULL, meta);
   1752 
   1753    lua_pushvalue(L, -1);  // [meta, meta]
   1754    lua_setfield(L, -2, "__index");  // [meta]
   1755  }
   1756  lua_pop(L, 1);  // [] (don't use it now)
   1757 }
   1758 
   1759 /// Init the tslua library.
   1760 ///
   1761 /// All global state is stored in the registry of the lua_State.
   1762 static void tslua_init(lua_State *L)
   1763 {
   1764  // type metatables
   1765  build_meta(L, TS_META_PARSER, parser_meta);
   1766  build_meta(L, TS_META_TREE, tree_meta);
   1767  build_meta(L, TS_META_NODE, node_meta);
   1768  build_meta(L, TS_META_QUERY, query_meta);
   1769  build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
   1770  build_meta(L, TS_META_QUERYMATCH, querymatch_meta);
   1771 
   1772  ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
   1773 }
   1774 
   1775 static int tslua_get_language_version(lua_State *L)
   1776 {
   1777  lua_pushnumber(L, TREE_SITTER_LANGUAGE_VERSION);
   1778  return 1;
   1779 }
   1780 
   1781 static int tslua_get_minimum_language_version(lua_State *L)
   1782 {
   1783  lua_pushnumber(L, TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION);
   1784  return 1;
   1785 }
   1786 
   1787 void nlua_treesitter_free(void)
   1788 {
   1789 #ifdef HAVE_WASMTIME
   1790  if (wasmengine != NULL) {
   1791    wasm_engine_delete(wasmengine);
   1792  }
   1793  if (ts_wasmstore != NULL) {
   1794    ts_wasm_store_delete(ts_wasmstore);
   1795  }
   1796 #endif
   1797 }
   1798 
   1799 void nlua_treesitter_init(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL
   1800 {
   1801  tslua_init(lstate);
   1802 
   1803  lua_pushcfunction(lstate, tslua_push_parser);
   1804  lua_setfield(lstate, -2, "_create_ts_parser");
   1805 
   1806  lua_pushcfunction(lstate, tslua_push_querycursor);
   1807  lua_setfield(lstate, -2, "_create_ts_querycursor");
   1808 
   1809  lua_pushcfunction(lstate, tslua_add_language_from_object);
   1810  lua_setfield(lstate, -2, "_ts_add_language_from_object");
   1811 
   1812 #ifdef HAVE_WASMTIME
   1813  lua_pushcfunction(lstate, tslua_add_language_from_wasm);
   1814  lua_setfield(lstate, -2, "_ts_add_language_from_wasm");
   1815 #endif
   1816 
   1817  lua_pushcfunction(lstate, tslua_has_language);
   1818  lua_setfield(lstate, -2, "_ts_has_language");
   1819 
   1820  lua_pushcfunction(lstate, tslua_remove_lang);
   1821  lua_setfield(lstate, -2, "_ts_remove_language");
   1822 
   1823  lua_pushcfunction(lstate, tslua_inspect_lang);
   1824  lua_setfield(lstate, -2, "_ts_inspect_language");
   1825 
   1826  lua_pushcfunction(lstate, tslua_parse_query);
   1827  lua_setfield(lstate, -2, "_ts_parse_query");
   1828 
   1829  lua_pushcfunction(lstate, tslua_get_language_version);
   1830  lua_setfield(lstate, -2, "_ts_get_language_version");
   1831 
   1832  lua_pushcfunction(lstate, tslua_get_minimum_language_version);
   1833  lua_setfield(lstate, -2, "_ts_get_minimum_language_version");
   1834 }