neovim

Neovim text editor
git clone https://git.dasho.dev/neovim.git
Log | Files | Refs | README

_range.lua (4260B)


      1 local api = vim.api
      2 
      3 local M = {}
      4 
      5 ---@class Range2
      6 ---@inlinedoc
      7 ---@field [1] integer start row
      8 ---@field [2] integer end row
      9 
     10 ---@class Range4
     11 ---@inlinedoc
     12 ---@field [1] integer start row
     13 ---@field [2] integer start column
     14 ---@field [3] integer end row
     15 ---@field [4] integer end column
     16 
     17 ---@class Range6
     18 ---@inlinedoc
     19 ---@field [1] integer start row
     20 ---@field [2] integer start column
     21 ---@field [3] integer start bytes
     22 ---@field [4] integer end row
     23 ---@field [5] integer end column
     24 ---@field [6] integer end bytes
     25 
     26 ---@alias Range Range2|Range4|Range6
     27 
     28 ---@param a_row integer
     29 ---@param a_col integer
     30 ---@param b_row integer
     31 ---@param b_col integer
     32 ---@return integer
     33 --- 1: a > b
     34 --- 0: a == b
     35 --- -1: a < b
     36 local function cmp_pos(a_row, a_col, b_row, b_col)
     37  if a_row == b_row then
     38    if a_col > b_col then
     39      return 1
     40    elseif a_col < b_col then
     41      return -1
     42    else
     43      return 0
     44    end
     45  elseif a_row > b_row then
     46    return 1
     47  end
     48 
     49  return -1
     50 end
     51 
     52 M.cmp_pos = {
     53  lt = function(...)
     54    return cmp_pos(...) == -1
     55  end,
     56  le = function(...)
     57    return cmp_pos(...) ~= 1
     58  end,
     59  gt = function(...)
     60    return cmp_pos(...) == 1
     61  end,
     62  ge = function(...)
     63    return cmp_pos(...) ~= -1
     64  end,
     65  eq = function(...)
     66    return cmp_pos(...) == 0
     67  end,
     68  ne = function(...)
     69    return cmp_pos(...) ~= 0
     70  end,
     71 }
     72 
     73 setmetatable(M.cmp_pos, { __call = cmp_pos })
     74 
     75 ---Check if a variable is a valid range object
     76 ---@param r any
     77 ---@return boolean
     78 function M.validate(r)
     79  if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
     80    return false
     81  end
     82 
     83  for _, e in
     84    ipairs(r --[[@as any[] ]])
     85  do
     86    if type(e) ~= 'number' then
     87      return false
     88    end
     89  end
     90 
     91  return true
     92 end
     93 
     94 ---@param r1 Range
     95 ---@param r2 Range
     96 ---@return boolean
     97 function M.intercepts(r1, r2)
     98  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
     99  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
    100 
    101  -- r1 is above r2
    102  if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
    103    return false
    104  end
    105 
    106  -- r1 is below r2
    107  if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
    108    return false
    109  end
    110 
    111  return true
    112 end
    113 
    114 ---@param r1 Range6
    115 ---@param r2 Range6
    116 ---@return Range6?
    117 function M.intersection(r1, r2)
    118  if not M.intercepts(r1, r2) then
    119    return nil
    120  end
    121  local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
    122  local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
    123  return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
    124 end
    125 
    126 ---@param r Range
    127 ---@return integer, integer, integer, integer
    128 function M.unpack4(r)
    129  if #r == 2 then
    130    return r[1], 0, r[2], 0
    131  end
    132  local off_1 = #r == 6 and 1 or 0
    133  return r[1], r[2], r[3 + off_1], r[4 + off_1]
    134 end
    135 
    136 ---@param r Range6
    137 ---@return integer, integer, integer, integer, integer, integer
    138 function M.unpack6(r)
    139  return r[1], r[2], r[3], r[4], r[5], r[6]
    140 end
    141 
    142 ---@param r1 Range
    143 ---@param r2 Range
    144 ---@return boolean whether r1 contains r2
    145 function M.contains(r1, r2)
    146  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
    147  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
    148 
    149  -- start doesn't fit
    150  if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
    151    return false
    152  end
    153 
    154  -- end doesn't fit
    155  if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
    156    return false
    157  end
    158 
    159  return true
    160 end
    161 
    162 --- @param source integer|string
    163 --- @param index integer
    164 --- @return integer
    165 local function get_offset(source, index)
    166  if index == 0 then
    167    return 0
    168  end
    169 
    170  if type(source) == 'number' then
    171    return api.nvim_buf_get_offset(source, index)
    172  end
    173 
    174  local byte = 0
    175  local next_offset = source:gmatch('()\n')
    176  local line = 1
    177  while line <= index do
    178    byte = next_offset() --[[@as integer]]
    179    line = line + 1
    180  end
    181 
    182  return byte
    183 end
    184 
    185 ---@param source integer|string
    186 ---@param range Range
    187 ---@return Range6
    188 function M.add_bytes(source, range)
    189  if type(range) == 'table' and #range == 6 then
    190    return range --[[@as Range6]]
    191  end
    192 
    193  local start_row, start_col, end_row, end_col = M.unpack4(range)
    194  -- TODO(vigoux): proper byte computation here, and account for EOL ?
    195  local start_byte = get_offset(source, start_row) + start_col
    196  local end_byte = get_offset(source, end_row) + end_col
    197 
    198  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
    199 end
    200 
    201 return M