commit 0977f70f4dd3d14175697d5e6568d4019019506f
parent ee3f9a1e03a93ab2a75a6f7a2f3bb2f4c6b4e736
Author: Riley Bruins <ribru17@hotmail.com>
Date: Sun, 13 Apr 2025 14:22:17 -0700
fix(treesitter): injected lang ranges may cross capture boundaries #32549
Problem:
treesitter injected language ranges sometimes cross over the capture
boundaries when `@combined`.
Solution:
Clip child regions to not spill out of parent regions within
languagetree.lua, and only apply highlights within those regions in
highlighter.lua.
Co-authored-by: Cormac Relf <web@cormacrelf.net>
Diffstat:
4 files changed, 221 insertions(+), 50 deletions(-)
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
@@ -115,6 +115,19 @@ function M.intercepts(r1, r2)
end
---@private
+---@param r1 Range6
+---@param r2 Range6
+---@return Range6?
+function M.intersection(r1, r2)
+ if not M.intercepts(r1, r2) then
+ return nil
+ end
+ local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
+ local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
+ return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
+end
+
+---@private
---@param r Range
---@return integer, integer, integer, integer
function M.unpack4(r)
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua
@@ -322,6 +322,8 @@ local function on_line_impl(self, buf, line, on_spell, on_conceal)
return
end
+ local tree_region = state.tstree:included_ranges(true)
+
if state.iter == nil or state.next_row < line then
-- Mainly used to skip over folds
@@ -336,56 +338,63 @@ local function on_line_impl(self, buf, line, on_spell, on_conceal)
while line >= state.next_row do
local capture, node, metadata, match = state.iter(line)
- local range = { root_end_row + 1, 0, root_end_row + 1, 0 }
+ local outer_range = { root_end_row + 1, 0, root_end_row + 1, 0 }
if node then
- range = vim.treesitter.get_range(node, buf, metadata and metadata[capture])
+ outer_range = vim.treesitter.get_range(node, buf, metadata and metadata[capture])
end
- local start_row, start_col, end_row, end_col = Range.unpack4(range)
-
- if capture then
- local hl = state.highlighter_query:get_hl_from_capture(capture)
-
- local capture_name = captures[capture]
-
- local spell, spell_pri_offset = get_spell(capture_name)
-
- -- The "priority" attribute can be set at the pattern level or on a particular capture
- local priority = (
- tonumber(metadata.priority or metadata[capture] and metadata[capture].priority)
- or vim.hl.priorities.treesitter
- ) + spell_pri_offset
-
- -- The "conceal" attribute can be set at the pattern level or on a particular capture
- local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal
-
- local url = get_url(match, buf, capture, metadata)
-
- if hl and end_row >= line and not on_conceal and (not on_spell or spell ~= nil) then
- api.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
- end_line = end_row,
- end_col = end_col,
- hl_group = hl,
- ephemeral = true,
- priority = priority,
- conceal = conceal,
- spell = spell,
- url = url,
- })
- end
-
- if
- (metadata.conceal_lines or metadata[capture] and metadata[capture].conceal_lines)
- and #api.nvim_buf_get_extmarks(buf, ns, { start_row, 0 }, { start_row, 0 }, {}) == 0
- then
- api.nvim_buf_set_extmark(buf, ns, start_row, 0, {
- end_line = end_row,
- conceal_lines = '',
- })
+ local outer_range_start_row = outer_range[1]
+
+ for _, range in ipairs(tree_region) do
+ local intersection = Range.intersection(range, outer_range)
+ if intersection then
+ local start_row, start_col, end_row, end_col = Range.unpack4(intersection)
+
+ if capture then
+ local hl = state.highlighter_query:get_hl_from_capture(capture)
+
+ local capture_name = captures[capture]
+
+ local spell, spell_pri_offset = get_spell(capture_name)
+
+ -- The "priority" attribute can be set at the pattern level or on a particular capture
+ local priority = (
+ tonumber(metadata.priority or metadata[capture] and metadata[capture].priority)
+ or vim.hl.priorities.treesitter
+ ) + spell_pri_offset
+
+ -- The "conceal" attribute can be set at the pattern level or on a particular capture
+ local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal
+
+ local url = get_url(match, buf, capture, metadata)
+
+ if hl and end_row >= line and not on_conceal and (not on_spell or spell ~= nil) then
+ api.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
+ end_line = end_row,
+ end_col = end_col,
+ hl_group = hl,
+ ephemeral = true,
+ priority = priority,
+ conceal = conceal,
+ spell = spell,
+ url = url,
+ })
+ end
+
+ if
+ (metadata.conceal_lines or metadata[capture] and metadata[capture].conceal_lines)
+ and #api.nvim_buf_get_extmarks(buf, ns, { start_row, 0 }, { start_row, 0 }, {}) == 0
+ then
+ api.nvim_buf_set_extmark(buf, ns, start_row, 0, {
+ end_line = end_row,
+ conceal_lines = '',
+ })
+ end
+ end
end
end
- if start_row > line then
- state.next_row = start_row
+ if outer_range_start_row > line then
+ state.next_row = outer_range_start_row
end
end
end)
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
@@ -874,6 +874,39 @@ local function get_node_ranges(node, source, metadata, include_children)
return ranges
end
+---Finds the intersection between two regions, assuming they are sorted in ascending order by
+---starting point.
+---@param region1 Range6[]
+---@param region2 Range6[]?
+---@return Range6[]
+local function clip_regions(region1, region2)
+ if not region2 then
+ return region1
+ end
+
+ local result = {}
+ local i, j = 1, 1
+
+ while i <= #region1 and j <= #region2 do
+ local r1 = region1[i]
+ local r2 = region2[j]
+
+ local intersection = Range.intersection(r1, r2)
+ if intersection then
+ table.insert(result, intersection)
+ end
+
+ -- Advance the range that ends earlier
+ if Range.cmp_pos.le(r1[3], r1[4], r2[3], r2[4]) then
+ i = i + 1
+ else
+ j = j + 1
+ end
+ end
+
+ return result
+end
+
---@nodoc
---@class vim.treesitter.languagetree.InjectionElem
---@field combined boolean
@@ -886,8 +919,9 @@ end
---@param lang string
---@param combined boolean
---@param ranges Range6[]
+---@param parent_ranges Range6[]?
---@param result table<string,Range6[][]>
-local function add_injection(t, pattern, lang, combined, ranges, result)
+local function add_injection(t, pattern, lang, combined, ranges, parent_ranges, result)
if #ranges == 0 then
-- Make sure not to add an empty range set as this is interpreted to mean the whole buffer.
return
@@ -898,7 +932,7 @@ local function add_injection(t, pattern, lang, combined, ranges, result)
end
if not combined then
- table.insert(result[lang], ranges)
+ table.insert(result[lang], clip_regions(ranges, parent_ranges))
return
end
@@ -914,7 +948,7 @@ local function add_injection(t, pattern, lang, combined, ranges, result)
table.insert(result[lang], regions)
end
- for _, range in ipairs(ranges) do
+ for _, range in ipairs(clip_regions(ranges, parent_ranges)) do
table.insert(t[lang][pattern], range)
end
end
@@ -1007,10 +1041,11 @@ function LanguageTree:_get_injections(range, thread_state)
local full_scan = range == true or self._injection_query.has_combined_injections
- for _, tree in pairs(self._trees) do
+ for tree_index, tree in pairs(self._trees) do
---@type vim.treesitter.languagetree.Injection
local injections = {}
local root_node = tree:root()
+ local parent_ranges = self._regions and self._regions[tree_index] or nil
local start_line, end_line ---@type integer, integer
if full_scan then
start_line, _, end_line = root_node:range()
@@ -1023,7 +1058,7 @@ function LanguageTree:_get_injections(range, thread_state)
do
local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then
- add_injection(injections, pattern, lang, combined, ranges, result)
+ add_injection(injections, pattern, lang, combined, ranges, parent_ranges, result)
else
self:_log('match from injection query failed for pattern', pattern)
end
diff --git a/test/functional/treesitter/highlight_spec.lua b/test/functional/treesitter/highlight_spec.lua
@@ -513,6 +513,120 @@ describe('treesitter highlighting (C)', function()
screen:expect { grid = injection_grid_expected_c }
end)
+ it('supports combined injections #31777', function()
+ insert([=[
+ -- print([[
+ -- some
+ -- random
+ -- text
+ -- here]])
+ ]=])
+
+ exec_lua(function()
+ local parser = vim.treesitter.get_parser(0, 'lua', {
+ injections = {
+ lua = [[
+ ; query
+ ((comment_content) @injection.content
+ (#set! injection.self)
+ (#set! injection.combined))
+ ]],
+ },
+ })
+ local highlighter = vim.treesitter.highlighter
+ highlighter.new(parser, {
+ queries = {
+ lua = [[
+ ; query
+ (string) @string
+ (comment) @comment
+ (function_call (identifier) @function.call)
+ [ "(" ")" ] @punctuation.bracket
+ ]],
+ },
+ })
+ end)
+
+ screen:expect([=[
+ {18:-- }{25:print}{16:(}{26:[[} |
+ {18:--}{26: some} |
+ {18:-- random} |
+ {18:-- text} |
+ {18:-- here]])} |
+ ^ |
+ {1:~ }|*11
+ |
+ ]=])
+ -- NOTE: Once #31777 is fixed, this test case should be updated to the following:
+ -- screen:expect([=[
+ -- {18:-- }{25:print}{16:(}{26:[[} |
+ -- {18:--}{26: some} |
+ -- {18:--}{26: random} |
+ -- {18:--}{26: text} |
+ -- {18:--}{26: here]]}{16:)} |
+ -- ^ |
+ -- {1:~ }|*11
+ -- |
+ -- ]=])
+ end)
+
+ it('supports complicated combined injections', function()
+ insert([[
+ -- # Markdown here
+ --
+ -- ```c
+ -- int main() {
+ -- printf("Hello, world!");
+ -- }
+ -- ```
+ ]])
+
+ exec_lua(function()
+ local parser = vim.treesitter.get_parser(0, 'lua', {
+ injections = {
+ lua = [[
+ ; query
+ ((comment) @injection.content
+ (#offset! @injection.content 0 3 0 1)
+ (#lua-match? @injection.content "[-][-] ")
+ (#set! injection.combined)
+ (#set! injection.include-children)
+ (#set! injection.language "markdown"))
+ ]],
+ },
+ })
+ local highlighter = vim.treesitter.highlighter
+ highlighter.new(parser, {
+ queries = {
+ lua = [[
+ ; query
+ (string) @string
+ (comment) @comment
+ (function_call (identifier) @function.call)
+ [ "(" ")" ] @punctuation.bracket
+ ]],
+ },
+ })
+ end)
+
+ screen:add_extra_attr_ids({
+ [131] = { foreground = Screen.colors.Fuchsia, bold = true },
+ })
+
+ screen:expect([[
+ {18:-- }{131:# Markdown here} |
+ {18:--} |
+ {18:-- ```}{15:c} |
+ {18:-- }{16:int}{18: }{25:main}{16:()}{18: }{16:{} |
+ {18:-- }{25:printf}{16:(}{26:"Hello, world!"}{16:);} |
+ {18:-- }{16:}} |
+ {18:-- ```} |
+ ^ |
+ {1:~ }|*9
+ |
+ ]])
+ end)
+
it("supports injecting by ft name in metadata['injection.language']", function()
insert(injection_text_c)