treesitter.c (49514B)
1 // lua bindings for treesitter. 2 // NB: this file mostly contains a generic lua interface for treesitter 3 // trees and nodes, and could be broken out as a reusable lua package 4 5 #include <assert.h> 6 #include <ctype.h> 7 #include <lauxlib.h> 8 #include <limits.h> 9 #include <lua.h> 10 #include <stdbool.h> 11 #include <stdint.h> 12 #include <stdio.h> 13 #include <stdlib.h> 14 #include <string.h> 15 #include <tree_sitter/api.h> 16 #include <uv.h> 17 18 #include "nvim/os/time.h" 19 20 #ifdef HAVE_WASMTIME 21 # include <wasm.h> 22 23 # include "nvim/os/fs.h" 24 #endif 25 26 #include "nvim/api/private/helpers.h" 27 #include "nvim/ascii_defs.h" 28 #include "nvim/buffer_defs.h" 29 #include "nvim/globals.h" 30 #include "nvim/lua/treesitter.h" 31 #include "nvim/macros_defs.h" 32 #include "nvim/map_defs.h" 33 #include "nvim/memline.h" 34 #include "nvim/memory.h" 35 #include "nvim/pos_defs.h" 36 #include "nvim/strings.h" 37 #include "nvim/types_defs.h" 38 39 #define TS_META_PARSER "treesitter_parser" 40 #define TS_META_TREE "treesitter_tree" 41 #define TS_META_NODE "treesitter_node" 42 #define TS_META_QUERY "treesitter_query" 43 #define TS_META_QUERYCURSOR "treesitter_querycursor" 44 #define TS_META_QUERYMATCH "treesitter_querymatch" 45 46 typedef struct { 47 LuaRef cb; 48 lua_State *lstate; 49 bool lex; 50 bool parse; 51 } TSLuaLoggerOpts; 52 53 typedef struct { 54 // We derive TSNode's, TSQueryCursor's, etc., from the TSTree, so it must not be mutated. 55 const TSTree *tree; 56 } TSLuaTree; 57 58 typedef struct { 59 uint64_t parse_start_time; 60 uint64_t timeout_threshold_ns; 61 } TSLuaParserCallbackPayload; 62 63 #include "lua/treesitter.c.generated.h" 64 65 static PMap(cstr_t) langs = MAP_INIT; 66 67 #ifdef HAVE_WASMTIME 68 static wasm_engine_t *wasmengine; 69 static TSWasmStore *ts_wasmstore; 70 #endif 71 72 // TSLanguage 73 74 static int tslua_has_language(lua_State *L) 75 { 76 const char *lang_name = luaL_checkstring(L, 1); 77 lua_pushboolean(L, map_has(cstr_t, &langs, lang_name)); 78 return 1; 79 } 80 81 #ifdef HAVE_WASMTIME 82 static char *read_file(const char *path, size_t *len) 83 FUNC_ATTR_MALLOC 84 { 85 FILE *file = os_fopen(path, "r"); 86 if (file == NULL) { 87 return NULL; 88 } 89 fseek(file, 0L, SEEK_END); 90 *len = (size_t)ftell(file); 91 fseek(file, 0L, SEEK_SET); 92 char *data = xmalloc(*len); 93 if (fread(data, *len, 1, file) != 1) { 94 xfree(data); 95 fclose(file); 96 return NULL; 97 } 98 fclose(file); 99 return data; 100 } 101 102 static const char *wasmerr_to_str(TSWasmErrorKind werr) 103 { 104 switch (werr) { 105 case TSWasmErrorKindParse: 106 return "PARSE"; 107 case TSWasmErrorKindCompile: 108 return "COMPILE"; 109 case TSWasmErrorKindInstantiate: 110 return "INSTANTIATE"; 111 case TSWasmErrorKindAllocate: 112 return "ALLOCATE"; 113 default: 114 return "UNKNOWN"; 115 } 116 } 117 #endif 118 119 #ifdef HAVE_WASMTIME 120 static int tslua_add_language_from_wasm(lua_State *L) 121 { 122 return add_language(L, true); 123 } 124 #endif 125 126 // Creates the language into the internal language map. 127 // 128 // Returns true if the language is correctly loaded in the language map 129 static int tslua_add_language_from_object(lua_State *L) 130 { 131 return add_language(L, false); 132 } 133 134 static const TSLanguage *load_language_from_object(lua_State *L, const char *path, 135 const char *lang_name, const char *symbol) 136 { 137 uv_lib_t lib; 138 if (uv_dlopen(path, &lib)) { 139 xstrlcpy(IObuff, uv_dlerror(&lib), sizeof(IObuff)); 140 uv_dlclose(&lib); 141 luaL_error(L, "Failed to load parser for language '%s': uv_dlopen: %s", lang_name, IObuff); 142 } 143 144 char symbol_buf[128]; 145 snprintf(symbol_buf, sizeof(symbol_buf), "tree_sitter_%s", symbol); 146 147 TSLanguage *(*lang_parser)(void); 148 if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) { 149 xstrlcpy(IObuff, uv_dlerror(&lib), sizeof(IObuff)); 150 uv_dlclose(&lib); 151 luaL_error(L, "Failed to load parser: uv_dlsym: %s", IObuff); 152 } 153 154 TSLanguage *lang = lang_parser(); 155 156 if (lang == NULL) { 157 uv_dlclose(&lib); 158 luaL_error(L, "Failed to load parser %s: internal error", path); 159 } 160 161 return lang; 162 } 163 164 static const TSLanguage *load_language_from_wasm(lua_State *L, const char *path, 165 const char *lang_name) 166 { 167 #ifndef HAVE_WASMTIME 168 luaL_error(L, "Not supported"); 169 return NULL; 170 #else 171 if (wasmengine == NULL) { 172 wasmengine = wasm_engine_new(); 173 } 174 assert(wasmengine != NULL); 175 176 TSWasmError werr = { 0 }; 177 if (ts_wasmstore == NULL) { 178 ts_wasmstore = ts_wasm_store_new(wasmengine, &werr); 179 } 180 181 if (werr.kind > 0) { 182 luaL_error(L, "Failed to create WASM store: (%s) %s", wasmerr_to_str(werr.kind), werr.message); 183 } 184 185 size_t file_size = 0; 186 char *data = read_file(path, &file_size); 187 188 if (data == NULL) { 189 luaL_error(L, "Unable to read file", path); 190 } 191 192 const TSLanguage *lang = ts_wasm_store_load_language(ts_wasmstore, lang_name, data, 193 (uint32_t)file_size, &werr); 194 195 xfree(data); 196 197 if (werr.kind > 0) { 198 luaL_error(L, "Failed to load WASM parser %s: (%s) %s", path, wasmerr_to_str(werr.kind), 199 werr.message); 200 } 201 202 if (lang == NULL) { 203 luaL_error(L, "Failed to load parser %s: internal error", path); 204 } 205 206 return lang; 207 #endif 208 } 209 210 static int add_language(lua_State *L, bool is_wasm) 211 { 212 const char *path = luaL_checkstring(L, 1); 213 const char *lang_name = luaL_checkstring(L, 2); 214 const char *symbol_name = lang_name; 215 216 if (!is_wasm && lua_gettop(L) >= 3 && !lua_isnil(L, 3)) { 217 symbol_name = luaL_checkstring(L, 3); 218 } 219 220 if (map_has(cstr_t, &langs, lang_name)) { 221 lua_pushboolean(L, true); 222 return 1; 223 } 224 225 const TSLanguage *lang = is_wasm 226 ? load_language_from_wasm(L, path, lang_name) 227 : load_language_from_object(L, path, lang_name, symbol_name); 228 229 uint32_t lang_version = ts_language_abi_version(lang); 230 if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION 231 || lang_version > TREE_SITTER_LANGUAGE_VERSION) { 232 return luaL_error(L, 233 "ABI version mismatch for %s: supported between %d and %d, found %d", 234 path, 235 TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION, 236 TREE_SITTER_LANGUAGE_VERSION, lang_version); 237 } 238 239 pmap_put(cstr_t)(&langs, xstrdup(lang_name), (TSLanguage *)lang); 240 241 lua_pushboolean(L, true); 242 return 1; 243 } 244 245 static int tslua_remove_lang(lua_State *L) 246 { 247 const char *lang_name = luaL_checkstring(L, 1); 248 bool present = map_has(cstr_t, &langs, lang_name); 249 if (present) { 250 cstr_t key; 251 pmap_del(cstr_t)(&langs, lang_name, &key); 252 xfree((void *)key); 253 } 254 lua_pushboolean(L, present); 255 return 1; 256 } 257 258 static TSLanguage *lang_check(lua_State *L, int index) 259 { 260 const char *lang_name = luaL_checkstring(L, index); 261 TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); 262 if (!lang) { 263 luaL_error(L, "no such language: %s", lang_name); 264 } 265 return lang; 266 } 267 268 static int tslua_inspect_lang(lua_State *L) 269 { 270 TSLanguage *lang = lang_check(L, 1); 271 272 lua_createtable(L, 0, 2); // [retval] 273 274 { // Symbols 275 uint32_t nsymbols = ts_language_symbol_count(lang); 276 assert(nsymbols < INT_MAX); 277 278 lua_createtable(L, (int)(nsymbols - 1), 1); // [retval, symbols] 279 for (uint32_t i = 0; i < nsymbols; i++) { 280 TSSymbolType t = ts_language_symbol_type(lang, (TSSymbol)i); 281 if (t == TSSymbolTypeAuxiliary) { 282 // not used by the API 283 continue; 284 } 285 const char *name = ts_language_symbol_name(lang, (TSSymbol)i); 286 bool named = t != TSSymbolTypeAnonymous; 287 lua_pushboolean(L, named); // [retval, symbols, is_named] 288 if (!named) { 289 char buf[256]; 290 snprintf(buf, sizeof(buf), "\"%s\"", name); 291 lua_setfield(L, -2, buf); // [retval, symbols] 292 } else { 293 lua_setfield(L, -2, name); // [retval, symbols] 294 } 295 } 296 297 lua_setfield(L, -2, "symbols"); // [retval] 298 } 299 300 { // Fields 301 uint32_t nfields = ts_language_field_count(lang); 302 lua_createtable(L, (int)nfields, 1); // [retval, fields] 303 // Field IDs go from 1 to nfields inclusive (extra index 0 maps to NULL) 304 for (uint32_t i = 1; i <= nfields; i++) { 305 lua_pushstring(L, ts_language_field_name_for_id(lang, (TSFieldId)i)); 306 lua_rawseti(L, -2, (int)i); // [retval, fields] 307 } 308 309 lua_setfield(L, -2, "fields"); // [retval] 310 } 311 312 lua_pushboolean(L, ts_language_is_wasm(lang)); 313 lua_setfield(L, -2, "_wasm"); 314 315 lua_pushinteger(L, ts_language_abi_version(lang)); // [retval, version] 316 lua_setfield(L, -2, "abi_version"); 317 318 { // Metadata 319 const TSLanguageMetadata *meta = ts_language_metadata(lang); 320 321 if (meta != NULL) { 322 lua_createtable(L, 0, 3); 323 324 lua_pushinteger(L, meta->major_version); 325 lua_setfield(L, -2, "major_version"); 326 lua_pushinteger(L, meta->minor_version); 327 lua_setfield(L, -2, "minor_version"); 328 lua_pushinteger(L, meta->patch_version); 329 lua_setfield(L, -2, "patch_version"); 330 331 lua_setfield(L, -2, "metadata"); 332 } 333 } 334 335 lua_pushinteger(L, ts_language_state_count(lang)); 336 lua_setfield(L, -2, "state_count"); 337 338 { // Supertypes 339 uint32_t nsupertypes; 340 const TSSymbol *supertypes = ts_language_supertypes(lang, &nsupertypes); 341 342 lua_createtable(L, 0, (int)nsupertypes); // [retval, supertypes] 343 for (uint32_t i = 0; i < nsupertypes; i++) { 344 const TSSymbol supertype = *(supertypes + i); 345 346 uint32_t nsubtypes; 347 const TSSymbol *subtypes = ts_language_subtypes(lang, supertype, &nsubtypes); 348 349 lua_createtable(L, (int)nsubtypes, 0); 350 for (uint32_t j = 1; j <= nsubtypes; j++) { 351 lua_pushstring(L, ts_language_symbol_name(lang, *(subtypes + j))); 352 lua_rawseti(L, -2, (int)j); 353 } 354 355 lua_setfield(L, -2, ts_language_symbol_name(lang, supertype)); 356 } 357 358 lua_setfield(L, -2, "supertypes"); // [retval] 359 } 360 361 return 1; 362 } 363 364 // TSParser 365 366 static struct luaL_Reg parser_meta[] = { 367 { "__gc", parser_gc }, 368 { "__tostring", parser_tostring }, 369 { "parse", parser_parse }, 370 { "reset", parser_reset }, 371 { "set_included_ranges", parser_set_ranges }, 372 { "included_ranges", parser_get_ranges }, 373 { "_set_logger", parser_set_logger }, 374 { "_logger", parser_get_logger }, 375 { NULL, NULL } 376 }; 377 378 static int tslua_push_parser(lua_State *L) 379 { 380 TSLanguage *lang = lang_check(L, 1); 381 382 TSParser **parser = lua_newuserdata(L, sizeof(TSParser *)); 383 *parser = ts_parser_new(); 384 385 #ifdef HAVE_WASMTIME 386 if (ts_language_is_wasm(lang)) { 387 assert(wasmengine != NULL); 388 ts_parser_set_wasm_store(*parser, ts_wasmstore); 389 } 390 #endif 391 392 if (!ts_parser_set_language(*parser, lang)) { 393 ts_parser_delete(*parser); 394 const char *lang_name = luaL_checkstring(L, 1); 395 return luaL_error(L, "Failed to load language : %s", lang_name); 396 } 397 398 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_PARSER); // [udata, meta] 399 lua_setmetatable(L, -2); // [udata] 400 return 1; 401 } 402 403 static TSParser *parser_check(lua_State *L, uint16_t index) 404 { 405 TSParser **ud = luaL_checkudata(L, index, TS_META_PARSER); 406 luaL_argcheck(L, *ud, index, "TSParser expected"); 407 return *ud; 408 } 409 410 static void logger_gc(TSLogger logger) 411 { 412 if (!logger.log) { 413 return; 414 } 415 416 TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload; 417 luaL_unref(opts->lstate, LUA_REGISTRYINDEX, opts->cb); 418 xfree(opts); 419 } 420 421 static int parser_gc(lua_State *L) 422 { 423 TSParser *p = parser_check(L, 1); 424 logger_gc(ts_parser_logger(p)); 425 ts_parser_delete(p); 426 return 0; 427 } 428 429 static int parser_tostring(lua_State *L) 430 { 431 lua_pushstring(L, "<parser>"); 432 return 1; 433 } 434 435 static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position, 436 uint32_t *bytes_read) 437 { 438 buf_T *bp = payload; 439 #define BUFSIZE 256 440 static char buf[BUFSIZE]; 441 442 if ((linenr_T)position.row >= bp->b_ml.ml_line_count) { 443 *bytes_read = 0; 444 return ""; 445 } 446 linenr_T lnum = (linenr_T)position.row + 1; 447 char *line = ml_get_buf(bp, lnum); 448 size_t len = (size_t)ml_get_buf_len(bp, lnum); 449 if (position.column > len) { 450 *bytes_read = 0; 451 return ""; 452 } 453 size_t tocopy = MIN(len - position.column, BUFSIZE); 454 455 memcpy(buf, line + position.column, tocopy); 456 // Translate embedded \n to NUL 457 memchrsub(buf, '\n', NUL, tocopy); 458 *bytes_read = (uint32_t)tocopy; 459 if (tocopy < BUFSIZE) { 460 // now add the final \n, if it is meant to be present for this buffer. If it didn't fit, 461 // input_cb will be called again on the same line with advanced column. 462 if (lnum != bp->b_ml.ml_line_count || (!bp->b_p_bin && bp->b_p_fixeol) 463 || (lnum != bp->b_no_eol_lnum && bp->b_p_eol)) { 464 buf[tocopy] = '\n'; 465 (*bytes_read)++; 466 } 467 } 468 return buf; 469 #undef BUFSIZE 470 } 471 472 static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length, 473 bool include_bytes) 474 { 475 lua_createtable(L, (int)length, 0); 476 for (size_t i = 0; i < length; i++) { 477 lua_createtable(L, include_bytes ? 6 : 4, 0); 478 int j = 1; 479 lua_pushinteger(L, ranges[i].start_point.row); 480 lua_rawseti(L, -2, j++); 481 lua_pushinteger(L, ranges[i].start_point.column); 482 lua_rawseti(L, -2, j++); 483 if (include_bytes) { 484 lua_pushinteger(L, ranges[i].start_byte); 485 lua_rawseti(L, -2, j++); 486 } 487 lua_pushinteger(L, ranges[i].end_point.row); 488 lua_rawseti(L, -2, j++); 489 lua_pushinteger(L, ranges[i].end_point.column); 490 lua_rawseti(L, -2, j++); 491 if (include_bytes) { 492 lua_pushinteger(L, ranges[i].end_byte); 493 lua_rawseti(L, -2, j++); 494 } 495 496 lua_rawseti(L, -2, (int)(i + 1)); 497 } 498 } 499 500 static bool on_parser_progress(TSParseState *state) 501 { 502 TSLuaParserCallbackPayload *payload = state->payload; 503 uint64_t parse_time = os_hrtime() - payload->parse_start_time; 504 return parse_time >= payload->timeout_threshold_ns; 505 } 506 507 static int parser_parse(lua_State *L) 508 { 509 TSParser *p = parser_check(L, 1); 510 const TSTree *old_tree = NULL; 511 if (!lua_isnil(L, 2)) { 512 TSLuaTree *ud = luaL_checkudata(L, 2, TS_META_TREE); 513 old_tree = ud ? ud->tree : NULL; 514 } 515 516 TSTree *new_tree = NULL; 517 size_t len; 518 const char *str; 519 handle_T bufnr; 520 buf_T *buf; 521 TSInput input; 522 523 // This switch is necessary because of the behavior of lua_isstring, that 524 // consider numbers as strings... 525 switch (lua_type(L, 3)) { 526 case LUA_TSTRING: 527 str = lua_tolstring(L, 3, &len); 528 new_tree = ts_parser_parse_string(p, old_tree, str, (uint32_t)len); 529 break; 530 531 case LUA_TNUMBER: 532 bufnr = (handle_T)lua_tointeger(L, 3); 533 buf = handle_get_buffer(bufnr); 534 535 if (!buf) { 536 #define BUFSIZE 256 537 char ebuf[BUFSIZE] = { 0 }; 538 vim_snprintf(ebuf, BUFSIZE, "invalid buffer handle: %d", bufnr); 539 return luaL_argerror(L, 3, ebuf); 540 #undef BUFSIZE 541 } 542 543 input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8, NULL }; 544 if (!lua_isnil(L, 5)) { 545 uint64_t timeout_ns = (uint64_t)lua_tointeger(L, 5); 546 TSLuaParserCallbackPayload payload = 547 (TSLuaParserCallbackPayload){ .parse_start_time = os_hrtime(), 548 .timeout_threshold_ns = timeout_ns }; 549 TSParseOptions parse_options = { .payload = &payload, 550 .progress_callback = on_parser_progress }; 551 new_tree = ts_parser_parse_with_options(p, old_tree, input, parse_options); 552 } else { 553 new_tree = ts_parser_parse(p, old_tree, input); 554 } 555 556 break; 557 558 default: 559 return luaL_argerror(L, 3, "expected either string or buffer handle"); 560 } 561 562 bool include_bytes = (lua_gettop(L) >= 4) && lua_toboolean(L, 4); 563 564 if (!new_tree) { 565 // Sometimes parsing fails (no language was set, or it was set to one with an incompatible ABI) 566 // In those cases, just return an error. 567 if (!ts_parser_language(p)) { 568 return luaL_error(L, "Language was unset, or has an incompatible ABI."); 569 } 570 return 0; 571 } 572 573 // The new tree will be pushed to the stack, without copy, ownership is now to the lua GC. 574 // Old tree is owned by lua GC since before 575 uint32_t n_ranges = 0; 576 TSRange *changed = old_tree ? ts_tree_get_changed_ranges(old_tree, new_tree, &n_ranges) 577 : ts_tree_included_ranges(new_tree, &n_ranges); 578 579 push_tree(L, new_tree); // [tree] 580 581 push_ranges(L, changed, n_ranges, include_bytes); // [tree, ranges] 582 583 xfree(changed); 584 return 2; 585 } 586 587 static int parser_reset(lua_State *L) 588 { 589 TSParser *p = parser_check(L, 1); 590 ts_parser_reset(p); 591 return 0; 592 } 593 594 static void range_err(lua_State *L) 595 { 596 luaL_error(L, "Ranges can only be made from 6 element long tables or nodes."); 597 } 598 599 // Use the top of the stack (without popping it) to create a TSRange, it can be 600 // either a lua table or a TSNode 601 static void range_from_lua(lua_State *L, TSRange *range) 602 { 603 TSNode node; 604 605 if (lua_istable(L, -1)) { 606 // should be a table of 6 elements 607 if (lua_objlen(L, -1) != 6) { 608 range_err(L); 609 } 610 611 lua_rawgeti(L, -1, 1); // [ range, start_row] 612 uint32_t start_row = (uint32_t)luaL_checkinteger(L, -1); 613 lua_pop(L, 1); 614 615 lua_rawgeti(L, -1, 2); // [ range, start_col] 616 uint32_t start_col = (uint32_t)luaL_checkinteger(L, -1); 617 lua_pop(L, 1); 618 619 lua_rawgeti(L, -1, 3); // [ range, start_byte] 620 uint32_t start_byte = (uint32_t)luaL_checkinteger(L, -1); 621 lua_pop(L, 1); 622 623 lua_rawgeti(L, -1, 4); // [ range, end_row] 624 uint32_t end_row = (uint32_t)luaL_checkinteger(L, -1); 625 lua_pop(L, 1); 626 627 lua_rawgeti(L, -1, 5); // [ range, end_col] 628 uint32_t end_col = (uint32_t)luaL_checkinteger(L, -1); 629 lua_pop(L, 1); 630 631 lua_rawgeti(L, -1, 6); // [ range, end_byte] 632 uint32_t end_byte = (uint32_t)luaL_checkinteger(L, -1); 633 lua_pop(L, 1); // [ range ] 634 635 *range = (TSRange) { 636 .start_point = (TSPoint) { 637 .row = start_row, 638 .column = start_col 639 }, 640 .end_point = (TSPoint) { 641 .row = end_row, 642 .column = end_col 643 }, 644 .start_byte = start_byte, 645 .end_byte = end_byte, 646 }; 647 } else if (node_check_opt(L, -1, &node)) { 648 *range = (TSRange) { 649 .start_point = ts_node_start_point(node), 650 .end_point = ts_node_end_point(node), 651 .start_byte = ts_node_start_byte(node), 652 .end_byte = ts_node_end_byte(node) 653 }; 654 } else { 655 range_err(L); 656 } 657 } 658 659 static int parser_set_ranges(lua_State *L) 660 { 661 if (lua_gettop(L) < 2) { 662 return luaL_error(L, "not enough args to parser:set_included_ranges()"); 663 } 664 665 TSParser *p = parser_check(L, 1); 666 667 luaL_argcheck(L, lua_istable(L, 2), 2, "table expected."); 668 669 size_t tbl_len = lua_objlen(L, 2); 670 TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len); 671 672 // [ parser, ranges ] 673 for (size_t index = 0; index < tbl_len; index++) { 674 lua_rawgeti(L, 2, (int)index + 1); // [ parser, ranges, range ] 675 range_from_lua(L, ranges + index); 676 lua_pop(L, 1); 677 } 678 679 // This memcpies ranges, thus we can free it afterwards 680 ts_parser_set_included_ranges(p, ranges, (uint32_t)tbl_len); 681 xfree(ranges); 682 683 return 0; 684 } 685 686 static int parser_get_ranges(lua_State *L) 687 { 688 TSParser *p = parser_check(L, 1); 689 690 bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); 691 692 uint32_t len; 693 const TSRange *ranges = ts_parser_included_ranges(p, &len); 694 695 push_ranges(L, ranges, len, include_bytes); 696 return 1; 697 } 698 699 static void logger_cb(void *payload, TSLogType logtype, const char *s) 700 { 701 TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)payload; 702 if ((!opts->lex && logtype == TSLogTypeLex) 703 || (!opts->parse && logtype == TSLogTypeParse)) { 704 return; 705 } 706 707 lua_State *lstate = opts->lstate; 708 709 lua_rawgeti(lstate, LUA_REGISTRYINDEX, opts->cb); 710 lua_pushstring(lstate, logtype == TSLogTypeParse ? "parse" : "lex"); 711 lua_pushstring(lstate, s); 712 if (lua_pcall(lstate, 2, 0, 0)) { 713 luaL_error(lstate, "treesitter logger callback failed"); 714 } 715 } 716 717 static int parser_set_logger(lua_State *L) 718 { 719 TSParser *p = parser_check(L, 1); 720 721 luaL_argcheck(L, lua_isboolean(L, 2), 2, "boolean expected"); 722 luaL_argcheck(L, lua_isboolean(L, 3), 3, "boolean expected"); 723 luaL_argcheck(L, lua_isfunction(L, 4), 4, "function expected"); 724 725 TSLuaLoggerOpts *opts = xmalloc(sizeof(TSLuaLoggerOpts)); 726 lua_pushvalue(L, 4); 727 LuaRef ref = luaL_ref(L, LUA_REGISTRYINDEX); 728 729 *opts = (TSLuaLoggerOpts){ 730 .lex = lua_toboolean(L, 2), 731 .parse = lua_toboolean(L, 3), 732 .cb = ref, 733 .lstate = L 734 }; 735 736 TSLogger logger = { 737 .payload = (void *)opts, 738 .log = logger_cb 739 }; 740 741 ts_parser_set_logger(p, logger); 742 return 0; 743 } 744 745 static int parser_get_logger(lua_State *L) 746 { 747 TSParser *p = parser_check(L, 1); 748 TSLogger logger = ts_parser_logger(p); 749 if (logger.log) { 750 TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload; 751 lua_rawgeti(L, LUA_REGISTRYINDEX, opts->cb); 752 } else { 753 lua_pushnil(L); 754 } 755 756 return 1; 757 } 758 759 // TSTree 760 761 static struct luaL_Reg tree_meta[] = { 762 { "__gc", tree_gc }, 763 { "__tostring", tree_tostring }, 764 { "root", tree_root }, 765 { "edit", tree_edit }, 766 { "included_ranges", tree_get_ranges }, 767 { "copy", tree_copy }, 768 { NULL, NULL } 769 }; 770 771 /// Push tree interface on to the lua stack. 772 /// 773 /// The tree is not copied. Ownership of the tree is transferred from C to 774 /// Lua. If needed use ts_tree_copy() in the caller. 775 static void push_tree(lua_State *L, const TSTree *tree) 776 { 777 if (tree == NULL) { 778 lua_pushnil(L); 779 return; 780 } 781 782 TSLuaTree *ud = lua_newuserdata(L, sizeof(TSLuaTree)); // [udata] 783 ud->tree = tree; 784 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREE); // [udata, meta] 785 lua_setmetatable(L, -2); // [udata] 786 } 787 788 static int tree_copy(lua_State *L) 789 { 790 TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); 791 TSTree *copy = ts_tree_copy(ud->tree); 792 push_tree(L, copy); // [tree] 793 794 return 1; 795 } 796 797 static int tree_edit(lua_State *L) 798 { 799 if (lua_gettop(L) < 10) { 800 lua_pushstring(L, "not enough args to tree:edit()"); 801 return lua_error(L); 802 } 803 804 TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); 805 806 uint32_t start_byte = (uint32_t)luaL_checkint(L, 2); 807 uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3); 808 uint32_t new_end_byte = (uint32_t)luaL_checkint(L, 4); 809 TSPoint start_point = { (uint32_t)luaL_checkint(L, 5), (uint32_t)luaL_checkint(L, 6) }; 810 TSPoint old_end_point = { (uint32_t)luaL_checkint(L, 7), (uint32_t)luaL_checkint(L, 8) }; 811 TSPoint new_end_point = { (uint32_t)luaL_checkint(L, 9), (uint32_t)luaL_checkint(L, 10) }; 812 813 TSInputEdit edit = { start_byte, old_end_byte, new_end_byte, 814 start_point, old_end_point, new_end_point }; 815 816 TSTree *new_tree = ts_tree_copy(ud->tree); 817 ts_tree_edit(new_tree, &edit); 818 819 push_tree(L, new_tree); // [tree] 820 821 return 1; 822 } 823 824 static int tree_get_ranges(lua_State *L) 825 { 826 TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); 827 828 bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); 829 830 uint32_t len; 831 TSRange *ranges = ts_tree_included_ranges(ud->tree, &len); 832 833 push_ranges(L, ranges, len, include_bytes); 834 835 xfree(ranges); 836 return 1; 837 } 838 839 static int tree_gc(lua_State *L) 840 { 841 TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); 842 843 // SAFETY: we can cast the const away because the tree is only garbage collected after all of its 844 // TSNode's, TSQuerCurors, etc., are unreachable (each contains a reference to the TSLuaTree) 845 TSTree *tree = (TSTree *)ud->tree; 846 847 ts_tree_delete(tree); 848 return 0; 849 } 850 851 static int tree_tostring(lua_State *L) 852 { 853 lua_pushstring(L, "<tree>"); 854 return 1; 855 } 856 857 static int tree_root(lua_State *L) 858 { 859 TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); 860 861 TSNode root = ts_tree_root_node(ud->tree); 862 863 TSNode *node_ud = lua_newuserdata(L, sizeof(TSNode)); // [node] 864 *node_ud = root; 865 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE); // [node, meta] 866 lua_setmetatable(L, -2); // [node] 867 868 // To prevent the tree from being garbage collected, create a reference to it 869 // in the fenv which will be passed to userdata nodes of the tree. 870 // Note: environments (fenvs) associated with userdata have no meaning in Lua 871 // and are only used to associate a table. 872 lua_createtable(L, 1, 0); // [node, reftable] 873 lua_pushvalue(L, 1); // [node, reftable, tree] 874 lua_rawseti(L, -2, 1); // [node, reftable] 875 lua_setfenv(L, -2); // [node] 876 877 return 1; 878 } 879 880 // TSNode 881 static struct luaL_Reg node_meta[] = { 882 { "__tostring", node_tostring }, 883 { "__eq", node_eq }, 884 { "__len", node_child_count }, 885 { "id", node_id }, 886 { "range", node_range }, 887 { "start", node_start }, 888 { "end_", node_end }, 889 { "type", node_type }, 890 { "symbol", node_symbol }, 891 { "field", node_field }, 892 { "named", node_named }, 893 { "missing", node_missing }, 894 { "extra", node_extra }, 895 { "has_changes", node_has_changes }, 896 { "has_error", node_has_error }, 897 { "sexpr", node_sexpr }, 898 { "child_count", node_child_count }, 899 { "named_child_count", node_named_child_count }, 900 { "child", node_child }, 901 { "named_child", node_named_child }, 902 { "descendant_for_range", node_descendant_for_range }, 903 { "named_descendant_for_range", node_named_descendant_for_range }, 904 { "parent", node_parent }, 905 { "__has_ancestor", __has_ancestor }, 906 { "child_with_descendant", node_child_with_descendant }, 907 { "iter_children", node_iter_children }, 908 { "next_sibling", node_next_sibling }, 909 { "prev_sibling", node_prev_sibling }, 910 { "next_named_sibling", node_next_named_sibling }, 911 { "prev_named_sibling", node_prev_named_sibling }, 912 { "named_children", node_named_children }, 913 { "root", node_root }, 914 { "tree", node_tree }, 915 { "byte_length", node_byte_length }, 916 { "equal", node_equal }, 917 918 { NULL, NULL } 919 }; 920 921 /// Push node interface on to the Lua stack 922 /// 923 /// Stack at `uindex` must have a value with a fenv with a reference to node's 924 /// tree. This value is not popped. Can only be called inside a cfunction with 925 /// the tslua environment. 926 static void push_node(lua_State *L, TSNode node, int uindex) 927 { 928 assert(uindex > 0 || uindex < -LUA_MINSTACK); 929 if (ts_node_is_null(node)) { 930 lua_pushnil(L); // [nil] 931 return; 932 } 933 934 TSNode *ud = lua_newuserdata(L, sizeof(TSNode)); // [udata] 935 *ud = node; 936 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE); // [udata, meta] 937 lua_setmetatable(L, -2); // [udata] 938 939 // Copy the fenv to keep alive a reference to the node's tree. 940 lua_getfenv(L, uindex); // [udata, reftable] 941 lua_setfenv(L, -2); // [udata] 942 } 943 944 static bool node_check_opt(lua_State *L, int index, TSNode *res) 945 { 946 TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); 947 if (ud) { 948 *res = *ud; 949 return true; 950 } 951 return false; 952 } 953 954 static TSNode node_check(lua_State *L, int index) 955 { 956 TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); 957 return *ud; 958 } 959 960 static int node_tostring(lua_State *L) 961 { 962 TSNode node = node_check(L, 1); 963 lua_pushstring(L, "<node "); 964 lua_pushstring(L, ts_node_type(node)); 965 lua_pushstring(L, ">"); 966 lua_concat(L, 3); 967 return 1; 968 } 969 970 static int node_eq(lua_State *L) 971 { 972 TSNode node = node_check(L, 1); 973 TSNode node2 = node_check(L, 2); 974 lua_pushboolean(L, ts_node_eq(node, node2)); 975 return 1; 976 } 977 978 static int node_id(lua_State *L) 979 { 980 TSNode node = node_check(L, 1); 981 lua_pushlstring(L, (const char *)&node.id, sizeof node.id); 982 return 1; 983 } 984 985 static int node_range(lua_State *L) 986 { 987 TSNode node = node_check(L, 1); 988 989 bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); 990 991 TSPoint start = ts_node_start_point(node); 992 TSPoint end = ts_node_end_point(node); 993 994 if (include_bytes) { 995 lua_pushinteger(L, start.row); 996 lua_pushinteger(L, start.column); 997 lua_pushinteger(L, ts_node_start_byte(node)); 998 lua_pushinteger(L, end.row); 999 lua_pushinteger(L, end.column); 1000 lua_pushinteger(L, ts_node_end_byte(node)); 1001 return 6; 1002 } 1003 1004 lua_pushinteger(L, start.row); 1005 lua_pushinteger(L, start.column); 1006 lua_pushinteger(L, end.row); 1007 lua_pushinteger(L, end.column); 1008 return 4; 1009 } 1010 1011 static int node_start(lua_State *L) 1012 { 1013 TSNode node = node_check(L, 1); 1014 TSPoint start = ts_node_start_point(node); 1015 uint32_t start_byte = ts_node_start_byte(node); 1016 lua_pushinteger(L, start.row); 1017 lua_pushinteger(L, start.column); 1018 lua_pushinteger(L, start_byte); 1019 return 3; 1020 } 1021 1022 static int node_end(lua_State *L) 1023 { 1024 TSNode node = node_check(L, 1); 1025 TSPoint end = ts_node_end_point(node); 1026 uint32_t end_byte = ts_node_end_byte(node); 1027 lua_pushinteger(L, end.row); 1028 lua_pushinteger(L, end.column); 1029 lua_pushinteger(L, end_byte); 1030 return 3; 1031 } 1032 1033 static int node_child_count(lua_State *L) 1034 { 1035 TSNode node = node_check(L, 1); 1036 uint32_t count = ts_node_child_count(node); 1037 lua_pushinteger(L, count); 1038 return 1; 1039 } 1040 1041 static int node_named_child_count(lua_State *L) 1042 { 1043 TSNode node = node_check(L, 1); 1044 uint32_t count = ts_node_named_child_count(node); 1045 lua_pushinteger(L, count); 1046 return 1; 1047 } 1048 1049 static int node_type(lua_State *L) 1050 { 1051 TSNode node = node_check(L, 1); 1052 lua_pushstring(L, ts_node_type(node)); 1053 return 1; 1054 } 1055 1056 static int node_symbol(lua_State *L) 1057 { 1058 TSNode node = node_check(L, 1); 1059 TSSymbol symbol = ts_node_symbol(node); 1060 lua_pushinteger(L, symbol); 1061 return 1; 1062 } 1063 1064 static int node_field(lua_State *L) 1065 { 1066 TSNode node = node_check(L, 1); 1067 uint32_t count = ts_node_child_count(node); 1068 int curr_index = 0; 1069 1070 size_t name_len; 1071 const char *field_name = luaL_checklstring(L, 2, &name_len); 1072 1073 lua_newtable(L); 1074 1075 for (uint32_t i = 0; i < count; i++) { 1076 const char *child_field_name = ts_node_field_name_for_child(node, i); 1077 if (strequal(field_name, child_field_name)) { 1078 TSNode child = ts_node_child(node, i); 1079 push_node(L, child, 1); 1080 lua_rawseti(L, -2, ++curr_index); 1081 } 1082 } 1083 1084 return 1; 1085 } 1086 1087 static int node_named(lua_State *L) 1088 { 1089 TSNode node = node_check(L, 1); 1090 lua_pushboolean(L, ts_node_is_named(node)); 1091 return 1; 1092 } 1093 1094 static int node_sexpr(lua_State *L) 1095 { 1096 TSNode node = node_check(L, 1); 1097 char *allocated = ts_node_string(node); 1098 lua_pushstring(L, allocated); 1099 xfree(allocated); 1100 return 1; 1101 } 1102 1103 static int node_missing(lua_State *L) 1104 { 1105 TSNode node = node_check(L, 1); 1106 lua_pushboolean(L, ts_node_is_missing(node)); 1107 return 1; 1108 } 1109 1110 static int node_extra(lua_State *L) 1111 { 1112 TSNode node = node_check(L, 1); 1113 lua_pushboolean(L, ts_node_is_extra(node)); 1114 return 1; 1115 } 1116 1117 static int node_has_changes(lua_State *L) 1118 { 1119 TSNode node = node_check(L, 1); 1120 lua_pushboolean(L, ts_node_has_changes(node)); 1121 return 1; 1122 } 1123 1124 static int node_has_error(lua_State *L) 1125 { 1126 TSNode node = node_check(L, 1); 1127 lua_pushboolean(L, ts_node_has_error(node)); 1128 return 1; 1129 } 1130 1131 static int node_child(lua_State *L) 1132 { 1133 TSNode node = node_check(L, 1); 1134 uint32_t num = (uint32_t)lua_tointeger(L, 2); 1135 TSNode child = ts_node_child(node, num); 1136 1137 push_node(L, child, 1); 1138 return 1; 1139 } 1140 1141 static int node_named_child(lua_State *L) 1142 { 1143 TSNode node = node_check(L, 1); 1144 uint32_t num = (uint32_t)lua_tointeger(L, 2); 1145 TSNode child = ts_node_named_child(node, num); 1146 1147 push_node(L, child, 1); 1148 return 1; 1149 } 1150 1151 static int node_descendant_for_range(lua_State *L) 1152 { 1153 TSNode node = node_check(L, 1); 1154 TSPoint start = { (uint32_t)lua_tointeger(L, 2), 1155 (uint32_t)lua_tointeger(L, 3) }; 1156 TSPoint end = { (uint32_t)lua_tointeger(L, 4), 1157 (uint32_t)lua_tointeger(L, 5) }; 1158 TSNode child = ts_node_descendant_for_point_range(node, start, end); 1159 1160 push_node(L, child, 1); 1161 return 1; 1162 } 1163 1164 static int node_named_descendant_for_range(lua_State *L) 1165 { 1166 TSNode node = node_check(L, 1); 1167 TSPoint start = { (uint32_t)lua_tointeger(L, 2), 1168 (uint32_t)lua_tointeger(L, 3) }; 1169 TSPoint end = { (uint32_t)lua_tointeger(L, 4), 1170 (uint32_t)lua_tointeger(L, 5) }; 1171 TSNode child = ts_node_named_descendant_for_point_range(node, start, end); 1172 1173 push_node(L, child, 1); 1174 return 1; 1175 } 1176 1177 static int node_next_child(lua_State *L) 1178 { 1179 uint32_t *child_index = lua_touserdata(L, lua_upvalueindex(1)); 1180 TSNode source = node_check(L, lua_upvalueindex(2)); 1181 1182 if (*child_index >= ts_node_child_count(source)) { 1183 return 0; 1184 } 1185 1186 TSNode child = ts_node_child(source, *child_index); 1187 push_node(L, child, lua_upvalueindex(2)); 1188 1189 const char *field = ts_node_field_name_for_child(source, *child_index); 1190 if (field != NULL) { 1191 lua_pushstring(L, field); 1192 } else { 1193 lua_pushnil(L); 1194 } // [node, field_name_or_nil] 1195 1196 (*child_index)++; 1197 1198 return 2; 1199 } 1200 1201 static int node_iter_children(lua_State *L) 1202 { 1203 node_check(L, 1); 1204 uint32_t *child_index = lua_newuserdata(L, sizeof(uint32_t)); // [source_node,..., udata] 1205 *child_index = 0; 1206 1207 lua_pushvalue(L, 1); // [source_node, ..., udata, source_node] 1208 lua_pushcclosure(L, node_next_child, 2); 1209 1210 return 1; 1211 } 1212 1213 static int node_parent(lua_State *L) 1214 { 1215 TSNode node = node_check(L, 1); 1216 TSNode parent = ts_node_parent(node); 1217 push_node(L, parent, 1); 1218 return 1; 1219 } 1220 1221 static int __has_ancestor(lua_State *L) 1222 { 1223 TSNode descendant = node_check(L, 1); 1224 if (lua_type(L, 2) != LUA_TTABLE) { 1225 lua_pushboolean(L, false); 1226 return 1; 1227 } 1228 int const pred_len = (int)lua_objlen(L, 2); 1229 1230 TSNode node = ts_tree_root_node(descendant.tree); 1231 while (node.id != descendant.id && !ts_node_is_null(node)) { 1232 char const *node_type = ts_node_type(node); 1233 size_t node_type_len = strlen(node_type); 1234 1235 for (int i = 3; i <= pred_len; i++) { 1236 lua_rawgeti(L, 2, i); 1237 if (lua_type(L, -1) == LUA_TSTRING) { 1238 size_t check_len; 1239 char const *check_str = lua_tolstring(L, -1, &check_len); 1240 if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) { 1241 lua_pushboolean(L, true); 1242 return 1; 1243 } 1244 } 1245 lua_pop(L, 1); 1246 } 1247 1248 node = ts_node_child_with_descendant(node, descendant); 1249 } 1250 1251 lua_pushboolean(L, false); 1252 return 1; 1253 } 1254 1255 static int node_child_with_descendant(lua_State *L) 1256 { 1257 TSNode node = node_check(L, 1); 1258 TSNode descendant = node_check(L, 2); 1259 TSNode child = ts_node_child_with_descendant(node, descendant); 1260 push_node(L, child, 1); 1261 return 1; 1262 } 1263 1264 static int node_next_sibling(lua_State *L) 1265 { 1266 TSNode node = node_check(L, 1); 1267 TSNode sibling = ts_node_next_sibling(node); 1268 push_node(L, sibling, 1); 1269 return 1; 1270 } 1271 1272 static int node_prev_sibling(lua_State *L) 1273 { 1274 TSNode node = node_check(L, 1); 1275 TSNode sibling = ts_node_prev_sibling(node); 1276 push_node(L, sibling, 1); 1277 return 1; 1278 } 1279 1280 static int node_next_named_sibling(lua_State *L) 1281 { 1282 TSNode node = node_check(L, 1); 1283 TSNode sibling = ts_node_next_named_sibling(node); 1284 push_node(L, sibling, 1); 1285 return 1; 1286 } 1287 1288 static int node_prev_named_sibling(lua_State *L) 1289 { 1290 TSNode node = node_check(L, 1); 1291 TSNode sibling = ts_node_prev_named_sibling(node); 1292 push_node(L, sibling, 1); 1293 return 1; 1294 } 1295 1296 static int node_named_children(lua_State *L) 1297 { 1298 TSNode source = node_check(L, 1); 1299 1300 lua_newtable(L); 1301 int curr_index = 0; 1302 1303 uint32_t n = ts_node_child_count(source); 1304 for (uint32_t i = 0; i < n; i++) { 1305 TSNode child = ts_node_child(source, i); 1306 if (ts_node_is_named(child)) { 1307 push_node(L, child, 1); 1308 lua_rawseti(L, -2, ++curr_index); 1309 } 1310 } 1311 1312 return 1; 1313 } 1314 1315 static int node_root(lua_State *L) 1316 { 1317 TSNode node = node_check(L, 1); 1318 TSNode root = ts_tree_root_node(node.tree); 1319 push_node(L, root, 1); 1320 return 1; 1321 } 1322 1323 static int node_tree(lua_State *L) 1324 { 1325 node_check(L, 1); 1326 1327 // Get the tree from the node fenv. We cannot use `push_tree(node.tree)` here because that would 1328 // cause a double free. 1329 lua_getfenv(L, 1); // [node, reftable] 1330 lua_rawgeti(L, 2, 1); // [node, reftable, tree] 1331 1332 return 1; 1333 } 1334 1335 static int node_byte_length(lua_State *L) 1336 { 1337 TSNode node = node_check(L, 1); 1338 uint32_t start_byte = ts_node_start_byte(node); 1339 uint32_t end_byte = ts_node_end_byte(node); 1340 lua_pushinteger(L, end_byte - start_byte); 1341 return 1; 1342 } 1343 1344 static int node_equal(lua_State *L) 1345 { 1346 TSNode node1 = node_check(L, 1); 1347 TSNode node2 = node_check(L, 2); 1348 lua_pushboolean(L, ts_node_eq(node1, node2)); 1349 return 1; 1350 } 1351 1352 // TSQueryCursor 1353 1354 static struct luaL_Reg querycursor_meta[] = { 1355 { "remove_match", querycursor_remove_match }, 1356 { "next_capture", querycursor_next_capture }, 1357 { "next_match", querycursor_next_match }, 1358 { "__gc", querycursor_gc }, 1359 { NULL, NULL } 1360 }; 1361 1362 static int tslua_push_querycursor(lua_State *L) 1363 { 1364 TSNode node = node_check(L, 1); 1365 1366 TSQuery *query = query_check(L, 2); 1367 TSQueryCursor *cursor = ts_query_cursor_new(); 1368 1369 if (lua_gettop(L) >= 3 && !lua_isnil(L, 3)) { 1370 luaL_argcheck(L, lua_istable(L, 3), 3, "table expected"); 1371 } 1372 1373 lua_getfield(L, 3, "start_row"); 1374 uint32_t start_row = (uint32_t)luaL_checkinteger(L, -1); 1375 lua_pop(L, 1); 1376 1377 lua_getfield(L, 3, "start_col"); 1378 uint32_t start_col = (uint32_t)luaL_checkinteger(L, -1); 1379 lua_pop(L, 1); 1380 1381 lua_getfield(L, 3, "end_row"); 1382 uint32_t end_row = (uint32_t)luaL_checkinteger(L, -1); 1383 lua_pop(L, 1); 1384 1385 lua_getfield(L, 3, "end_col"); 1386 uint32_t end_col = (uint32_t)luaL_checkinteger(L, -1); 1387 lua_pop(L, 1); 1388 1389 ts_query_cursor_set_point_range(cursor, (TSPoint){ start_row, start_col }, 1390 (TSPoint){ end_row, end_col }); 1391 1392 lua_getfield(L, 3, "max_start_depth"); 1393 if (!lua_isnil(L, -1)) { 1394 uint32_t max_start_depth = (uint32_t)luaL_checkinteger(L, -1); 1395 ts_query_cursor_set_max_start_depth(cursor, max_start_depth); 1396 } 1397 lua_pop(L, 1); 1398 1399 lua_getfield(L, 3, "match_limit"); 1400 if (!lua_isnil(L, -1)) { 1401 uint32_t match_limit = (uint32_t)luaL_checkinteger(L, -1); 1402 ts_query_cursor_set_match_limit(cursor, match_limit); 1403 } 1404 lua_pop(L, 1); 1405 1406 ts_query_cursor_exec(cursor, query, node); 1407 1408 TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud)); // [node, query, ..., udata] 1409 *ud = cursor; 1410 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); // [node, query, ..., udata, meta] 1411 lua_setmetatable(L, -2); // [node, query, ..., udata] 1412 1413 // Copy the fenv which contains the nodes tree. 1414 lua_getfenv(L, 1); // [udata, reftable] 1415 lua_setfenv(L, -2); // [udata] 1416 1417 return 1; 1418 } 1419 1420 static int querycursor_remove_match(lua_State *L) 1421 { 1422 TSQueryCursor *cursor = querycursor_check(L, 1); 1423 uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2); 1424 ts_query_cursor_remove_match(cursor, match_id); 1425 return 0; 1426 } 1427 1428 static int querycursor_next_capture(lua_State *L) 1429 { 1430 TSQueryCursor *cursor = querycursor_check(L, 1); 1431 TSQueryMatch match; 1432 uint32_t capture_index; 1433 if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) { 1434 return 0; 1435 } 1436 1437 TSQueryCapture capture = match.captures[capture_index]; 1438 1439 // Handle capture quantifiers here 1440 lua_pushinteger(L, capture.index + 1); // [index] 1441 push_node(L, capture.node, 1); // [index, node] 1442 push_querymatch(L, &match, 1); 1443 1444 return 3; 1445 } 1446 1447 static int querycursor_next_match(lua_State *L) 1448 { 1449 TSQueryCursor *cursor = querycursor_check(L, 1); 1450 1451 TSQueryMatch match; 1452 if (!ts_query_cursor_next_match(cursor, &match)) { 1453 return 0; 1454 } 1455 1456 push_querymatch(L, &match, 1); 1457 1458 return 1; 1459 } 1460 1461 static TSQueryCursor *querycursor_check(lua_State *L, int index) 1462 { 1463 TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR); 1464 luaL_argcheck(L, *ud, index, "TSQueryCursor expected"); 1465 return *ud; 1466 } 1467 1468 static int querycursor_gc(lua_State *L) 1469 { 1470 TSQueryCursor *cursor = querycursor_check(L, 1); 1471 ts_query_cursor_delete(cursor); 1472 return 0; 1473 } 1474 1475 // TSQueryMatch 1476 1477 static struct luaL_Reg querymatch_meta[] = { 1478 { "info", querymatch_info }, 1479 { "captures", querymatch_captures }, 1480 { NULL, NULL } 1481 }; 1482 1483 static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex) 1484 { 1485 TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch)); // [udata] 1486 *ud = *match; 1487 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH); // [udata, meta] 1488 lua_setmetatable(L, -2); // [udata] 1489 1490 // Copy the fenv which contains the nodes tree. 1491 lua_getfenv(L, uindex); // [udata, reftable] 1492 lua_setfenv(L, -2); // [udata] 1493 } 1494 1495 static int querymatch_info(lua_State *L) 1496 { 1497 TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH); 1498 lua_pushinteger(L, match->id); 1499 lua_pushinteger(L, match->pattern_index + 1); 1500 return 2; 1501 } 1502 1503 static int querymatch_captures(lua_State *L) 1504 { 1505 TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH); 1506 lua_newtable(L); // [match, nodes, captures] 1507 for (size_t i = 0; i < match->capture_count; i++) { 1508 TSQueryCapture capture = match->captures[i]; 1509 int index = (int)capture.index + 1; 1510 1511 lua_rawgeti(L, -1, index); // [match, node, captures] 1512 if (lua_isnil(L, -1)) { // [match, node, captures, nil] 1513 lua_pop(L, 1); // [match, node, captures] 1514 lua_newtable(L); // [match, node, captures, nodes] 1515 } 1516 push_node(L, capture.node, 1); // [match, node, captures, nodes, node] 1517 lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, node, captures, nodes] 1518 lua_rawseti(L, -2, index); // [match, node, captures] 1519 } 1520 return 1; 1521 } 1522 1523 // TSQuery 1524 1525 static struct luaL_Reg query_meta[] = { 1526 { "__gc", query_gc }, 1527 { "__tostring", query_tostring }, 1528 { "inspect", query_inspect }, 1529 { "disable_capture", query_disable_capture }, 1530 { "disable_pattern", query_disable_pattern }, 1531 { NULL, NULL } 1532 }; 1533 1534 static int tslua_parse_query(lua_State *L) 1535 { 1536 if (lua_gettop(L) < 2 || !lua_isstring(L, 1) || !lua_isstring(L, 2)) { 1537 return luaL_error(L, "string expected"); 1538 } 1539 1540 TSLanguage *lang = lang_check(L, 1); 1541 1542 size_t len; 1543 const char *src = lua_tolstring(L, 2, &len); 1544 1545 tslua_query_parse_count++; 1546 uint32_t error_offset; 1547 TSQueryError error_type; 1548 TSQuery *query = ts_query_new(lang, src, (uint32_t)len, &error_offset, &error_type); 1549 1550 if (!query) { 1551 char err_msg[IOSIZE]; 1552 query_err_string(src, (int)error_offset, error_type, err_msg, sizeof(err_msg)); 1553 return luaL_error(L, "%s", err_msg); 1554 } 1555 1556 TSQuery **ud = lua_newuserdata(L, sizeof(TSQuery *)); // [udata] 1557 *ud = query; 1558 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERY); // [udata, meta] 1559 lua_setmetatable(L, -2); // [udata] 1560 return 1; 1561 } 1562 1563 static const char *query_err_to_string(TSQueryError error_type) 1564 { 1565 switch (error_type) { 1566 case TSQueryErrorSyntax: 1567 return "Invalid syntax:\n"; 1568 case TSQueryErrorNodeType: 1569 return "Invalid node type "; 1570 case TSQueryErrorField: 1571 return "Invalid field name "; 1572 case TSQueryErrorCapture: 1573 return "Invalid capture name "; 1574 case TSQueryErrorStructure: 1575 return "Impossible pattern:\n"; 1576 default: 1577 return "error"; 1578 } 1579 } 1580 1581 static void query_err_string(const char *src, int error_offset, TSQueryError error_type, char *err, 1582 size_t errlen) 1583 { 1584 int line_start = 0; 1585 int row = 0; 1586 const char *error_line = NULL; 1587 int error_line_len = 0; 1588 1589 const char *end_str; 1590 do { 1591 const char *src_tmp = src + line_start; 1592 end_str = strchr(src_tmp, '\n'); 1593 int line_length = end_str != NULL ? (int)(end_str - src_tmp) : (int)strlen(src_tmp); 1594 int line_end = line_start + line_length; 1595 if (line_end > error_offset) { 1596 error_line = src_tmp; 1597 error_line_len = line_length; 1598 break; 1599 } 1600 line_start = line_end + 1; 1601 row++; 1602 } while (end_str != NULL); 1603 1604 int column = error_offset - line_start; 1605 1606 const char *type_msg = query_err_to_string(error_type); 1607 snprintf(err, errlen, "Query error at %d:%d. %s", row + 1, column + 1, type_msg); 1608 size_t offset = strlen(err); 1609 errlen = errlen - offset; 1610 err = err + offset; 1611 1612 // Error types that report names 1613 if (error_type == TSQueryErrorNodeType 1614 || error_type == TSQueryErrorField 1615 || error_type == TSQueryErrorCapture) { 1616 const char *suffix = src + error_offset; 1617 bool is_anonymous = error_type == TSQueryErrorNodeType && suffix[-1] == '"'; 1618 int suffix_len = 0; 1619 char c = suffix[suffix_len]; 1620 if (is_anonymous) { 1621 int backslashes = 0; 1622 // Stop when we hit an unescaped double quote 1623 while (c != '"' || backslashes % 2 != 0) { 1624 if (c == '\\') { 1625 backslashes += 1; 1626 } else { 1627 backslashes = 0; 1628 } 1629 c = suffix[++suffix_len]; 1630 } 1631 } else { 1632 // Stop when we hit the end of the identifier 1633 while (isalnum(c) || c == '_' || c == '-' || c == '.') { 1634 c = suffix[++suffix_len]; 1635 } 1636 } 1637 snprintf(err, errlen, "\"%.*s\":\n", suffix_len, suffix); 1638 offset = strlen(err); 1639 errlen = errlen - offset; 1640 err = err + offset; 1641 } 1642 1643 if (!error_line) { 1644 snprintf(err, errlen, "Unexpected EOF\n"); 1645 return; 1646 } 1647 1648 snprintf(err, errlen, "%.*s\n%*s^\n", error_line_len, error_line, column, ""); 1649 } 1650 1651 static TSQuery *query_check(lua_State *L, int index) 1652 { 1653 TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY); 1654 luaL_argcheck(L, *ud, index, "TSQuery expected"); 1655 return *ud; 1656 } 1657 1658 static int query_gc(lua_State *L) 1659 { 1660 TSQuery *query = query_check(L, 1); 1661 ts_query_delete(query); 1662 return 0; 1663 } 1664 1665 static int query_tostring(lua_State *L) 1666 { 1667 lua_pushstring(L, "<query>"); 1668 return 1; 1669 } 1670 1671 static int query_inspect(lua_State *L) 1672 { 1673 TSQuery *query = query_check(L, 1); 1674 1675 // TSQueryInfo 1676 lua_createtable(L, 0, 2); // [retval] 1677 1678 uint32_t n_pat = ts_query_pattern_count(query); 1679 lua_createtable(L, (int)n_pat, 1); // [retval, patterns] 1680 for (size_t i = 0; i < n_pat; i++) { 1681 uint32_t len; 1682 const TSQueryPredicateStep *step = ts_query_predicates_for_pattern(query, (uint32_t)i, &len); 1683 if (len == 0) { 1684 continue; 1685 } 1686 lua_createtable(L, (int)len/4, 1); // [retval, patterns, pat] 1687 lua_createtable(L, 3, 0); // [retval, patterns, pat, pred] 1688 int nextpred = 1; 1689 int nextitem = 1; 1690 for (size_t k = 0; k < len; k++) { 1691 if (step[k].type == TSQueryPredicateStepTypeDone) { 1692 lua_rawseti(L, -2, nextpred++); // [retval, patterns, pat] 1693 lua_createtable(L, 3, 0); // [retval, patterns, pat, pred] 1694 nextitem = 1; 1695 continue; 1696 } 1697 1698 if (step[k].type == TSQueryPredicateStepTypeString) { 1699 uint32_t strlen; 1700 const char *str = ts_query_string_value_for_id(query, step[k].value_id, 1701 &strlen); 1702 lua_pushlstring(L, str, strlen); // [retval, patterns, pat, pred, item] 1703 } else if (step[k].type == TSQueryPredicateStepTypeCapture) { 1704 lua_pushinteger(L, step[k].value_id + 1); // [..., pat, pred, item] 1705 } else { 1706 abort(); 1707 } 1708 lua_rawseti(L, -2, nextitem++); // [retval, patterns, pat, pred] 1709 } 1710 // last predicate should have ended with TypeDone 1711 lua_pop(L, 1); // [retval, patterns, pat] 1712 lua_rawseti(L, -2, (int)i + 1); // [retval, patterns] 1713 } 1714 lua_setfield(L, -2, "patterns"); // [retval] 1715 1716 uint32_t n_captures = ts_query_capture_count(query); 1717 lua_createtable(L, (int)n_captures, 0); // [retval, captures] 1718 for (size_t i = 0; i < n_captures; i++) { 1719 uint32_t strlen; 1720 const char *str = ts_query_capture_name_for_id(query, (uint32_t)i, &strlen); 1721 lua_pushlstring(L, str, strlen); // [retval, captures, capture] 1722 lua_rawseti(L, -2, (int)i + 1); 1723 } 1724 lua_setfield(L, -2, "captures"); // [retval] 1725 1726 return 1; 1727 } 1728 1729 static int query_disable_capture(lua_State *L) 1730 { 1731 TSQuery *query = query_check(L, 1); 1732 size_t name_len; 1733 const char *name = luaL_checklstring(L, 2, &name_len); 1734 ts_query_disable_capture(query, name, (uint32_t)name_len); 1735 return 0; 1736 } 1737 1738 static int query_disable_pattern(lua_State *L) 1739 { 1740 TSQuery *query = query_check(L, 1); 1741 const uint32_t pattern_index = (uint32_t)luaL_checkinteger(L, 2); 1742 ts_query_disable_pattern(query, pattern_index - 1); 1743 return 0; 1744 } 1745 1746 // Library init 1747 1748 static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta) 1749 { 1750 if (luaL_newmetatable(L, tname)) { // [meta] 1751 luaL_register(L, NULL, meta); 1752 1753 lua_pushvalue(L, -1); // [meta, meta] 1754 lua_setfield(L, -2, "__index"); // [meta] 1755 } 1756 lua_pop(L, 1); // [] (don't use it now) 1757 } 1758 1759 /// Init the tslua library. 1760 /// 1761 /// All global state is stored in the registry of the lua_State. 1762 static void tslua_init(lua_State *L) 1763 { 1764 // type metatables 1765 build_meta(L, TS_META_PARSER, parser_meta); 1766 build_meta(L, TS_META_TREE, tree_meta); 1767 build_meta(L, TS_META_NODE, node_meta); 1768 build_meta(L, TS_META_QUERY, query_meta); 1769 build_meta(L, TS_META_QUERYCURSOR, querycursor_meta); 1770 build_meta(L, TS_META_QUERYMATCH, querymatch_meta); 1771 1772 ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree); 1773 } 1774 1775 static int tslua_get_language_version(lua_State *L) 1776 { 1777 lua_pushnumber(L, TREE_SITTER_LANGUAGE_VERSION); 1778 return 1; 1779 } 1780 1781 static int tslua_get_minimum_language_version(lua_State *L) 1782 { 1783 lua_pushnumber(L, TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION); 1784 return 1; 1785 } 1786 1787 void nlua_treesitter_free(void) 1788 { 1789 #ifdef HAVE_WASMTIME 1790 if (wasmengine != NULL) { 1791 wasm_engine_delete(wasmengine); 1792 } 1793 if (ts_wasmstore != NULL) { 1794 ts_wasm_store_delete(ts_wasmstore); 1795 } 1796 #endif 1797 } 1798 1799 void nlua_treesitter_init(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL 1800 { 1801 tslua_init(lstate); 1802 1803 lua_pushcfunction(lstate, tslua_push_parser); 1804 lua_setfield(lstate, -2, "_create_ts_parser"); 1805 1806 lua_pushcfunction(lstate, tslua_push_querycursor); 1807 lua_setfield(lstate, -2, "_create_ts_querycursor"); 1808 1809 lua_pushcfunction(lstate, tslua_add_language_from_object); 1810 lua_setfield(lstate, -2, "_ts_add_language_from_object"); 1811 1812 #ifdef HAVE_WASMTIME 1813 lua_pushcfunction(lstate, tslua_add_language_from_wasm); 1814 lua_setfield(lstate, -2, "_ts_add_language_from_wasm"); 1815 #endif 1816 1817 lua_pushcfunction(lstate, tslua_has_language); 1818 lua_setfield(lstate, -2, "_ts_has_language"); 1819 1820 lua_pushcfunction(lstate, tslua_remove_lang); 1821 lua_setfield(lstate, -2, "_ts_remove_language"); 1822 1823 lua_pushcfunction(lstate, tslua_inspect_lang); 1824 lua_setfield(lstate, -2, "_ts_inspect_language"); 1825 1826 lua_pushcfunction(lstate, tslua_parse_query); 1827 lua_setfield(lstate, -2, "_ts_parse_query"); 1828 1829 lua_pushcfunction(lstate, tslua_get_language_version); 1830 lua_setfield(lstate, -2, "_ts_get_language_version"); 1831 1832 lua_pushcfunction(lstate, tslua_get_minimum_language_version); 1833 lua_setfield(lstate, -2, "_ts_get_minimum_language_version"); 1834 }