neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

_fold.lua (15526B)


      1 local ts = vim.treesitter
      2 
      3 local Range = require('vim.treesitter._range')
      4 
      5 local api = vim.api
      6 
      7 ---Treesitter folding is done in two steps:
      8 ---(1) compute the fold levels with the syntax tree and cache the result (`compute_folds_levels`)
      9 ---(2) evaluate foldexpr for each window, which reads from the cache (`foldupdate`)
     10 ---@class TS.FoldInfo
     11 ---
     12 ---@field levels string[] the cached foldexpr result for each line
     13 ---@field levels0 integer[] the cached raw fold levels
     14 ---
     15 ---The range edited since the last invocation of the callback scheduled in on_bytes.
     16 ---Should compute fold levels in this range.
     17 ---@field on_bytes_range? Range2
     18 ---
     19 ---The range on which to evaluate foldexpr.
     20 ---When in insert mode, the evaluation is deferred to InsertLeave.
     21 ---@field foldupdate_range? Range2
     22 ---
     23 ---The treesitter parser associated with this buffer.
     24 ---@field parser? vim.treesitter.LanguageTree
     25 local FoldInfo = {}
     26 FoldInfo.__index = FoldInfo
     27 
     28 ---@private
     29 ---@param bufnr integer
     30 function FoldInfo.new(bufnr)
     31  return setmetatable({
     32    levels0 = {},
     33    levels = {},
     34    parser = ts.get_parser(bufnr, nil, { error = false }),
     35  }, FoldInfo)
     36 end
     37 
     38 ---@package
     39 ---@param srow integer
     40 ---@param erow integer 0-indexed, exclusive
     41 function FoldInfo:remove_range(srow, erow)
     42  vim._list_remove(self.levels, srow + 1, erow)
     43  vim._list_remove(self.levels0, srow + 1, erow)
     44 end
     45 
     46 ---@package
     47 ---@param srow integer
     48 ---@param erow integer 0-indexed, exclusive
     49 function FoldInfo:add_range(srow, erow)
     50  vim._list_insert(self.levels, srow + 1, erow, -1)
     51  vim._list_insert(self.levels0, srow + 1, erow, -1)
     52 end
     53 
     54 ---@param range Range2
     55 ---@param srow integer
     56 ---@param erow_old integer
     57 ---@param erow_new integer 0-indexed, exclusive
     58 local function edit_range(range, srow, erow_old, erow_new)
     59  range[1] = math.min(srow, range[1])
     60  if erow_old <= range[2] then
     61    range[2] = range[2] + (erow_new - erow_old)
     62  end
     63  range[2] = math.max(range[2], erow_new)
     64 end
     65 
     66 -- TODO(lewis6991): Setup a decor provider so injections folds can be parsed
     67 -- as the window is redrawn
     68 ---@param bufnr integer
     69 ---@param info TS.FoldInfo
     70 ---@param srow integer?
     71 ---@param erow integer? 0-indexed, exclusive
     72 ---@param callback function?
     73 local function compute_folds_levels(bufnr, info, srow, erow, callback)
     74  srow = srow or 0
     75  erow = erow or api.nvim_buf_line_count(bufnr)
     76 
     77  local parser = info.parser
     78  if
     79    not parser
     80    -- Parsing an empty buffer results in problems with the parsing state,
     81    -- resulting in both a broken highlighter and foldexpr.
     82    or api.nvim_buf_line_count(bufnr) == 1
     83      and api.nvim_buf_call(bufnr, function()
     84        return vim.fn.line2byte(1) <= 0
     85      end)
     86  then
     87    return
     88  end
     89 
     90  parser:parse(nil, function(_, trees)
     91    if not trees then
     92      return
     93    end
     94 
     95    local enter_counts = {} ---@type table<integer, integer>
     96    local leave_counts = {} ---@type table<integer, integer>
     97    local prev_start = -1
     98    local prev_stop = -1
     99 
    100    parser:for_each_tree(function(tree, ltree)
    101      local query = ts.query.get(ltree:lang(), 'folds')
    102      if not query then
    103        return
    104      end
    105 
    106      -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
    107      -- srow - 1 from the level of srow - 1 to get accurate level of srow.
    108      for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
    109        for id, nodes in pairs(match) do
    110          if query.captures[id] == 'fold' then
    111            local range = ts.get_range(nodes[1], bufnr, metadata[id])
    112            local start, _, stop, stop_col = Range.unpack4(range)
    113 
    114            if #nodes > 1 then
    115              -- assumes nodes are ordered by range
    116              local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id])
    117              local _, _, end_stop, end_stop_col = Range.unpack4(end_range)
    118              stop = end_stop
    119              stop_col = end_stop_col
    120            end
    121 
    122            if stop_col == 0 then
    123              stop = stop - 1
    124            end
    125 
    126            local fold_length = stop - start + 1
    127 
    128            -- Fold only multiline nodes that are not exactly the same as previously met folds
    129            -- Checking against just the previously found fold is sufficient if nodes
    130            -- are returned in preorder or postorder when traversing tree
    131            if
    132              fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
    133            then
    134              enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
    135              leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
    136              prev_start = start
    137              prev_stop = stop
    138            end
    139          end
    140        end
    141      end
    142    end)
    143 
    144    local nestmax = vim.wo.foldnestmax
    145    local level0_prev = info.levels0[srow] or 0
    146    local leave_prev = leave_counts[srow] or 0
    147 
    148    -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
    149    for lnum = srow + 1, erow do
    150      local enter_line = enter_counts[lnum] or 0
    151      local leave_line = leave_counts[lnum] or 0
    152      local level0 = level0_prev - leave_prev + enter_line
    153 
    154      -- Determine if it's the start/end of a fold
    155      -- NB: vim's fold-expr interface does not have a mechanism to indicate that
    156      -- two (or more) folds start at this line, so it cannot distinguish between
    157      --  ( \n ( \n )) \n (( \n ) \n )
    158      -- versus
    159      --  ( \n ( \n ) \n ( \n ) \n )
    160      -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
    161      -- vim interprets as the second case.
    162      -- If it did have such a mechanism, (clamped - clamped_prev)
    163      -- would be the correct number of starts to pass on.
    164      local adjusted = level0 ---@type integer
    165      local prefix = ''
    166      if enter_line > 0 then
    167        prefix = '>'
    168        if leave_line > 0 then
    169          -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
    170          -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
    171          -- foldminlines, but we don't handle it for simplicity.
    172          --- @type integer avoid flaky error
    173          adjusted = level0 - leave_line
    174          leave_line = 0
    175        end
    176      end
    177 
    178      -- Clamp at foldnestmax.
    179      local clamped = adjusted
    180      if adjusted > nestmax then
    181        prefix = ''
    182        clamped = nestmax
    183      end
    184 
    185      -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels().
    186      info.levels0[lnum] = adjusted
    187      info.levels[lnum] = prefix .. tostring(clamped)
    188 
    189      leave_prev = leave_line
    190      level0_prev = adjusted
    191    end
    192 
    193    if callback then
    194      callback()
    195    end
    196  end)
    197 end
    198 
    199 local M = {}
    200 
    201 ---@type table<integer,TS.FoldInfo>
    202 local foldinfos = {}
    203 
    204 local group = api.nvim_create_augroup('nvim.treesitter.fold', {})
    205 
    206 --- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
    207 --- the user doesn't use different foldexpr for the same buffer).
    208 ---
    209 --- Nvim usually automatically updates folds when text changes, but it doesn't work here because
    210 --- FoldInfo update is scheduled. So we do it manually.
    211 ---@package
    212 ---@param srow integer
    213 ---@param erow integer 0-indexed, exclusive
    214 function FoldInfo:foldupdate(bufnr, srow, erow)
    215  if self.foldupdate_range then
    216    edit_range(self.foldupdate_range, srow, erow, erow)
    217  else
    218    self.foldupdate_range = { srow, erow }
    219  end
    220 
    221  if api.nvim_get_mode().mode:match('^i') then
    222    -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
    223    if #(api.nvim_get_autocmds({
    224      group = group,
    225      buffer = bufnr,
    226    })) > 0 then
    227      return
    228    end
    229    api.nvim_create_autocmd('InsertLeave', {
    230      group = group,
    231      buffer = bufnr,
    232      once = true,
    233      callback = function()
    234        self:do_foldupdate(bufnr)
    235      end,
    236    })
    237    return
    238  end
    239 
    240  self:do_foldupdate(bufnr)
    241 end
    242 
    243 ---@package
    244 function FoldInfo:do_foldupdate(bufnr)
    245  -- InsertLeave is not executed when <C-C> is used for exiting the insert mode, leaving
    246  -- do_foldupdate untouched. If another execution of foldupdate consumes foldupdate_range, the
    247  -- InsertLeave do_foldupdate gets nil foldupdate_range. In that case, skip the update. This is
    248  -- correct because the update that consumed the range must have incorporated the range that
    249  -- InsertLeave meant to update.
    250  if not self.foldupdate_range then
    251    return
    252  end
    253 
    254  local srow, erow = self.foldupdate_range[1], self.foldupdate_range[2]
    255  self.foldupdate_range = nil
    256  for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do
    257    if vim.wo[win].foldmethod == 'expr' then
    258      vim._foldupdate(win, srow, erow)
    259    end
    260  end
    261 end
    262 
    263 --- Schedule a function only if bufnr is loaded.
    264 --- We schedule fold level computation for the following reasons:
    265 --- * queries seem to use the old buffer state in on_bytes for some unknown reason;
    266 --- * to avoid textlock;
    267 --- * to avoid infinite recursion:
    268 ---   compute_folds_levels → parse → _do_callback → on_changedtree → compute_folds_levels.
    269 ---@param bufnr integer
    270 ---@param fn function
    271 local function schedule_if_loaded(bufnr, fn)
    272  vim.schedule(function()
    273    if not api.nvim_buf_is_loaded(bufnr) then
    274      return
    275    end
    276    fn()
    277  end)
    278 end
    279 
    280 ---@param bufnr integer
    281 ---@param tree_changes Range4[]
    282 local function on_changedtree(bufnr, tree_changes)
    283  schedule_if_loaded(bufnr, function()
    284    -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked.
    285    local foldinfo = foldinfos[bufnr]
    286    if not foldinfo then
    287      return
    288    end
    289 
    290    local srow_upd, erow_upd ---@type integer?, integer?
    291    local max_erow = api.nvim_buf_line_count(bufnr)
    292    -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved
    293    local iterations = 0
    294    for _, change in ipairs(tree_changes) do
    295      local srow, _, erow, ecol = Range.unpack4(change)
    296      -- If a parser doesn't have any ranges explicitly set, treesitter will
    297      -- return a range with end_row and end_bytes with a value of UINT32_MAX,
    298      -- so clip end_row to the max buffer line.
    299      -- TODO(lewis6991): Handle this generally
    300      if erow > max_erow then
    301        erow = max_erow
    302      elseif ecol > 0 then
    303        erow = erow + 1
    304      end
    305      -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
    306      srow = math.max(srow - vim.wo.foldminlines, 0)
    307      srow_upd = srow_upd and math.min(srow_upd, srow) or srow
    308      erow_upd = erow_upd and math.max(erow_upd, erow) or erow
    309      compute_folds_levels(bufnr, foldinfo, srow, erow, function()
    310        iterations = iterations + 1
    311        if iterations == #tree_changes then
    312          foldinfo:foldupdate(bufnr, srow_upd, erow_upd)
    313        end
    314      end)
    315    end
    316  end)
    317 end
    318 
    319 ---@param bufnr integer
    320 ---@param start_row integer
    321 ---@param old_row integer
    322 ---@param old_col integer
    323 ---@param new_row integer
    324 ---@param new_col integer
    325 local function on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col)
    326  -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked.
    327  local foldinfo = foldinfos[bufnr]
    328  if not foldinfo then
    329    return
    330  end
    331 
    332  -- extend the end to fully include the range
    333  local end_row_old = start_row + old_row + 1
    334  local end_row_new = start_row + new_row + 1
    335 
    336  if new_row ~= old_row then
    337    -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the
    338    -- outdated levels, which may spuriously open the folds that didn't change. So we should shift
    339    -- folds as accurately as possible. For this to be perfectly accurate, we should track the
    340    -- actual TSNodes that account for each fold, and compare the node's range with the edited
    341    -- range. But for simplicity, we just check whether the start row is completely removed (e.g.,
    342    -- `dd`) or shifted (e.g., `o`).
    343    if new_row < old_row then
    344      if start_col == 0 and new_row == 0 and new_col == 0 then
    345        foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new))
    346      else
    347        foldinfo:remove_range(end_row_new, end_row_old)
    348      end
    349    else
    350      if start_col == 0 and old_row == 0 and old_col == 0 then
    351        foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old))
    352      else
    353        foldinfo:add_range(end_row_old, end_row_new)
    354      end
    355    end
    356 
    357    if foldinfo.on_bytes_range then
    358      edit_range(foldinfo.on_bytes_range, start_row, end_row_old, end_row_new)
    359    else
    360      foldinfo.on_bytes_range = { start_row, end_row_new }
    361    end
    362    if foldinfo.foldupdate_range then
    363      edit_range(foldinfo.foldupdate_range, start_row, end_row_old, end_row_new)
    364    end
    365 
    366    -- This callback must not use on_bytes arguments, because they can be outdated when the callback
    367    -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
    368    -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`.
    369    schedule_if_loaded(bufnr, function()
    370      if not (foldinfo.on_bytes_range and foldinfos[bufnr]) then
    371        return
    372      end
    373      local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2]
    374      foldinfo.on_bytes_range = nil
    375      -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
    376      srow = math.max(srow - vim.wo.foldminlines, 0)
    377      compute_folds_levels(bufnr, foldinfo, srow, erow, function()
    378        foldinfo:foldupdate(bufnr, srow, erow)
    379      end)
    380    end)
    381  end
    382 end
    383 
    384 local registered_cbs = {} ---@type table<integer, boolean>
    385 
    386 ---@param lnum integer|nil
    387 ---@return string
    388 function M.foldexpr(lnum)
    389  lnum = lnum or vim.v.lnum
    390  local bufnr = api.nvim_get_current_buf()
    391 
    392  if not foldinfos[bufnr] then
    393    foldinfos[bufnr] = FoldInfo.new(bufnr)
    394    api.nvim_create_autocmd({ 'BufUnload', 'VimEnter', 'FileType' }, {
    395      buffer = bufnr,
    396      once = true,
    397      callback = function()
    398        foldinfos[bufnr] = nil
    399      end,
    400    })
    401 
    402    local parser = foldinfos[bufnr].parser
    403    if not parser then
    404      return '0'
    405    end
    406 
    407    compute_folds_levels(bufnr, foldinfos[bufnr])
    408 
    409    if not registered_cbs[bufnr] then
    410      parser:register_cbs({
    411        on_changedtree = function(tree_changes)
    412          on_changedtree(bufnr, tree_changes)
    413        end,
    414 
    415        on_bytes = function(
    416          _,
    417          _,
    418          start_row,
    419          start_col,
    420          _,
    421          old_row,
    422          old_col,
    423          _,
    424          new_row,
    425          new_col,
    426          _
    427        )
    428          on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col)
    429        end,
    430 
    431        on_detach = function()
    432          foldinfos[bufnr] = nil
    433          registered_cbs[bufnr] = nil
    434        end,
    435      })
    436 
    437      registered_cbs[bufnr] = true
    438    end
    439  end
    440 
    441  return foldinfos[bufnr].levels[lnum] or '0'
    442 end
    443 
    444 api.nvim_create_autocmd('OptionSet', {
    445  pattern = { 'foldminlines', 'foldnestmax' },
    446  desc = 'Refresh treesitter folds',
    447  callback = function()
    448    local buf = api.nvim_get_current_buf()
    449    local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos)
    450      or foldinfos[buf] and { buf }
    451      or {}
    452    for _, bufnr in ipairs(bufs) do
    453      foldinfos[bufnr] = FoldInfo.new(bufnr)
    454      api.nvim_buf_call(bufnr, function()
    455        compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function()
    456          foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr))
    457        end)
    458      end)
    459    end
    460  end,
    461 })
    462 return M