query.lua (39182B)
1 --- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse 2 --- text. See |vim.treesitter.query.parse()| for a working example. 3 4 local api = vim.api 5 local language = require('vim.treesitter.language') 6 local memoize = vim.func._memoize 7 local cmp_ge = require('vim.treesitter._range').cmp_pos.ge 8 9 local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$' 10 local EXTENDS_FORMAT = '^;+%s*extends%s*$' 11 12 local M = {} 13 14 ---Parsed query, see |vim.treesitter.query.parse()| 15 --- 16 ---@class vim.treesitter.Query 17 ---@field lang string parser language name 18 ---@field captures string[] list of (unique) capture names defined in query 19 ---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives) 20 ---@field has_conceal_line boolean whether the query sets conceal_lines metadata 21 ---@field has_combined_injections boolean whether the query contains combined injections 22 ---@field query TSQuery userdata query object 23 ---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern> 24 local Query = {} 25 Query.__index = Query 26 27 local function is_directive(name) 28 return string.sub(name, -1) == '!' 29 end 30 31 ---@nodoc 32 ---@class vim.treesitter.query.ProcessedPredicate 33 ---@field [1] string predicate name 34 ---@field [2] boolean should match 35 ---@field [3] (integer|string)[] the original predicate 36 37 ---@alias vim.treesitter.query.ProcessedDirective (integer|string)[] 38 39 ---@nodoc 40 ---@class vim.treesitter.query.ProcessedPattern { 41 ---@field predicates vim.treesitter.query.ProcessedPredicate[] 42 ---@field directives vim.treesitter.query.ProcessedDirective[] 43 44 --- Splits the query patterns into predicates and directives. 45 function Query:_process_patterns() 46 self._processed_patterns = {} 47 48 for k, pattern_list in pairs(self.info.patterns) do 49 ---@type vim.treesitter.query.ProcessedPredicate[] 50 local predicates = {} 51 ---@type vim.treesitter.query.ProcessedDirective[] 52 local directives = {} 53 54 for _, pattern in ipairs(pattern_list) do 55 -- Note: tree-sitter strips the leading # from predicates for us. 56 local pred_name = pattern[1] 57 ---@cast pred_name string 58 59 if is_directive(pred_name) then 60 table.insert(directives, pattern) 61 if vim.deep_equal(pattern, { 'set!', 'injection.combined' }) then 62 self.has_combined_injections = true 63 end 64 if vim.deep_equal(pattern, { 'set!', 'conceal_lines', '' }) then 65 self.has_conceal_line = true 66 end 67 else 68 local should_match = true 69 if pred_name:match('^not%-') then 70 pred_name = pred_name:sub(5) 71 should_match = false 72 end 73 table.insert(predicates, { pred_name, should_match, pattern }) 74 end 75 end 76 77 self._processed_patterns[k] = { predicates = predicates, directives = directives } 78 end 79 end 80 81 ---@package 82 ---@see vim.treesitter.query.parse 83 ---@param lang string 84 ---@param ts_query TSQuery 85 ---@return vim.treesitter.Query 86 function Query.new(lang, ts_query) 87 local self = setmetatable({}, Query) 88 local query_info = ts_query:inspect() ---@type TSQueryInfo 89 self.query = ts_query 90 self.lang = lang 91 self.info = { 92 captures = query_info.captures, 93 patterns = query_info.patterns, 94 } 95 self.captures = self.info.captures 96 self:_process_patterns() 97 return self 98 end 99 100 ---@nodoc 101 ---Information for Query, see |vim.treesitter.query.parse()| 102 ---@class vim.treesitter.QueryInfo 103 --- 104 ---List of (unique) capture names defined in query. 105 ---@field captures string[] 106 --- 107 ---Contains information about predicates and directives. 108 ---Key is pattern id, and value is list of predicates or directives defined in the pattern. 109 ---A predicate or directive is a list of (integer|string); integer represents `capture_id`, and 110 ---string represents (literal) arguments to predicate/directive. See |treesitter-predicates| 111 ---and |treesitter-directives| for more details. 112 ---@field patterns table<integer, (integer|string)[][]> 113 114 ---@param files string[] 115 ---@return string[] 116 local function dedupe_files(files) 117 local result = {} 118 ---@type table<string,boolean> 119 local seen = {} 120 121 for _, path in ipairs(files) do 122 if not seen[path] then 123 table.insert(result, path) 124 seen[path] = true 125 end 126 end 127 128 return result 129 end 130 131 local function safe_read(filename, read_quantifier) 132 local file, err = io.open(filename, 'r') 133 if not file then 134 error(err) 135 end 136 local content = file:read(read_quantifier) 137 io.close(file) 138 return content 139 end 140 141 --- Adds {ilang} to {base_langs}, only if {ilang} is different than {lang} 142 --- 143 ---@return boolean true If lang == ilang 144 local function add_included_lang(base_langs, lang, ilang) 145 if lang == ilang then 146 return true 147 end 148 table.insert(base_langs, ilang) 149 return false 150 end 151 152 --- Gets the list of files used to make up a query 153 --- 154 ---@param lang string Language to get query for 155 ---@param query_name string Name of the query to load (e.g., "highlights") 156 ---@param is_included? boolean Internal parameter, most of the time left as `nil` 157 ---@return string[] query_files List of files to load for given query and language 158 function M.get_files(lang, query_name, is_included) 159 local query_path = string.format('queries/%s/%s.scm', lang, query_name) 160 local lang_files = dedupe_files(api.nvim_get_runtime_file(query_path, true)) 161 162 if #lang_files == 0 then 163 return {} 164 end 165 166 local base_query = nil ---@type string? 167 local extensions = {} 168 169 local base_langs = {} ---@type string[] 170 171 -- Now get the base languages by looking at the first line of every file 172 -- The syntax is the following : 173 -- ;+ inherits: ({language},)*{language} 174 -- 175 -- {language} ::= {lang} | ({lang}) 176 for _, filename in ipairs(lang_files) do 177 local file, err = io.open(filename, 'r') 178 if not file then 179 error(err) 180 end 181 182 local extension = false 183 184 for modeline in 185 ---@return string 186 function() 187 return file:read('*l') 188 end 189 do 190 if not vim.startswith(modeline, ';') then 191 break 192 end 193 194 local langlist = modeline:match(MODELINE_FORMAT) 195 if langlist then 196 ---@diagnostic disable-next-line:param-type-mismatch 197 for _, incllang in ipairs(vim.split(langlist, ',')) do 198 local is_optional = incllang:match('%(.*%)') 199 200 if is_optional then 201 if not is_included then 202 if add_included_lang(base_langs, lang, incllang:sub(2, #incllang - 1)) then 203 extension = true 204 end 205 end 206 else 207 if add_included_lang(base_langs, lang, incllang) then 208 extension = true 209 end 210 end 211 end 212 elseif modeline:match(EXTENDS_FORMAT) then 213 extension = true 214 end 215 end 216 217 if extension then 218 table.insert(extensions, filename) 219 elseif base_query == nil then 220 base_query = filename 221 end 222 io.close(file) 223 end 224 225 local query_files = {} 226 for _, base_lang in ipairs(base_langs) do 227 local base_files = M.get_files(base_lang, query_name, true) 228 vim.list_extend(query_files, base_files) 229 end 230 vim.list_extend(query_files, { base_query }) 231 vim.list_extend(query_files, extensions) 232 233 return query_files 234 end 235 236 ---@param filenames string[] 237 ---@return string 238 local function read_query_files(filenames) 239 local contents = {} 240 241 for _, filename in ipairs(filenames) do 242 table.insert(contents, safe_read(filename, '*a')) 243 end 244 245 return table.concat(contents, '') 246 end 247 248 -- The explicitly set query strings from |vim.treesitter.query.set()| 249 ---@type table<string,table<string,string>> 250 local explicit_queries = setmetatable({}, { 251 __index = function(t, k) 252 local lang_queries = {} 253 rawset(t, k, lang_queries) 254 255 return lang_queries 256 end, 257 }) 258 259 --- Sets the runtime query named {query_name} for {lang} 260 --- 261 --- This allows users to override or extend any runtime files and/or configuration 262 --- set by plugins. 263 --- 264 --- For example, you could enable spellchecking of `C` identifiers with the 265 --- following code: 266 --- ```lua 267 --- vim.treesitter.query.set( 268 --- 'c', 269 --- 'highlights', 270 --- [[;inherits c 271 --- (identifier) @spell]]) 272 --- ]]) 273 --- ``` 274 --- 275 ---@param lang string Language to use for the query 276 ---@param query_name string Name of the query (e.g., "highlights") 277 ---@param text string Query text (unparsed). 278 function M.set(lang, query_name, text) 279 --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics 280 M.get:clear(lang, query_name) 281 explicit_queries[lang][query_name] = text 282 end 283 284 --- Returns the runtime query {query_name} for {lang}. 285 --- 286 ---@param lang string Language to use for the query 287 ---@param query_name string Name of the query (e.g. "highlights") 288 --- 289 ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. 290 M.get = memoize('concat-2', function(lang, query_name) 291 local query_string ---@type string 292 293 if explicit_queries[lang][query_name] then 294 local query_files = {} 295 local base_langs = {} ---@type string[] 296 297 for line in explicit_queries[lang][query_name]:gmatch('([^\n]*)\n?') do 298 if not vim.startswith(line, ';') then 299 break 300 end 301 302 local lang_list = line:match(MODELINE_FORMAT) 303 if lang_list then 304 for _, incl_lang in ipairs(vim.split(lang_list, ',')) do 305 local is_optional = incl_lang:match('%(.*%)') 306 307 if is_optional then 308 add_included_lang(base_langs, lang, incl_lang:sub(2, #incl_lang - 1)) 309 else 310 add_included_lang(base_langs, lang, incl_lang) 311 end 312 end 313 elseif line:match(EXTENDS_FORMAT) then 314 table.insert(base_langs, lang) 315 end 316 end 317 318 for _, base_lang in ipairs(base_langs) do 319 local base_files = M.get_files(base_lang, query_name, true) 320 vim.list_extend(query_files, base_files) 321 end 322 323 query_string = read_query_files(query_files) .. explicit_queries[lang][query_name] 324 else 325 local query_files = M.get_files(lang, query_name) 326 query_string = read_query_files(query_files) 327 end 328 329 if #query_string == 0 then 330 return nil 331 end 332 333 return M.parse(lang, query_string) 334 end, false) 335 336 api.nvim_create_autocmd('OptionSet', { 337 pattern = { 'runtimepath' }, 338 group = api.nvim_create_augroup('nvim.treesitter.query_cache_reset', { clear = true }), 339 callback = function() 340 --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics 341 M.get:clear() 342 end, 343 }) 344 345 --- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used 346 --- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|), 347 --- or inspect/modify the query via these fields: 348 --- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`). 349 --- - `info.patterns`: information about predicates. 350 --- - `query`: the underlying |TSQuery| which can be used to disable patterns or captures. 351 --- 352 --- Example: 353 --- ```lua 354 --- local query = vim.treesitter.query.parse('vimdoc', [[ 355 --- ; query 356 --- ((h1) @str 357 --- (#trim! @str 1 1 1 1)) 358 --- ]]) 359 --- local tree = vim.treesitter.get_parser():parse()[1] 360 --- for id, node, metadata in query:iter_captures(tree:root(), 0) do 361 --- -- Print the node name and source text. 362 --- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())}) 363 --- end 364 --- ``` 365 --- 366 ---@param lang string Language to use for the query 367 ---@param query string Query text, in s-expr syntax 368 --- 369 ---@return vim.treesitter.Query : Parsed query 370 --- 371 ---@see [vim.treesitter.query.get()] 372 M.parse = memoize('concat-2', function(lang, query) 373 assert(language.add(lang)) 374 local ts_query = vim._ts_parse_query(lang, query) 375 return Query.new(lang, ts_query) 376 end, false) 377 378 --- Implementations of predicates that can optionally be prefixed with "any-". 379 --- 380 --- These functions contain the implementations for each predicate, correctly 381 --- handling the "any" vs "all" semantics. They are called from the 382 --- predicate_handlers table with the appropriate arguments for each predicate. 383 local impl = { 384 --- @param match table<integer,TSNode[]> 385 --- @param source integer|string 386 --- @param predicate any[] 387 --- @param any boolean 388 ['eq'] = function(match, source, predicate, any) 389 local nodes = match[predicate[2]] 390 if not nodes or #nodes == 0 then 391 return true 392 end 393 394 for _, node in ipairs(nodes) do 395 local node_text = vim.treesitter.get_node_text(node, source) 396 397 local str ---@type string 398 if type(predicate[3]) == 'string' then 399 -- (#eq? @aa "foo") 400 str = predicate[3] 401 else 402 -- (#eq? @aa @bb) 403 local other = assert(match[predicate[3]]) 404 assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes') 405 str = vim.treesitter.get_node_text(other[1], source) 406 end 407 408 local res = str ~= nil and node_text == str 409 if any and res then 410 return true 411 elseif not any and not res then 412 return false 413 end 414 end 415 416 return not any 417 end, 418 419 --- @param match table<integer,TSNode[]> 420 --- @param source integer|string 421 --- @param predicate any[] 422 --- @param any boolean 423 ['lua-match'] = function(match, source, predicate, any) 424 local nodes = match[predicate[2]] 425 if not nodes or #nodes == 0 then 426 return true 427 end 428 429 for _, node in ipairs(nodes) do 430 local regex = predicate[3] 431 local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil 432 if any and res then 433 return true 434 elseif not any and not res then 435 return false 436 end 437 end 438 439 return not any 440 end, 441 442 ['match'] = (function() 443 local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } 444 local function check_magic(str) 445 if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then 446 return str 447 end 448 return '\\v' .. str 449 end 450 451 local compiled_vim_regexes = setmetatable({}, { 452 __index = function(t, pattern) 453 local res = vim.regex(check_magic(pattern)) 454 rawset(t, pattern, res) 455 return res 456 end, 457 }) 458 459 --- @param match table<integer,TSNode[]> 460 --- @param source integer|string 461 --- @param predicate any[] 462 --- @param any boolean 463 return function(match, source, predicate, any) 464 local nodes = match[predicate[2]] 465 if not nodes or #nodes == 0 then 466 return true 467 end 468 469 for _, node in ipairs(nodes) do 470 local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex 471 local res = regex:match_str(vim.treesitter.get_node_text(node, source)) 472 if any and res then 473 return true 474 elseif not any and not res then 475 return false 476 end 477 end 478 return not any 479 end 480 end)(), 481 482 --- @param match table<integer,TSNode[]> 483 --- @param source integer|string 484 --- @param predicate any[] 485 --- @param any boolean 486 ['contains'] = function(match, source, predicate, any) 487 local nodes = match[predicate[2]] 488 if not nodes or #nodes == 0 then 489 return true 490 end 491 492 for _, node in ipairs(nodes) do 493 local node_text = vim.treesitter.get_node_text(node, source) 494 495 for i = 3, #predicate do 496 local res = string.find(node_text, predicate[i], 1, true) 497 if any and res then 498 return true 499 elseif not any and not res then 500 return false 501 end 502 end 503 end 504 505 return not any 506 end, 507 } 508 509 ---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean 510 511 -- Predicate handler receive the following arguments 512 -- (match, pattern, bufnr, predicate) 513 ---@type table<string,TSPredicate> 514 local predicate_handlers = { 515 ['eq?'] = function(match, _, source, predicate) 516 return impl['eq'](match, source, predicate, false) 517 end, 518 519 ['any-eq?'] = function(match, _, source, predicate) 520 return impl['eq'](match, source, predicate, true) 521 end, 522 523 ['lua-match?'] = function(match, _, source, predicate) 524 return impl['lua-match'](match, source, predicate, false) 525 end, 526 527 ['any-lua-match?'] = function(match, _, source, predicate) 528 return impl['lua-match'](match, source, predicate, true) 529 end, 530 531 ['match?'] = function(match, _, source, predicate) 532 return impl['match'](match, source, predicate, false) 533 end, 534 535 ['any-match?'] = function(match, _, source, predicate) 536 return impl['match'](match, source, predicate, true) 537 end, 538 539 ['contains?'] = function(match, _, source, predicate) 540 return impl['contains'](match, source, predicate, false) 541 end, 542 543 ['any-contains?'] = function(match, _, source, predicate) 544 return impl['contains'](match, source, predicate, true) 545 end, 546 547 ['any-of?'] = function(match, _, source, predicate) 548 local nodes = match[predicate[2]] 549 if not nodes or #nodes == 0 then 550 return true 551 end 552 553 for _, node in ipairs(nodes) do 554 local node_text = vim.treesitter.get_node_text(node, source) 555 556 -- Since 'predicate' will not be used by callers of this function, use it 557 -- to store a string set built from the list of words to check against. 558 local string_set = predicate['string_set'] --- @type table<string, boolean> 559 if not string_set then 560 string_set = {} 561 for i = 3, #predicate do 562 string_set[predicate[i]] = true 563 end 564 predicate['string_set'] = string_set 565 end 566 567 if string_set[node_text] then 568 return true 569 end 570 end 571 572 return false 573 end, 574 575 ['has-ancestor?'] = function(match, _, _, predicate) 576 local nodes = match[predicate[2]] 577 if not nodes or #nodes == 0 then 578 return true 579 end 580 581 for _, node in ipairs(nodes) do 582 if node:__has_ancestor(predicate) then 583 return true 584 end 585 end 586 return false 587 end, 588 589 ['has-parent?'] = function(match, _, _, predicate) 590 local nodes = match[predicate[2]] 591 if not nodes or #nodes == 0 then 592 return true 593 end 594 595 for _, node in ipairs(nodes) do 596 if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then 597 return true 598 end 599 end 600 return false 601 end, 602 } 603 604 -- As we provide lua-match? also expose vim-match? 605 predicate_handlers['vim-match?'] = predicate_handlers['match?'] 606 predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] 607 608 ---@nodoc 609 ---@class vim.treesitter.query.TSMetadata 610 ---@field range? Range 611 ---@field offset? Range4 612 ---@field conceal? string 613 ---@field bo.commentstring? string 614 ---@field [integer]? vim.treesitter.query.TSMetadata 615 ---@field [string]? integer|string 616 617 ---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) 618 619 -- Predicate handler receive the following arguments 620 -- (match, pattern, bufnr, predicate) 621 622 -- Directives store metadata or perform side effects against a match. 623 -- Directives should always end with a `!`. 624 -- Directive handler receive the following arguments 625 -- (match, pattern, bufnr, predicate, metadata) 626 ---@type table<string,TSDirective> 627 local directive_handlers = { 628 ['set!'] = function(_, _, _, pred, metadata) 629 if #pred >= 3 and type(pred[2]) == 'number' then 630 -- (#set! @capture key value) 631 local capture_id, key, value = pred[2], pred[3], pred[4] 632 if not metadata[capture_id] then 633 metadata[capture_id] = {} 634 end 635 metadata[capture_id][key] = value 636 else 637 -- (#set! key value) 638 local key, value = pred[2], pred[3] 639 metadata[key] = value or true 640 end 641 end, 642 -- Shifts the range of a node. 643 -- Example: (#offset! @_node 0 1 0 -1) 644 ['offset!'] = function(match, _, _, pred, metadata) 645 local capture_id = pred[2] --[[@as integer]] 646 local nodes = match[capture_id] 647 if not nodes or #nodes == 0 then 648 return 649 end 650 651 if not metadata[capture_id] then 652 metadata[capture_id] = {} 653 end 654 655 metadata[capture_id].offset = { 656 pred[3] --[[@as integer]] 657 or 0, 658 pred[4] --[[@as integer]] 659 or 0, 660 pred[5] --[[@as integer]] 661 or 0, 662 pred[6] --[[@as integer]] 663 or 0, 664 } 665 end, 666 -- Transform the content of the node 667 -- Example: (#gsub! @_node ".*%.(.*)" "%1") 668 ['gsub!'] = function(match, _, bufnr, pred, metadata) 669 assert(#pred == 4) 670 671 local id = pred[2] 672 assert(type(id) == 'number') 673 674 local nodes = match[id] 675 if not nodes or #nodes == 0 then 676 return 677 end 678 assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') 679 local node = nodes[1] 680 local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' 681 682 if not metadata[id] then 683 metadata[id] = {} 684 end 685 686 local pattern, replacement = pred[3], pred[4] 687 assert(type(pattern) == 'string') 688 assert(type(replacement) == 'string') 689 690 metadata[id].text = text:gsub(pattern, replacement) 691 end, 692 -- Trim whitespace from both sides of the node 693 -- Example: (#trim! @fold 1 1 1 1) 694 ['trim!'] = function(match, _, bufnr, pred, metadata) 695 local capture_id = pred[2] 696 assert(type(capture_id) == 'number') 697 698 local trim_start_lines = pred[3] == '1' 699 local trim_start_cols = pred[4] == '1' 700 local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility 701 local trim_end_cols = pred[6] == '1' 702 703 local nodes = match[capture_id] 704 if not nodes or #nodes == 0 then 705 return 706 end 707 assert(#nodes == 1, '#trim! does not support captures on multiple nodes') 708 local node = nodes[1] 709 710 local start_row, start_col, end_row, end_col = node:range() 711 712 local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n') 713 if end_col == 0 then 714 -- get_node_text() will ignore the last line if the node ends at column 0 715 node_text[#node_text + 1] = '' 716 end 717 718 local end_idx = #node_text 719 local start_idx = 1 720 721 if trim_end_lines then 722 while end_idx > 0 and node_text[end_idx]:find('^%s*$') do 723 end_idx = end_idx - 1 724 end_row = end_row - 1 725 -- set the end position to the last column of the next line, or 0 if we just trimmed the 726 -- last line 727 end_col = end_idx > 0 and #node_text[end_idx] or 0 728 end 729 end 730 if trim_end_cols then 731 if end_idx == 0 then 732 end_row = start_row 733 end_col = start_col 734 else 735 local whitespace_start = node_text[end_idx]:find('(%s*)$') 736 end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0) 737 end 738 end 739 740 if trim_start_lines then 741 while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do 742 start_idx = start_idx + 1 743 start_row = start_row + 1 744 start_col = 0 745 end 746 end 747 if trim_start_cols and node_text[start_idx] then 748 local _, whitespace_end = node_text[start_idx]:find('^(%s*)') 749 whitespace_end = whitespace_end or 0 750 start_col = (start_idx == 1 and start_col or 0) + whitespace_end 751 end 752 753 -- If this produces an invalid range, we just skip it. 754 if start_row < end_row or (start_row == end_row and start_col <= end_col) then 755 metadata[capture_id] = metadata[capture_id] or {} 756 metadata[capture_id].range = { start_row, start_col, end_row, end_col } 757 end 758 end, 759 } 760 761 --- @class vim.treesitter.query.add_predicate.Opts 762 --- @inlinedoc 763 --- 764 --- Override an existing predicate of the same name 765 --- @field force? boolean 766 --- 767 --- Use the correct implementation of the match table where capture IDs map to 768 --- a list of nodes instead of a single node. Defaults to true. This option will 769 --- be removed in a future release. 770 --- @field all? boolean 771 772 --- Adds a new predicate to be used in queries 773 --- 774 ---@param name string Name of the predicate, without leading # 775 ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata): boolean? # 776 --- - see |vim.treesitter.query.add_directive()| for argument meanings 777 ---@param opts? vim.treesitter.query.add_predicate.Opts 778 function M.add_predicate(name, handler, opts) 779 -- Backward compatibility: old signature had "force" as boolean argument 780 if type(opts) == 'boolean' then 781 opts = { force = opts } 782 end 783 784 opts = opts or {} 785 786 if predicate_handlers[name] and not opts.force then 787 error(string.format('Overriding existing predicate %s', name)) 788 end 789 790 if opts.all ~= false then 791 predicate_handlers[name] = handler 792 else 793 --- @param match table<integer, TSNode[]> 794 local function wrapper(match, ...) 795 local m = {} ---@type table<integer, TSNode> 796 for k, v in pairs(match) do 797 if type(k) == 'number' then 798 m[k] = v[#v] 799 end 800 end 801 return handler(m, ...) 802 end 803 predicate_handlers[name] = wrapper 804 end 805 end 806 807 --- Adds a new directive to be used in queries 808 --- 809 --- Handlers can set match level data by setting directly on the 810 --- metadata object `metadata.key = value`. Additionally, handlers 811 --- can set node level data by using the capture id on the 812 --- metadata table `metadata[capture_id].key = value` 813 --- 814 ---@param name string Name of the directive, without leading # 815 ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata) # 816 --- - match: A table mapping capture IDs to a list of captured nodes 817 --- - pattern: the index of the matching pattern in the query file 818 --- - predicate: list of strings containing the full directive being called, e.g. 819 --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` 820 ---@param opts vim.treesitter.query.add_predicate.Opts 821 function M.add_directive(name, handler, opts) 822 -- Backward compatibility: old signature had "force" as boolean argument 823 if type(opts) == 'boolean' then 824 opts = { force = opts } 825 end 826 827 opts = opts or {} 828 829 if directive_handlers[name] and not opts.force then 830 error(string.format('Overriding existing directive %s', name)) 831 end 832 833 if opts.all then 834 directive_handlers[name] = handler 835 else 836 --- @param match table<integer, TSNode[]> 837 local function wrapper(match, ...) 838 local m = {} ---@type table<integer, TSNode> 839 for k, v in pairs(match) do 840 m[k] = v[#v] 841 end 842 handler(m, ...) 843 end 844 directive_handlers[name] = wrapper 845 end 846 end 847 848 --- Lists the currently available directives to use in queries. 849 ---@return string[] : Supported directives. 850 function M.list_directives() 851 return vim.tbl_keys(directive_handlers) 852 end 853 854 --- Lists the currently available predicates to use in queries. 855 ---@return string[] : Supported predicates. 856 function M.list_predicates() 857 return vim.tbl_keys(predicate_handlers) 858 end 859 860 ---@private 861 ---@param pattern_i integer 862 ---@param predicates vim.treesitter.query.ProcessedPredicate[] 863 ---@param captures table<integer, TSNode[]> 864 ---@param source integer|string 865 ---@return boolean whether the predicates match 866 function Query:_match_predicates(predicates, pattern_i, captures, source) 867 for _, predicate in ipairs(predicates) do 868 local processed_name = predicate[1] 869 local should_match = predicate[2] 870 local orig_predicate = predicate[3] 871 872 local handler = predicate_handlers[processed_name] 873 if not handler then 874 error(string.format('No handler for %s', orig_predicate[1])) 875 return false 876 end 877 878 local does_match = handler(captures, pattern_i, source, orig_predicate) 879 if does_match ~= should_match then 880 return false 881 end 882 end 883 return true 884 end 885 886 ---@private 887 ---@param pattern_i integer 888 ---@param directives vim.treesitter.query.ProcessedDirective[] 889 ---@param source integer|string 890 ---@param captures table<integer, TSNode[]> 891 ---@return vim.treesitter.query.TSMetadata metadata 892 function Query:_apply_directives(directives, pattern_i, captures, source) 893 ---@type vim.treesitter.query.TSMetadata 894 local metadata = {} 895 896 for _, directive in pairs(directives) do 897 local handler = directive_handlers[directive[1]] 898 899 if not handler then 900 error(string.format('No handler for %s', directive[1])) 901 end 902 903 handler(captures, pattern_i, source, directive, metadata) 904 end 905 906 return metadata 907 end 908 909 --- Returns the start and stop value if set else the node's range. 910 -- When the node's range is used, the stop is incremented by 1 911 -- to make the search inclusive. 912 ---@param start integer? 913 ---@param stop integer? 914 ---@param node TSNode 915 ---@return integer, integer 916 local function value_or_node_range(start, stop, node) 917 if start == nil then 918 start = node:start() 919 end 920 if stop == nil then 921 stop = node:end_() + 1 -- Make stop inclusive 922 end 923 924 return start, stop 925 end 926 927 --- Iterates over all captures from all matches in {node}. 928 --- 929 --- {source} is required if the query contains predicates; then the caller 930 --- must ensure to use a freshly parsed tree consistent with the current 931 --- text of the buffer (if relevant). {start} and {stop} can be used to limit 932 --- matches inside a row range (this is typically used with root node 933 --- as the {node}, i.e., to get syntax highlight matches in the current 934 --- viewport). When omitted, the {start} and {stop} row values are used from the given node. 935 --- 936 --- The iterator returns four values: 937 --- 1. the numeric id identifying the capture 938 --- 2. the captured node 939 --- 3. metadata from any directives processing the match 940 --- 4. the match itself 941 --- 942 --- Example: how to get captures by name: 943 --- ```lua 944 --- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do 945 --- local name = query.captures[id] -- name of the capture in the query 946 --- -- typically useful info about the node: 947 --- local type = node:type() -- type of the captured node 948 --- local row1, col1, row2, col2 = node:range() -- range of the capture 949 --- -- ... use the info here ... 950 --- end 951 --- ``` 952 --- 953 ---@param node TSNode under which the search will occur 954 ---@param source (integer|string) Source buffer or string to extract text from 955 ---@param start_row? integer Starting line for the search. Defaults to `node:start()`. 956 ---@param end_row? integer Stopping line for the search (end-inclusive, unless `stop_col` is provided). Defaults to `node:end_()`. 957 ---@param opts? table Optional keyword arguments: 958 --- - max_start_depth (integer) if non-zero, sets the maximum start depth 959 --- for each match. This is used to prevent traversing too deep into a tree. 960 --- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). 961 --- - start_col (integer) Starting column for the search. 962 --- - end_col (integer) Stopping column for the search (end-exclusive). 963 --- 964 ---@return (fun(end_line: integer|nil, end_col: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch, TSTree): 965 --- capture id, capture node, metadata, match, tree 966 --- 967 ---@note Captures are only returned if the query pattern of a specific capture contained predicates. 968 function Query:iter_captures(node, source, start_row, end_row, opts) 969 opts = opts or {} 970 opts.match_limit = opts.match_limit or 256 971 972 if type(source) == 'number' and source == 0 then 973 source = api.nvim_get_current_buf() 974 end 975 976 start_row, end_row = value_or_node_range(start_row, end_row, node) 977 978 local tree = node:tree() 979 local cursor = vim._create_ts_querycursor(node, self.query, { 980 start_row = start_row, 981 start_col = opts.start_col or 0, 982 end_row = end_row, 983 end_col = opts.end_col or 0, 984 max_start_depth = opts.max_start_depth, 985 match_limit = opts.match_limit or 256, 986 }) 987 988 -- For faster checks that a match is not in the cache. 989 local highest_cached_match_id = -1 990 ---@type table<integer, vim.treesitter.query.TSMetadata> 991 local match_cache = {} 992 993 local function iter(end_line, end_col) 994 local capture, captured_node, match = cursor:next_capture() 995 996 if not capture then 997 return 998 end 999 1000 local match_id, pattern_i = match:info() 1001 1002 --- @type vim.treesitter.query.TSMetadata 1003 local metadata 1004 if match_id <= highest_cached_match_id then 1005 metadata = match_cache[match_id] 1006 end 1007 1008 if not metadata then 1009 metadata = {} 1010 1011 local processed_pattern = self._processed_patterns[pattern_i] 1012 if processed_pattern then 1013 local captures = match:captures() 1014 1015 local predicates = processed_pattern.predicates 1016 if not self:_match_predicates(predicates, pattern_i, captures, source) then 1017 cursor:remove_match(match_id) 1018 1019 local row, col = captured_node:range() 1020 1021 local outside = false 1022 if end_line then 1023 if end_col then 1024 outside = cmp_ge(row, col, end_line, end_col) 1025 else 1026 outside = row > end_line 1027 end 1028 end 1029 1030 if outside then 1031 return nil, captured_node, nil, nil 1032 end 1033 1034 return iter(end_line) -- tail call: try next match 1035 end 1036 1037 local directives = processed_pattern.directives 1038 metadata = self:_apply_directives(directives, pattern_i, captures, source) 1039 end 1040 1041 highest_cached_match_id = math.max(highest_cached_match_id, match_id) 1042 match_cache[match_id] = metadata 1043 end 1044 1045 return capture, captured_node, metadata, match, tree 1046 end 1047 return iter 1048 end 1049 1050 --- Iterates the matches of self on a given range. 1051 --- 1052 --- Iterate over all matches within a {node}. The arguments are the same as for 1053 --- |Query:iter_captures()| but the iterated values are different: an (1-based) 1054 --- index of the pattern in the query, a table mapping capture indices to a list 1055 --- of nodes, and metadata from any directives processing the match. 1056 --- 1057 --- Example: 1058 --- 1059 --- ```lua 1060 --- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1) do 1061 --- for id, nodes in pairs(match) do 1062 --- local name = query.captures[id] 1063 --- for _, node in ipairs(nodes) do 1064 --- -- `node` was captured by the `name` capture in the match 1065 --- 1066 --- local node_data = metadata[id] -- Node level metadata 1067 --- -- ... use the info here ... 1068 --- end 1069 --- end 1070 --- end 1071 --- ``` 1072 --- 1073 --- 1074 ---@param node TSNode under which the search will occur 1075 ---@param source (integer|string) Source buffer or string to search 1076 ---@param start? integer Starting line for the search. Defaults to `node:start()`. 1077 ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. 1078 ---@param opts? table Optional keyword arguments: 1079 --- - max_start_depth (integer) if non-zero, sets the maximum start depth 1080 --- for each match. This is used to prevent traversing too deep into a tree. 1081 --- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). 1082 --- - all (boolean) When `false` (default `true`), the returned table maps capture IDs to a single 1083 --- (last) node instead of the full list of matching nodes. This option is only for backward 1084 --- compatibility and will be removed in a future release. 1085 --- 1086 ---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata, TSTree): pattern id, match, metadata, tree 1087 function Query:iter_matches(node, source, start, stop, opts) 1088 opts = opts or {} 1089 opts.match_limit = opts.match_limit or 256 1090 1091 if type(source) == 'number' and source == 0 then 1092 source = api.nvim_get_current_buf() 1093 end 1094 1095 start, stop = value_or_node_range(start, stop, node) 1096 1097 local tree = node:tree() 1098 local cursor = vim._create_ts_querycursor(node, self.query, { 1099 start_row = start, 1100 start_col = 0, 1101 end_row = stop, 1102 end_col = 0, 1103 max_start_depth = opts.max_start_depth, 1104 match_limit = opts.match_limit or 256, 1105 }) 1106 1107 local function iter() 1108 local match = cursor:next_match() 1109 1110 if not match then 1111 return 1112 end 1113 1114 local match_id, pattern_i = match:info() 1115 local processed_pattern = self._processed_patterns[pattern_i] 1116 local captures = match:captures() 1117 1118 --- @type vim.treesitter.query.TSMetadata 1119 local metadata = {} 1120 if processed_pattern then 1121 local predicates = processed_pattern.predicates 1122 if not self:_match_predicates(predicates, pattern_i, captures, source) then 1123 cursor:remove_match(match_id) 1124 return iter() -- tail call: try next match 1125 end 1126 local directives = processed_pattern.directives 1127 metadata = self:_apply_directives(directives, pattern_i, captures, source) 1128 end 1129 1130 if opts.all == false then 1131 -- Convert the match table into the old buggy version for backward 1132 -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by 1133 -- setting `all` to `false`. 1134 local old_match = {} ---@type table<integer, TSNode> 1135 for k, v in pairs(captures or {}) do 1136 old_match[k] = v[#v] 1137 end 1138 return pattern_i, old_match, metadata 1139 end 1140 1141 -- TODO(lewis6991): create a new function that returns {match, metadata} 1142 return pattern_i, captures, metadata, tree 1143 end 1144 return iter 1145 end 1146 1147 --- Optional keyword arguments: 1148 --- @class vim.treesitter.query.lint.Opts 1149 --- @inlinedoc 1150 --- 1151 --- Language(s) to use for checking the query. 1152 --- If multiple languages are specified, queries are validated for all of them 1153 --- @field langs? string|string[] 1154 --- 1155 --- Just clear current lint errors 1156 --- @field clear boolean 1157 1158 --- Lint treesitter queries using installed parser, or clear lint errors. 1159 --- 1160 --- Use |treesitter-parsers| in runtimepath to check the query file in {buf} for errors: 1161 --- 1162 --- - verify that used nodes are valid identifiers in the grammar. 1163 --- - verify that predicates and directives are valid. 1164 --- - verify that top-level s-expressions are valid. 1165 --- 1166 --- The found diagnostics are reported using |diagnostic-api|. 1167 --- By default, the parser used for verification is determined by the containing folder 1168 --- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the 1169 --- `lua` language will be used. 1170 ---@param buf (integer) Buffer handle 1171 ---@param opts? vim.treesitter.query.lint.Opts 1172 function M.lint(buf, opts) 1173 if opts and opts.clear then 1174 vim.treesitter._query_linter.clear(buf) 1175 else 1176 vim.treesitter._query_linter.lint(buf, opts) 1177 end 1178 end 1179 1180 --- Omnifunc for completing node names and predicates in treesitter queries. 1181 --- 1182 --- Use via 1183 --- 1184 --- ```lua 1185 --- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc' 1186 --- ``` 1187 --- 1188 --- @param findstart 0|1 1189 --- @param base string 1190 function M.omnifunc(findstart, base) 1191 return vim.treesitter._query_linter.omnifunc(findstart, base) 1192 end 1193 1194 --- Opens a live editor to query the buffer you started from. 1195 --- 1196 --- Can also be shown with the [:EditQuery]() command. `:EditQuery <tab>` completes available 1197 --- parsers. 1198 --- 1199 --- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted 1200 --- with |hl-DiagnosticVirtualTextHint| in the source buffer. 1201 --- 1202 --- The query editor is a scratch buffer, use `:write` to save it. You can find example queries 1203 --- at `$VIMRUNTIME/queries/`. 1204 --- 1205 --- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype. 1206 function M.edit(lang) 1207 assert(vim.treesitter.dev.edit_query(lang)) 1208 end 1209 1210 return M