neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

treesitter.lua (16617B)


      1 local api = vim.api
      2 
      3 ---@type table<integer,vim.treesitter.LanguageTree>
      4 local parsers = setmetatable({}, { __mode = 'v' })
      5 
      6 local M = vim._defer_require('vim.treesitter', {
      7  _fold = ..., --- @module 'vim.treesitter._fold'
      8  _query_linter = ..., --- @module 'vim.treesitter._query_linter'
      9  _range = ..., --- @module 'vim.treesitter._range'
     10  dev = ..., --- @module 'vim.treesitter.dev'
     11  highlighter = ..., --- @module 'vim.treesitter.highlighter'
     12  language = ..., --- @module 'vim.treesitter.language'
     13  languagetree = ..., --- @module 'vim.treesitter.languagetree'
     14  query = ..., --- @module 'vim.treesitter.query'
     15 })
     16 
     17 local LanguageTree = M.languagetree
     18 
     19 --- @nodoc
     20 M.language_version = vim._ts_get_language_version()
     21 
     22 --- @nodoc
     23 M.minimum_language_version = vim._ts_get_minimum_language_version()
     24 
     25 --- Creates a new parser
     26 ---
     27 --- It is not recommended to use this; use |get_parser()| instead.
     28 ---
     29 ---@param bufnr integer Buffer the parser will be tied to (0 for current buffer)
     30 ---@param lang string Language of the parser
     31 ---@param opts (table|nil) Options to pass to the created language tree
     32 ---
     33 ---@return vim.treesitter.LanguageTree object to use for parsing
     34 function M._create_parser(bufnr, lang, opts)
     35  bufnr = vim._resolve_bufnr(bufnr)
     36 
     37  local self = LanguageTree.new(bufnr, lang, opts)
     38 
     39  local function bytes_cb(_, ...)
     40    self:_on_bytes(...)
     41  end
     42 
     43  local function detach_cb(_, ...)
     44    if parsers[bufnr] == self then
     45      parsers[bufnr] = nil
     46    end
     47    self:_on_detach(...)
     48  end
     49 
     50  local function reload_cb(_)
     51    self:_on_reload()
     52  end
     53 
     54  local source = self:source() --[[@as integer]]
     55 
     56  api.nvim_buf_attach(
     57    source,
     58    false,
     59    { on_bytes = bytes_cb, on_detach = detach_cb, on_reload = reload_cb, preview = true }
     60  )
     61 
     62  return self
     63 end
     64 
     65 local function valid_lang(lang)
     66  return lang and lang ~= ''
     67 end
     68 
     69 --- Returns the parser for a specific buffer and attaches it to the buffer
     70 ---
     71 --- If needed, this will create the parser.
     72 ---
     73 --- If no parser can be created, an error is thrown. Set `opts.error = false` to suppress this and
     74 --- return nil (and an error message) instead. WARNING: This behavior will become default in Nvim
     75 --- 0.12 and the option will be removed.
     76 ---
     77 ---@param bufnr (integer|nil) Buffer the parser should be tied to (default: current buffer)
     78 ---@param lang (string|nil) Language of this parser (default: from buffer filetype)
     79 ---@param opts (table|nil) Options to pass to the created language tree
     80 ---
     81 ---@return vim.treesitter.LanguageTree? object to use for parsing
     82 ---@return string? error message, if applicable
     83 function M.get_parser(bufnr, lang, opts)
     84  opts = opts or {}
     85  local should_error = opts.error == nil or opts.error
     86 
     87  bufnr = vim._resolve_bufnr(bufnr)
     88 
     89  if not valid_lang(lang) then
     90    lang = M.language.get_lang(vim.bo[bufnr].filetype)
     91  end
     92 
     93  if not valid_lang(lang) then
     94    if not parsers[bufnr] then
     95      local err_msg =
     96        string.format('Parser not found for buffer %s: language could not be determined', bufnr)
     97      if should_error then
     98        error(err_msg)
     99      end
    100      return nil, err_msg
    101    end
    102  elseif parsers[bufnr] == nil or parsers[bufnr]:lang() ~= lang then
    103    if not api.nvim_buf_is_loaded(bufnr) then
    104      error(('Buffer %s must be loaded to create parser'):format(bufnr))
    105    end
    106    local parser = vim.F.npcall(M._create_parser, bufnr, lang, opts)
    107    if not parser then
    108      local err_msg =
    109        string.format('Parser could not be created for buffer %s and language "%s"', bufnr, lang)
    110      if should_error then
    111        error(err_msg)
    112      end
    113      return nil, err_msg
    114    end
    115    parsers[bufnr] = parser
    116  end
    117 
    118  parsers[bufnr]:register_cbs(opts.buf_attach_cbs)
    119 
    120  return parsers[bufnr]
    121 end
    122 
    123 --- Returns a string parser
    124 ---
    125 ---@param str string Text to parse
    126 ---@param lang string Language of this string
    127 ---@param opts (table|nil) Options to pass to the created language tree
    128 ---
    129 ---@return vim.treesitter.LanguageTree object to use for parsing
    130 function M.get_string_parser(str, lang, opts)
    131  vim.validate('str', str, 'string')
    132  vim.validate('lang', lang, 'string')
    133 
    134  return LanguageTree.new(str, lang, opts)
    135 end
    136 
    137 --- Determines whether a node is the ancestor of another
    138 ---
    139 ---@param dest TSNode Possible ancestor
    140 ---@param source TSNode Possible descendant
    141 ---
    142 ---@return boolean True if {dest} is an ancestor of {source}
    143 function M.is_ancestor(dest, source)
    144  if not (dest and source) then
    145    return false
    146  end
    147 
    148  return dest:child_with_descendant(source) ~= nil
    149 end
    150 
    151 --- Returns the node's range or an unpacked range table
    152 ---
    153 ---@param node_or_range TSNode|Range4 Node or table of positions
    154 ---
    155 ---@return integer start_row
    156 ---@return integer start_col # (byte offset)
    157 ---@return integer end_row
    158 ---@return integer end_col # (byte offset)
    159 function M.get_node_range(node_or_range)
    160  if type(node_or_range) == 'table' then
    161    --- @cast node_or_range -TSNode LuaLS bug
    162    return M._range.unpack4(node_or_range)
    163  else
    164    return node_or_range:range(false)
    165  end
    166 end
    167 
    168 ---@param node TSNode
    169 ---@param source integer|string Buffer or string from which the {node} is extracted
    170 ---@param offset Range4
    171 ---@return Range6
    172 local function apply_range_offset(node, source, offset)
    173  ---@diagnostic disable-next-line: missing-fields LuaLS varargs bug
    174  local range = { node:range() } ---@type Range4
    175  local start_row_offset = offset[1]
    176  local start_col_offset = offset[2]
    177  local end_row_offset = offset[3]
    178  local end_col_offset = offset[4]
    179 
    180  range[1] = range[1] + start_row_offset
    181  range[2] = range[2] + start_col_offset
    182  range[3] = range[3] + end_row_offset
    183  range[4] = range[4] + end_col_offset
    184 
    185  if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
    186    return M._range.add_bytes(source, range)
    187  end
    188 
    189  -- If this produces an invalid range, we just skip it.
    190  return { node:range(true) }
    191 end
    192 
    193 ---Get the range of a |TSNode|. Can also supply {source} and {metadata}
    194 ---to get the range with directives applied.
    195 ---@param node TSNode
    196 ---@param source integer|string|nil Buffer or string from which the {node} is extracted
    197 ---@param metadata vim.treesitter.query.TSMetadata|nil
    198 ---@return Range6
    199 function M.get_range(node, source, metadata)
    200  if metadata then
    201    if metadata.range then
    202      return M._range.add_bytes(assert(source), metadata.range)
    203    elseif metadata.offset then
    204      return apply_range_offset(node, assert(source), metadata.offset)
    205    end
    206  end
    207  return { node:range(true) }
    208 end
    209 
    210 ---@param buf integer
    211 ---@param range Range
    212 ---@returns string
    213 local function buf_range_get_text(buf, range)
    214  local start_row, start_col, end_row, end_col = M._range.unpack4(range)
    215  if end_col == 0 then
    216    if start_row == end_row then
    217      start_col = -1
    218      start_row = start_row - 1
    219    end
    220    end_col = -1
    221    end_row = end_row - 1
    222  end
    223  local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {})
    224  return table.concat(lines, '\n')
    225 end
    226 
    227 --- Gets the text corresponding to a given node
    228 ---
    229 ---@param node TSNode
    230 ---@param source (integer|string) Buffer or string from which the {node} is extracted
    231 ---@param opts (table|nil) Optional parameters.
    232 ---          - metadata (table) Metadata of a specific capture. This would be
    233 ---            set to `metadata[capture_id]` when using |vim.treesitter.query.add_directive()|.
    234 ---@return string
    235 function M.get_node_text(node, source, opts)
    236  opts = opts or {}
    237  local metadata = opts.metadata or {}
    238 
    239  if metadata.text then
    240    return metadata.text
    241  elseif type(source) == 'number' then
    242    local range = M.get_range(node, source, metadata)
    243    return buf_range_get_text(source, range)
    244  end
    245 
    246  ---@cast source string
    247  return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
    248 end
    249 
    250 --- Determines whether (line, col) position is in node range
    251 ---
    252 ---@param node TSNode defining the range
    253 ---@param line integer Line (0-based)
    254 ---@param col integer Column (0-based)
    255 ---
    256 ---@return boolean True if the position is in node range
    257 function M.is_in_node_range(node, line, col)
    258  return M.node_contains(node, { line, col, line, col + 1 })
    259 end
    260 
    261 --- Determines if a node contains a range
    262 ---
    263 ---@param node TSNode
    264 ---@param range table
    265 ---
    266 ---@return boolean True if the {node} contains the {range}
    267 function M.node_contains(node, range)
    268  -- allow a table so nodes can be mocked
    269  vim.validate('node', node, { 'userdata', 'table' })
    270  vim.validate('range', range, M._range.validate, 'integer list with 4 or 6 elements')
    271  --- @diagnostic disable-next-line: missing-fields LuaLS bug
    272  local nrange = { node:range() } --- @type Range4
    273  return M._range.contains(nrange, range)
    274 end
    275 
    276 --- Returns a list of highlight captures at the given position
    277 ---
    278 --- Each capture is represented by a table containing the capture name as a string, the capture's
    279 --- language, a table of metadata (`priority`, `conceal`, ...; empty if none are defined), and the
    280 --- id of the capture.
    281 ---
    282 ---@param bufnr integer Buffer number (0 for current buffer)
    283 ---@param row integer Position row
    284 ---@param col integer Position column
    285 ---
    286 ---@return {capture: string, lang: string, metadata: vim.treesitter.query.TSMetadata, id: integer}[]
    287 function M.get_captures_at_pos(bufnr, row, col)
    288  bufnr = vim._resolve_bufnr(bufnr)
    289  local buf_highlighter = M.highlighter.active[bufnr]
    290 
    291  if not buf_highlighter then
    292    return {}
    293  end
    294 
    295  local matches = {}
    296 
    297  buf_highlighter.tree:for_each_tree(function(tstree, tree)
    298    if not tstree then
    299      return
    300    end
    301 
    302    local root = tstree:root()
    303    local root_start_row, _, root_end_row, _ = root:range()
    304 
    305    -- Only worry about trees within the line range
    306    if root_start_row > row or root_end_row < row then
    307      return
    308    end
    309 
    310    local q = buf_highlighter:get_query(tree:lang())
    311    local query = q:query()
    312 
    313    -- Some injected languages may not have highlight queries.
    314    if not query then
    315      return
    316    end
    317 
    318    local iter = query:iter_captures(root, buf_highlighter.bufnr, row, row + 1)
    319 
    320    for id, node, metadata, match in iter do
    321      if M.is_in_node_range(node, row, col) then
    322        ---@diagnostic disable-next-line: invisible
    323        local capture = query.captures[id] -- name of the capture in the query
    324        if capture ~= nil then
    325          local _, pattern_id = match:info()
    326          table.insert(matches, {
    327            capture = capture,
    328            metadata = metadata,
    329            lang = tree:lang(),
    330            id = id,
    331            pattern_id = pattern_id,
    332          })
    333        end
    334      end
    335    end
    336  end)
    337  return matches
    338 end
    339 
    340 --- Returns a list of highlight capture names under the cursor
    341 ---
    342 ---@param winnr (integer|nil): |window-ID| or 0 for current window (default)
    343 ---
    344 ---@return string[] List of capture names
    345 function M.get_captures_at_cursor(winnr)
    346  winnr = winnr or 0
    347  local bufnr = api.nvim_win_get_buf(winnr)
    348  local cursor = api.nvim_win_get_cursor(winnr)
    349 
    350  local data = M.get_captures_at_pos(bufnr, cursor[1] - 1, cursor[2])
    351 
    352  local captures = {}
    353 
    354  for _, capture in ipairs(data) do
    355    table.insert(captures, capture.capture)
    356  end
    357 
    358  return captures
    359 end
    360 
    361 --- Optional keyword arguments:
    362 --- @class vim.treesitter.get_node.Opts : vim.treesitter.LanguageTree.tree_for_range.Opts
    363 --- @inlinedoc
    364 ---
    365 --- Buffer number (nil or 0 for current buffer)
    366 --- @field bufnr integer?
    367 ---
    368 --- 0-indexed (row, col) tuple. Defaults to cursor position in the
    369 --- current window. Required if {bufnr} is not the current buffer
    370 --- @field pos [integer, integer]?
    371 ---
    372 --- Parser language. (default: from buffer filetype)
    373 --- @field lang string?
    374 ---
    375 --- Ignore injected languages (default true)
    376 --- @field ignore_injections boolean?
    377 ---
    378 --- Include anonymous nodes (default false)
    379 --- @field include_anonymous boolean?
    380 
    381 --- Returns the smallest named node at the given position
    382 ---
    383 --- NOTE: Calling this on an unparsed tree can yield an invalid node.
    384 --- If the tree is not known to be parsed by, e.g., an active highlighter,
    385 --- parse the tree first via
    386 ---
    387 --- ```lua
    388 --- vim.treesitter.get_parser(bufnr):parse(range)
    389 --- ```
    390 ---
    391 ---@param opts vim.treesitter.get_node.Opts?
    392 ---
    393 ---@return TSNode | nil Node at the given position
    394 function M.get_node(opts)
    395  opts = opts or {}
    396 
    397  local bufnr = vim._resolve_bufnr(opts.bufnr)
    398 
    399  local row, col --- @type integer, integer
    400  if opts.pos then
    401    assert(#opts.pos == 2, 'Position must be a (row, col) tuple')
    402    row, col = opts.pos[1], opts.pos[2]
    403  else
    404    assert(
    405      bufnr == api.nvim_get_current_buf(),
    406      'Position must be explicitly provided when not using the current buffer'
    407    )
    408    local pos = api.nvim_win_get_cursor(0)
    409    -- Subtract one to account for 1-based row indexing in nvim_win_get_cursor
    410    row, col = pos[1] - 1, pos[2]
    411  end
    412 
    413  assert(row >= 0 and col >= 0, 'Invalid position: row and col must be non-negative')
    414 
    415  local ts_range = { row, col, row, col }
    416 
    417  local root_lang_tree = M.get_parser(bufnr, opts.lang, { error = false })
    418  if not root_lang_tree then
    419    return
    420  end
    421 
    422  if opts.include_anonymous then
    423    return root_lang_tree:node_for_range(ts_range, opts)
    424  end
    425  return root_lang_tree:named_node_for_range(ts_range, opts)
    426 end
    427 
    428 --- Starts treesitter highlighting for a buffer
    429 ---
    430 --- Can be used in an ftplugin or FileType autocommand.
    431 ---
    432 --- Note: By default, disables regex syntax highlighting, which may be required for some plugins.
    433 --- In this case, add `vim.bo.syntax = 'ON'` after the call to `start`.
    434 ---
    435 --- Note: By default, the highlighter parses code asynchronously, using a segment time of 3ms.
    436 ---
    437 --- Example:
    438 ---
    439 --- ```lua
    440 --- vim.api.nvim_create_autocmd( 'FileType', { pattern = 'tex',
    441 ---     callback = function(args)
    442 ---         vim.treesitter.start(args.buf, 'latex')
    443 ---         vim.bo[args.buf].syntax = 'ON'  -- only if additional legacy syntax is needed
    444 ---     end
    445 --- })
    446 --- ```
    447 ---
    448 ---@param bufnr integer? Buffer to be highlighted (default: current buffer)
    449 ---@param lang string? Language of the parser (default: from buffer filetype)
    450 function M.start(bufnr, lang)
    451  bufnr = vim._resolve_bufnr(bufnr)
    452  -- Ensure buffer is loaded. `:edit` over `bufload()` to show swapfile prompt.
    453  if not api.nvim_buf_is_loaded(bufnr) then
    454    if api.nvim_buf_get_name(bufnr) ~= '' then
    455      pcall(api.nvim_buf_call, bufnr, vim.cmd.edit)
    456    else
    457      vim.fn.bufload(bufnr)
    458    end
    459  end
    460  local parser = assert(M.get_parser(bufnr, lang, { error = false }))
    461  M.highlighter.new(parser)
    462 end
    463 
    464 --- Stops treesitter highlighting for a buffer
    465 ---
    466 ---@param bufnr (integer|nil) Buffer to stop highlighting (default: current buffer)
    467 function M.stop(bufnr)
    468  bufnr = vim._resolve_bufnr(bufnr)
    469 
    470  if M.highlighter.active[bufnr] then
    471    M.highlighter.active[bufnr]:destroy()
    472  end
    473 end
    474 
    475 --- Open a window that displays a textual representation of the nodes in the language tree.
    476 ---
    477 --- While in the window, press "a" to toggle display of anonymous nodes, "I" to toggle the
    478 --- display of the source language of each node, "o" to toggle the query editor, and press
    479 --- [<Enter>] to jump to the node under the cursor in the source buffer. Folding also works
    480 --- (try |zo|, |zc|, etc.).
    481 ---
    482 --- Can also be shown with `:InspectTree`. [:InspectTree]()
    483 ---
    484 ---@since 11
    485 ---@param opts table|nil Optional options table with the following possible keys:
    486 ---                      - lang (string|nil): The language of the source buffer. If omitted, detect
    487 ---                        from the filetype of the source buffer.
    488 ---                      - bufnr (integer|nil): Buffer to draw the tree into. If omitted, a new
    489 ---                        buffer is created.
    490 ---                      - winid (integer|nil): Window id to display the tree buffer in. If omitted,
    491 ---                        a new window is created with {command}.
    492 ---                      - command (string|nil): Vimscript command to create the window. Default
    493 ---                        value is "60vnew". Only used when {winid} is nil.
    494 ---                      - title (string|fun(bufnr:integer):string|nil): Title of the window. If a
    495 ---                        function, it accepts the buffer number of the source buffer as its only
    496 ---                        argument and should return a string.
    497 function M.inspect_tree(opts)
    498  ---@diagnostic disable-next-line: invisible
    499  M.dev.inspect_tree(opts)
    500 end
    501 
    502 --- Returns the fold level for {lnum} in the current buffer. Can be set directly to 'foldexpr':
    503 ---
    504 --- ```lua
    505 --- vim.wo.foldexpr = 'v:lua.vim.treesitter.foldexpr()'
    506 --- ```
    507 ---
    508 ---@since 11
    509 ---@param lnum integer|nil Line number to calculate fold level for
    510 ---@return string
    511 function M.foldexpr(lnum)
    512  return M._fold.foldexpr(lnum)
    513 end
    514 
    515 return M