neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

_query_linter.lua (8024B)


      1 local api = vim.api
      2 
      3 local namespace = api.nvim_create_namespace('nvim.treesitter.query_linter')
      4 
      5 local M = {}
      6 
      7 --- @class QueryLinterNormalizedOpts
      8 --- @field langs string[]
      9 --- @field clear boolean
     10 
     11 --- @alias vim.treesitter.ParseError {msg: string, range: Range4}
     12 
     13 --- Contains language dependent context for the query linter
     14 --- @class QueryLinterLanguageContext
     15 --- @field lang string? Current `lang` of the targeted parser
     16 --- @field parser_info table? Parser info returned by vim.treesitter.language.inspect
     17 --- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs`
     18 
     19 --- Adds a diagnostic for node in the query buffer
     20 --- @param diagnostics vim.Diagnostic.Set[]
     21 --- @param range Range4
     22 --- @param lint string
     23 --- @param lang string?
     24 local function add_lint_for_node(diagnostics, range, lint, lang)
     25  local message = lint:gsub('\n', ' ')
     26  diagnostics[#diagnostics + 1] = {
     27    lnum = range[1],
     28    end_lnum = range[3],
     29    col = range[2],
     30    end_col = range[4],
     31    severity = vim.diagnostic.ERROR,
     32    message = message,
     33    source = lang,
     34  }
     35 end
     36 
     37 --- Determines the target language of a query file by its path: <lang>/<query_type>.scm
     38 --- @param buf integer
     39 --- @return string?
     40 local function guess_query_lang(buf)
     41  local filename = api.nvim_buf_get_name(buf)
     42  if filename ~= '' then
     43    -- get <lang> from /path/<lang>/<query_type>.scm
     44    local resolved = vim.fs.basename(vim.fs.dirname(vim.fs.abspath(filename)))
     45    return vim.treesitter.language.get_lang(resolved)
     46  end
     47 end
     48 
     49 --- @param buf integer
     50 --- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil
     51 --- @return QueryLinterNormalizedOpts
     52 local function normalize_opts(buf, opts)
     53  opts = opts or {}
     54  if not opts.langs then
     55    opts.langs = guess_query_lang(buf)
     56  end
     57 
     58  if type(opts.langs) ~= 'table' then
     59    --- @diagnostic disable-next-line:assign-type-mismatch
     60    opts.langs = { opts.langs }
     61  end
     62 
     63  --- @cast opts QueryLinterNormalizedOpts
     64  opts.langs = opts.langs or {}
     65  return opts
     66 end
     67 
     68 local lint_query = [[;; query
     69  (program [(named_node) (anonymous_node) (list) (grouping)] @toplevel)
     70  (named_node
     71    name: _ @node.named)
     72  (anonymous_node
     73    name: _ @node.anonymous)
     74  (field_definition
     75    name: (identifier) @field)
     76  (predicate
     77    name: (identifier) @predicate.name
     78    type: (predicate_type) @predicate.type)
     79  (ERROR) @error
     80 ]]
     81 
     82 --- @param err string
     83 --- @param node TSNode
     84 --- @return vim.treesitter.ParseError
     85 local function get_error_entry(err, node)
     86  local start_line, start_col = node:range()
     87  local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string
     88  start_line, start_col =
     89    start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1
     90  local end_line, end_col = start_line, start_col
     91  if msg:match('^Invalid syntax') or msg:match('^Impossible') then
     92    -- Use the length of the underlined node
     93    local underlined = vim.split(err, '\n')[2]
     94    end_col = end_col + #underlined
     95  elseif msg:match('^Invalid') then
     96    -- Use the length of the problematic type/capture/field
     97    end_col = end_col + #(msg:match('"([^"]+)"') or '')
     98  end
     99 
    100  return {
    101    msg = msg,
    102    range = { start_line, start_col, end_line, end_col },
    103  }
    104 end
    105 
    106 --- @param node TSNode
    107 --- @param buf integer
    108 --- @param lang string
    109 local function hash_parse(node, buf, lang)
    110  return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang
    111 end
    112 
    113 --- @param node TSNode
    114 --- @param buf integer
    115 --- @param lang string
    116 --- @return vim.treesitter.ParseError?
    117 local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
    118  local query_text = vim.treesitter.get_node_text(node, buf)
    119  local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|vim.treesitter.Query
    120 
    121  if not ok and type(err) == 'string' then
    122    return get_error_entry(err, node)
    123  end
    124 end)
    125 
    126 --- @param buf integer
    127 --- @param match table<integer,TSNode[]>
    128 --- @param query vim.treesitter.Query
    129 --- @param lang_context QueryLinterLanguageContext
    130 --- @param diagnostics vim.Diagnostic.Set[]
    131 local function lint_match(buf, match, query, lang_context, diagnostics)
    132  local lang = lang_context.lang
    133  local parser_info = lang_context.parser_info
    134 
    135  for id, nodes in pairs(match) do
    136    for _, node in ipairs(nodes) do
    137      local cap_id = query.captures[id]
    138 
    139      -- perform language-independent checks only for first lang
    140      if lang_context.is_first_lang and cap_id == 'error' then
    141        local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
    142        ---@diagnostic disable-next-line: missing-fields LuaLS varargs bug
    143        local range = { node:range() } --- @type Range4
    144        add_lint_for_node(diagnostics, range, 'Syntax error: ' .. node_text)
    145      end
    146 
    147      -- other checks rely on Neovim parser introspection
    148      if lang and parser_info and cap_id == 'toplevel' then
    149        local err = parse(node, buf, lang)
    150        if err then
    151          add_lint_for_node(diagnostics, err.range, err.msg, lang)
    152        end
    153      end
    154    end
    155  end
    156 end
    157 
    158 --- @param buf integer Buffer to lint
    159 --- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil Options for linting
    160 function M.lint(buf, opts)
    161  if buf == 0 then
    162    buf = api.nvim_get_current_buf()
    163  end
    164 
    165  local diagnostics = {} --- @type vim.Diagnostic.Set[]
    166  local query = vim.treesitter.query.parse('query', lint_query)
    167 
    168  opts = normalize_opts(buf, opts)
    169 
    170  -- perform at least one iteration even with no langs to perform language independent checks
    171  for i = 1, math.max(1, #opts.langs) do
    172    local lang = opts.langs[i]
    173 
    174    --- @type (table|nil)
    175    local parser_info = vim.F.npcall(vim.treesitter.language.inspect, lang)
    176    local lang_context = {
    177      lang = lang,
    178      parser_info = parser_info,
    179      is_first_lang = i == 1,
    180    }
    181 
    182    local parser = assert(vim.treesitter.get_parser(buf, nil, { error = false }))
    183    parser:parse()
    184    parser:for_each_tree(function(tree, ltree)
    185      if ltree:lang() == 'query' then
    186        for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do
    187          lint_match(buf, match, query, lang_context, diagnostics)
    188        end
    189      end
    190    end)
    191  end
    192 
    193  vim.diagnostic.set(namespace, buf, diagnostics)
    194 end
    195 
    196 --- @param buf integer
    197 function M.clear(buf)
    198  vim.diagnostic.reset(namespace, buf)
    199 end
    200 
    201 --- @param findstart 0|1
    202 --- @param base string
    203 function M.omnifunc(findstart, base)
    204  if findstart == 1 then
    205    local result =
    206      api.nvim_get_current_line():sub(1, api.nvim_win_get_cursor(0)[2]):find('["#%-%w]*$')
    207    return result - 1
    208  end
    209 
    210  local buf = api.nvim_get_current_buf()
    211  local query_lang = guess_query_lang(buf)
    212 
    213  local ok, parser_info = pcall(vim.treesitter.language.inspect, query_lang)
    214  if not ok then
    215    return -2
    216  end
    217 
    218  local items = {}
    219  for _, f in pairs(parser_info.fields) do
    220    if f:find(base, 1, true) then
    221      table.insert(items, f .. ':')
    222    end
    223  end
    224  for _, p in pairs(vim.treesitter.query.list_predicates()) do
    225    local text = '#' .. p
    226    local found = text:find(base, 1, true)
    227    if found and found <= 2 then -- with or without '#'
    228      table.insert(items, text)
    229    end
    230    text = '#not-' .. p
    231    found = text:find(base, 1, true)
    232    if found and found <= 2 then -- with or without '#'
    233      table.insert(items, text)
    234    end
    235  end
    236  for _, p in pairs(vim.treesitter.query.list_directives()) do
    237    local text = '#' .. p
    238    local found = text:find(base, 1, true)
    239    if found and found <= 2 then -- with or without '#'
    240      table.insert(items, text)
    241    end
    242  end
    243  for text, named in
    244    pairs(parser_info.symbols --[[@as table<string, boolean>]])
    245  do
    246    if not named then
    247      text = string.format('%q', text:sub(2, -2)):gsub('\n', 'n') ---@type string
    248    end
    249    if text:find(base, 1, true) then
    250      table.insert(items, text)
    251    end
    252  end
    253  return { words = items, refresh = 'always' }
    254 end
    255 
    256 return M