_fold.lua (15526B)
1 local ts = vim.treesitter 2 3 local Range = require('vim.treesitter._range') 4 5 local api = vim.api 6 7 ---Treesitter folding is done in two steps: 8 ---(1) compute the fold levels with the syntax tree and cache the result (`compute_folds_levels`) 9 ---(2) evaluate foldexpr for each window, which reads from the cache (`foldupdate`) 10 ---@class TS.FoldInfo 11 --- 12 ---@field levels string[] the cached foldexpr result for each line 13 ---@field levels0 integer[] the cached raw fold levels 14 --- 15 ---The range edited since the last invocation of the callback scheduled in on_bytes. 16 ---Should compute fold levels in this range. 17 ---@field on_bytes_range? Range2 18 --- 19 ---The range on which to evaluate foldexpr. 20 ---When in insert mode, the evaluation is deferred to InsertLeave. 21 ---@field foldupdate_range? Range2 22 --- 23 ---The treesitter parser associated with this buffer. 24 ---@field parser? vim.treesitter.LanguageTree 25 local FoldInfo = {} 26 FoldInfo.__index = FoldInfo 27 28 ---@private 29 ---@param bufnr integer 30 function FoldInfo.new(bufnr) 31 return setmetatable({ 32 levels0 = {}, 33 levels = {}, 34 parser = ts.get_parser(bufnr, nil, { error = false }), 35 }, FoldInfo) 36 end 37 38 ---@package 39 ---@param srow integer 40 ---@param erow integer 0-indexed, exclusive 41 function FoldInfo:remove_range(srow, erow) 42 vim._list_remove(self.levels, srow + 1, erow) 43 vim._list_remove(self.levels0, srow + 1, erow) 44 end 45 46 ---@package 47 ---@param srow integer 48 ---@param erow integer 0-indexed, exclusive 49 function FoldInfo:add_range(srow, erow) 50 vim._list_insert(self.levels, srow + 1, erow, -1) 51 vim._list_insert(self.levels0, srow + 1, erow, -1) 52 end 53 54 ---@param range Range2 55 ---@param srow integer 56 ---@param erow_old integer 57 ---@param erow_new integer 0-indexed, exclusive 58 local function edit_range(range, srow, erow_old, erow_new) 59 range[1] = math.min(srow, range[1]) 60 if erow_old <= range[2] then 61 range[2] = range[2] + (erow_new - erow_old) 62 end 63 range[2] = math.max(range[2], erow_new) 64 end 65 66 -- TODO(lewis6991): Setup a decor provider so injections folds can be parsed 67 -- as the window is redrawn 68 ---@param bufnr integer 69 ---@param info TS.FoldInfo 70 ---@param srow integer? 71 ---@param erow integer? 0-indexed, exclusive 72 ---@param callback function? 73 local function compute_folds_levels(bufnr, info, srow, erow, callback) 74 srow = srow or 0 75 erow = erow or api.nvim_buf_line_count(bufnr) 76 77 local parser = info.parser 78 if 79 not parser 80 -- Parsing an empty buffer results in problems with the parsing state, 81 -- resulting in both a broken highlighter and foldexpr. 82 or api.nvim_buf_line_count(bufnr) == 1 83 and api.nvim_buf_call(bufnr, function() 84 return vim.fn.line2byte(1) <= 0 85 end) 86 then 87 return 88 end 89 90 parser:parse(nil, function(_, trees) 91 if not trees then 92 return 93 end 94 95 local enter_counts = {} ---@type table<integer, integer> 96 local leave_counts = {} ---@type table<integer, integer> 97 local prev_start = -1 98 local prev_stop = -1 99 100 parser:for_each_tree(function(tree, ltree) 101 local query = ts.query.get(ltree:lang(), 'folds') 102 if not query then 103 return 104 end 105 106 -- Collect folds starting from srow - 1, because we should first subtract the folds that end at 107 -- srow - 1 from the level of srow - 1 to get accurate level of srow. 108 for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do 109 for id, nodes in pairs(match) do 110 if query.captures[id] == 'fold' then 111 local range = ts.get_range(nodes[1], bufnr, metadata[id]) 112 local start, _, stop, stop_col = Range.unpack4(range) 113 114 if #nodes > 1 then 115 -- assumes nodes are ordered by range 116 local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id]) 117 local _, _, end_stop, end_stop_col = Range.unpack4(end_range) 118 stop = end_stop 119 stop_col = end_stop_col 120 end 121 122 if stop_col == 0 then 123 stop = stop - 1 124 end 125 126 local fold_length = stop - start + 1 127 128 -- Fold only multiline nodes that are not exactly the same as previously met folds 129 -- Checking against just the previously found fold is sufficient if nodes 130 -- are returned in preorder or postorder when traversing tree 131 if 132 fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) 133 then 134 enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 135 leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 136 prev_start = start 137 prev_stop = stop 138 end 139 end 140 end 141 end 142 end) 143 144 local nestmax = vim.wo.foldnestmax 145 local level0_prev = info.levels0[srow] or 0 146 local leave_prev = leave_counts[srow] or 0 147 148 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start 149 for lnum = srow + 1, erow do 150 local enter_line = enter_counts[lnum] or 0 151 local leave_line = leave_counts[lnum] or 0 152 local level0 = level0_prev - leave_prev + enter_line 153 154 -- Determine if it's the start/end of a fold 155 -- NB: vim's fold-expr interface does not have a mechanism to indicate that 156 -- two (or more) folds start at this line, so it cannot distinguish between 157 -- ( \n ( \n )) \n (( \n ) \n ) 158 -- versus 159 -- ( \n ( \n ) \n ( \n ) \n ) 160 -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and 161 -- vim interprets as the second case. 162 -- If it did have such a mechanism, (clamped - clamped_prev) 163 -- would be the correct number of starts to pass on. 164 local adjusted = level0 ---@type integer 165 local prefix = '' 166 if enter_line > 0 then 167 prefix = '>' 168 if leave_line > 0 then 169 -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line 170 -- so that f2 gets the correct level on this line. This may reduce the size of f1 below 171 -- foldminlines, but we don't handle it for simplicity. 172 --- @type integer avoid flaky error 173 adjusted = level0 - leave_line 174 leave_line = 0 175 end 176 end 177 178 -- Clamp at foldnestmax. 179 local clamped = adjusted 180 if adjusted > nestmax then 181 prefix = '' 182 clamped = nestmax 183 end 184 185 -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). 186 info.levels0[lnum] = adjusted 187 info.levels[lnum] = prefix .. tostring(clamped) 188 189 leave_prev = leave_line 190 level0_prev = adjusted 191 end 192 193 if callback then 194 callback() 195 end 196 end) 197 end 198 199 local M = {} 200 201 ---@type table<integer,TS.FoldInfo> 202 local foldinfos = {} 203 204 local group = api.nvim_create_augroup('nvim.treesitter.fold', {}) 205 206 --- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that 207 --- the user doesn't use different foldexpr for the same buffer). 208 --- 209 --- Nvim usually automatically updates folds when text changes, but it doesn't work here because 210 --- FoldInfo update is scheduled. So we do it manually. 211 ---@package 212 ---@param srow integer 213 ---@param erow integer 0-indexed, exclusive 214 function FoldInfo:foldupdate(bufnr, srow, erow) 215 if self.foldupdate_range then 216 edit_range(self.foldupdate_range, srow, erow, erow) 217 else 218 self.foldupdate_range = { srow, erow } 219 end 220 221 if api.nvim_get_mode().mode:match('^i') then 222 -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave 223 if #(api.nvim_get_autocmds({ 224 group = group, 225 buffer = bufnr, 226 })) > 0 then 227 return 228 end 229 api.nvim_create_autocmd('InsertLeave', { 230 group = group, 231 buffer = bufnr, 232 once = true, 233 callback = function() 234 self:do_foldupdate(bufnr) 235 end, 236 }) 237 return 238 end 239 240 self:do_foldupdate(bufnr) 241 end 242 243 ---@package 244 function FoldInfo:do_foldupdate(bufnr) 245 -- InsertLeave is not executed when <C-C> is used for exiting the insert mode, leaving 246 -- do_foldupdate untouched. If another execution of foldupdate consumes foldupdate_range, the 247 -- InsertLeave do_foldupdate gets nil foldupdate_range. In that case, skip the update. This is 248 -- correct because the update that consumed the range must have incorporated the range that 249 -- InsertLeave meant to update. 250 if not self.foldupdate_range then 251 return 252 end 253 254 local srow, erow = self.foldupdate_range[1], self.foldupdate_range[2] 255 self.foldupdate_range = nil 256 for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do 257 if vim.wo[win].foldmethod == 'expr' then 258 vim._foldupdate(win, srow, erow) 259 end 260 end 261 end 262 263 --- Schedule a function only if bufnr is loaded. 264 --- We schedule fold level computation for the following reasons: 265 --- * queries seem to use the old buffer state in on_bytes for some unknown reason; 266 --- * to avoid textlock; 267 --- * to avoid infinite recursion: 268 --- compute_folds_levels → parse → _do_callback → on_changedtree → compute_folds_levels. 269 ---@param bufnr integer 270 ---@param fn function 271 local function schedule_if_loaded(bufnr, fn) 272 vim.schedule(function() 273 if not api.nvim_buf_is_loaded(bufnr) then 274 return 275 end 276 fn() 277 end) 278 end 279 280 ---@param bufnr integer 281 ---@param tree_changes Range4[] 282 local function on_changedtree(bufnr, tree_changes) 283 schedule_if_loaded(bufnr, function() 284 -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked. 285 local foldinfo = foldinfos[bufnr] 286 if not foldinfo then 287 return 288 end 289 290 local srow_upd, erow_upd ---@type integer?, integer? 291 local max_erow = api.nvim_buf_line_count(bufnr) 292 -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved 293 local iterations = 0 294 for _, change in ipairs(tree_changes) do 295 local srow, _, erow, ecol = Range.unpack4(change) 296 -- If a parser doesn't have any ranges explicitly set, treesitter will 297 -- return a range with end_row and end_bytes with a value of UINT32_MAX, 298 -- so clip end_row to the max buffer line. 299 -- TODO(lewis6991): Handle this generally 300 if erow > max_erow then 301 erow = max_erow 302 elseif ecol > 0 then 303 erow = erow + 1 304 end 305 -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. 306 srow = math.max(srow - vim.wo.foldminlines, 0) 307 srow_upd = srow_upd and math.min(srow_upd, srow) or srow 308 erow_upd = erow_upd and math.max(erow_upd, erow) or erow 309 compute_folds_levels(bufnr, foldinfo, srow, erow, function() 310 iterations = iterations + 1 311 if iterations == #tree_changes then 312 foldinfo:foldupdate(bufnr, srow_upd, erow_upd) 313 end 314 end) 315 end 316 end) 317 end 318 319 ---@param bufnr integer 320 ---@param start_row integer 321 ---@param old_row integer 322 ---@param old_col integer 323 ---@param new_row integer 324 ---@param new_col integer 325 local function on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col) 326 -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked. 327 local foldinfo = foldinfos[bufnr] 328 if not foldinfo then 329 return 330 end 331 332 -- extend the end to fully include the range 333 local end_row_old = start_row + old_row + 1 334 local end_row_new = start_row + new_row + 1 335 336 if new_row ~= old_row then 337 -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the 338 -- outdated levels, which may spuriously open the folds that didn't change. So we should shift 339 -- folds as accurately as possible. For this to be perfectly accurate, we should track the 340 -- actual TSNodes that account for each fold, and compare the node's range with the edited 341 -- range. But for simplicity, we just check whether the start row is completely removed (e.g., 342 -- `dd`) or shifted (e.g., `o`). 343 if new_row < old_row then 344 if start_col == 0 and new_row == 0 and new_col == 0 then 345 foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new)) 346 else 347 foldinfo:remove_range(end_row_new, end_row_old) 348 end 349 else 350 if start_col == 0 and old_row == 0 and old_col == 0 then 351 foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old)) 352 else 353 foldinfo:add_range(end_row_old, end_row_new) 354 end 355 end 356 357 if foldinfo.on_bytes_range then 358 edit_range(foldinfo.on_bytes_range, start_row, end_row_old, end_row_new) 359 else 360 foldinfo.on_bytes_range = { start_row, end_row_new } 361 end 362 if foldinfo.foldupdate_range then 363 edit_range(foldinfo.foldupdate_range, start_row, end_row_old, end_row_new) 364 end 365 366 -- This callback must not use on_bytes arguments, because they can be outdated when the callback 367 -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing 368 -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`. 369 schedule_if_loaded(bufnr, function() 370 if not (foldinfo.on_bytes_range and foldinfos[bufnr]) then 371 return 372 end 373 local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2] 374 foldinfo.on_bytes_range = nil 375 -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. 376 srow = math.max(srow - vim.wo.foldminlines, 0) 377 compute_folds_levels(bufnr, foldinfo, srow, erow, function() 378 foldinfo:foldupdate(bufnr, srow, erow) 379 end) 380 end) 381 end 382 end 383 384 local registered_cbs = {} ---@type table<integer, boolean> 385 386 ---@param lnum integer|nil 387 ---@return string 388 function M.foldexpr(lnum) 389 lnum = lnum or vim.v.lnum 390 local bufnr = api.nvim_get_current_buf() 391 392 if not foldinfos[bufnr] then 393 foldinfos[bufnr] = FoldInfo.new(bufnr) 394 api.nvim_create_autocmd({ 'BufUnload', 'VimEnter', 'FileType' }, { 395 buffer = bufnr, 396 once = true, 397 callback = function() 398 foldinfos[bufnr] = nil 399 end, 400 }) 401 402 local parser = foldinfos[bufnr].parser 403 if not parser then 404 return '0' 405 end 406 407 compute_folds_levels(bufnr, foldinfos[bufnr]) 408 409 if not registered_cbs[bufnr] then 410 parser:register_cbs({ 411 on_changedtree = function(tree_changes) 412 on_changedtree(bufnr, tree_changes) 413 end, 414 415 on_bytes = function( 416 _, 417 _, 418 start_row, 419 start_col, 420 _, 421 old_row, 422 old_col, 423 _, 424 new_row, 425 new_col, 426 _ 427 ) 428 on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col) 429 end, 430 431 on_detach = function() 432 foldinfos[bufnr] = nil 433 registered_cbs[bufnr] = nil 434 end, 435 }) 436 437 registered_cbs[bufnr] = true 438 end 439 end 440 441 return foldinfos[bufnr].levels[lnum] or '0' 442 end 443 444 api.nvim_create_autocmd('OptionSet', { 445 pattern = { 'foldminlines', 'foldnestmax' }, 446 desc = 'Refresh treesitter folds', 447 callback = function() 448 local buf = api.nvim_get_current_buf() 449 local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos) 450 or foldinfos[buf] and { buf } 451 or {} 452 for _, bufnr in ipairs(bufs) do 453 foldinfos[bufnr] = FoldInfo.new(bufnr) 454 api.nvim_buf_call(bufnr, function() 455 compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function() 456 foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) 457 end) 458 end) 459 end 460 end, 461 }) 462 return M