_query_linter.lua (8024B)
1 local api = vim.api 2 3 local namespace = api.nvim_create_namespace('nvim.treesitter.query_linter') 4 5 local M = {} 6 7 --- @class QueryLinterNormalizedOpts 8 --- @field langs string[] 9 --- @field clear boolean 10 11 --- @alias vim.treesitter.ParseError {msg: string, range: Range4} 12 13 --- Contains language dependent context for the query linter 14 --- @class QueryLinterLanguageContext 15 --- @field lang string? Current `lang` of the targeted parser 16 --- @field parser_info table? Parser info returned by vim.treesitter.language.inspect 17 --- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs` 18 19 --- Adds a diagnostic for node in the query buffer 20 --- @param diagnostics vim.Diagnostic.Set[] 21 --- @param range Range4 22 --- @param lint string 23 --- @param lang string? 24 local function add_lint_for_node(diagnostics, range, lint, lang) 25 local message = lint:gsub('\n', ' ') 26 diagnostics[#diagnostics + 1] = { 27 lnum = range[1], 28 end_lnum = range[3], 29 col = range[2], 30 end_col = range[4], 31 severity = vim.diagnostic.ERROR, 32 message = message, 33 source = lang, 34 } 35 end 36 37 --- Determines the target language of a query file by its path: <lang>/<query_type>.scm 38 --- @param buf integer 39 --- @return string? 40 local function guess_query_lang(buf) 41 local filename = api.nvim_buf_get_name(buf) 42 if filename ~= '' then 43 -- get <lang> from /path/<lang>/<query_type>.scm 44 local resolved = vim.fs.basename(vim.fs.dirname(vim.fs.abspath(filename))) 45 return vim.treesitter.language.get_lang(resolved) 46 end 47 end 48 49 --- @param buf integer 50 --- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil 51 --- @return QueryLinterNormalizedOpts 52 local function normalize_opts(buf, opts) 53 opts = opts or {} 54 if not opts.langs then 55 opts.langs = guess_query_lang(buf) 56 end 57 58 if type(opts.langs) ~= 'table' then 59 --- @diagnostic disable-next-line:assign-type-mismatch 60 opts.langs = { opts.langs } 61 end 62 63 --- @cast opts QueryLinterNormalizedOpts 64 opts.langs = opts.langs or {} 65 return opts 66 end 67 68 local lint_query = [[;; query 69 (program [(named_node) (anonymous_node) (list) (grouping)] @toplevel) 70 (named_node 71 name: _ @node.named) 72 (anonymous_node 73 name: _ @node.anonymous) 74 (field_definition 75 name: (identifier) @field) 76 (predicate 77 name: (identifier) @predicate.name 78 type: (predicate_type) @predicate.type) 79 (ERROR) @error 80 ]] 81 82 --- @param err string 83 --- @param node TSNode 84 --- @return vim.treesitter.ParseError 85 local function get_error_entry(err, node) 86 local start_line, start_col = node:range() 87 local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string 88 start_line, start_col = 89 start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1 90 local end_line, end_col = start_line, start_col 91 if msg:match('^Invalid syntax') or msg:match('^Impossible') then 92 -- Use the length of the underlined node 93 local underlined = vim.split(err, '\n')[2] 94 end_col = end_col + #underlined 95 elseif msg:match('^Invalid') then 96 -- Use the length of the problematic type/capture/field 97 end_col = end_col + #(msg:match('"([^"]+)"') or '') 98 end 99 100 return { 101 msg = msg, 102 range = { start_line, start_col, end_line, end_col }, 103 } 104 end 105 106 --- @param node TSNode 107 --- @param buf integer 108 --- @param lang string 109 local function hash_parse(node, buf, lang) 110 return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang 111 end 112 113 --- @param node TSNode 114 --- @param buf integer 115 --- @param lang string 116 --- @return vim.treesitter.ParseError? 117 local parse = vim.func._memoize(hash_parse, function(node, buf, lang) 118 local query_text = vim.treesitter.get_node_text(node, buf) 119 local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|vim.treesitter.Query 120 121 if not ok and type(err) == 'string' then 122 return get_error_entry(err, node) 123 end 124 end) 125 126 --- @param buf integer 127 --- @param match table<integer,TSNode[]> 128 --- @param query vim.treesitter.Query 129 --- @param lang_context QueryLinterLanguageContext 130 --- @param diagnostics vim.Diagnostic.Set[] 131 local function lint_match(buf, match, query, lang_context, diagnostics) 132 local lang = lang_context.lang 133 local parser_info = lang_context.parser_info 134 135 for id, nodes in pairs(match) do 136 for _, node in ipairs(nodes) do 137 local cap_id = query.captures[id] 138 139 -- perform language-independent checks only for first lang 140 if lang_context.is_first_lang and cap_id == 'error' then 141 local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') 142 ---@diagnostic disable-next-line: missing-fields LuaLS varargs bug 143 local range = { node:range() } --- @type Range4 144 add_lint_for_node(diagnostics, range, 'Syntax error: ' .. node_text) 145 end 146 147 -- other checks rely on Neovim parser introspection 148 if lang and parser_info and cap_id == 'toplevel' then 149 local err = parse(node, buf, lang) 150 if err then 151 add_lint_for_node(diagnostics, err.range, err.msg, lang) 152 end 153 end 154 end 155 end 156 end 157 158 --- @param buf integer Buffer to lint 159 --- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil Options for linting 160 function M.lint(buf, opts) 161 if buf == 0 then 162 buf = api.nvim_get_current_buf() 163 end 164 165 local diagnostics = {} --- @type vim.Diagnostic.Set[] 166 local query = vim.treesitter.query.parse('query', lint_query) 167 168 opts = normalize_opts(buf, opts) 169 170 -- perform at least one iteration even with no langs to perform language independent checks 171 for i = 1, math.max(1, #opts.langs) do 172 local lang = opts.langs[i] 173 174 --- @type (table|nil) 175 local parser_info = vim.F.npcall(vim.treesitter.language.inspect, lang) 176 local lang_context = { 177 lang = lang, 178 parser_info = parser_info, 179 is_first_lang = i == 1, 180 } 181 182 local parser = assert(vim.treesitter.get_parser(buf, nil, { error = false })) 183 parser:parse() 184 parser:for_each_tree(function(tree, ltree) 185 if ltree:lang() == 'query' then 186 for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do 187 lint_match(buf, match, query, lang_context, diagnostics) 188 end 189 end 190 end) 191 end 192 193 vim.diagnostic.set(namespace, buf, diagnostics) 194 end 195 196 --- @param buf integer 197 function M.clear(buf) 198 vim.diagnostic.reset(namespace, buf) 199 end 200 201 --- @param findstart 0|1 202 --- @param base string 203 function M.omnifunc(findstart, base) 204 if findstart == 1 then 205 local result = 206 api.nvim_get_current_line():sub(1, api.nvim_win_get_cursor(0)[2]):find('["#%-%w]*$') 207 return result - 1 208 end 209 210 local buf = api.nvim_get_current_buf() 211 local query_lang = guess_query_lang(buf) 212 213 local ok, parser_info = pcall(vim.treesitter.language.inspect, query_lang) 214 if not ok then 215 return -2 216 end 217 218 local items = {} 219 for _, f in pairs(parser_info.fields) do 220 if f:find(base, 1, true) then 221 table.insert(items, f .. ':') 222 end 223 end 224 for _, p in pairs(vim.treesitter.query.list_predicates()) do 225 local text = '#' .. p 226 local found = text:find(base, 1, true) 227 if found and found <= 2 then -- with or without '#' 228 table.insert(items, text) 229 end 230 text = '#not-' .. p 231 found = text:find(base, 1, true) 232 if found and found <= 2 then -- with or without '#' 233 table.insert(items, text) 234 end 235 end 236 for _, p in pairs(vim.treesitter.query.list_directives()) do 237 local text = '#' .. p 238 local found = text:find(base, 1, true) 239 if found and found <= 2 then -- with or without '#' 240 table.insert(items, text) 241 end 242 end 243 for text, named in 244 pairs(parser_info.symbols --[[@as table<string, boolean>]]) 245 do 246 if not named then 247 text = string.format('%q', text:sub(2, -2)):gsub('\n', 'n') ---@type string 248 end 249 if text:find(base, 1, true) then 250 table.insert(items, text) 251 end 252 end 253 return { words = items, refresh = 'always' } 254 end 255 256 return M