_range.lua (4260B)
1 local api = vim.api 2 3 local M = {} 4 5 ---@class Range2 6 ---@inlinedoc 7 ---@field [1] integer start row 8 ---@field [2] integer end row 9 10 ---@class Range4 11 ---@inlinedoc 12 ---@field [1] integer start row 13 ---@field [2] integer start column 14 ---@field [3] integer end row 15 ---@field [4] integer end column 16 17 ---@class Range6 18 ---@inlinedoc 19 ---@field [1] integer start row 20 ---@field [2] integer start column 21 ---@field [3] integer start bytes 22 ---@field [4] integer end row 23 ---@field [5] integer end column 24 ---@field [6] integer end bytes 25 26 ---@alias Range Range2|Range4|Range6 27 28 ---@param a_row integer 29 ---@param a_col integer 30 ---@param b_row integer 31 ---@param b_col integer 32 ---@return integer 33 --- 1: a > b 34 --- 0: a == b 35 --- -1: a < b 36 local function cmp_pos(a_row, a_col, b_row, b_col) 37 if a_row == b_row then 38 if a_col > b_col then 39 return 1 40 elseif a_col < b_col then 41 return -1 42 else 43 return 0 44 end 45 elseif a_row > b_row then 46 return 1 47 end 48 49 return -1 50 end 51 52 M.cmp_pos = { 53 lt = function(...) 54 return cmp_pos(...) == -1 55 end, 56 le = function(...) 57 return cmp_pos(...) ~= 1 58 end, 59 gt = function(...) 60 return cmp_pos(...) == 1 61 end, 62 ge = function(...) 63 return cmp_pos(...) ~= -1 64 end, 65 eq = function(...) 66 return cmp_pos(...) == 0 67 end, 68 ne = function(...) 69 return cmp_pos(...) ~= 0 70 end, 71 } 72 73 setmetatable(M.cmp_pos, { __call = cmp_pos }) 74 75 ---Check if a variable is a valid range object 76 ---@param r any 77 ---@return boolean 78 function M.validate(r) 79 if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then 80 return false 81 end 82 83 for _, e in 84 ipairs(r --[[@as any[] ]]) 85 do 86 if type(e) ~= 'number' then 87 return false 88 end 89 end 90 91 return true 92 end 93 94 ---@param r1 Range 95 ---@param r2 Range 96 ---@return boolean 97 function M.intercepts(r1, r2) 98 local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) 99 local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2) 100 101 -- r1 is above r2 102 if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then 103 return false 104 end 105 106 -- r1 is below r2 107 if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then 108 return false 109 end 110 111 return true 112 end 113 114 ---@param r1 Range6 115 ---@param r2 Range6 116 ---@return Range6? 117 function M.intersection(r1, r2) 118 if not M.intercepts(r1, r2) then 119 return nil 120 end 121 local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1 122 local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1 123 return { rs[1], rs[2], rs[3], re[4], re[5], re[6] } 124 end 125 126 ---@param r Range 127 ---@return integer, integer, integer, integer 128 function M.unpack4(r) 129 if #r == 2 then 130 return r[1], 0, r[2], 0 131 end 132 local off_1 = #r == 6 and 1 or 0 133 return r[1], r[2], r[3 + off_1], r[4 + off_1] 134 end 135 136 ---@param r Range6 137 ---@return integer, integer, integer, integer, integer, integer 138 function M.unpack6(r) 139 return r[1], r[2], r[3], r[4], r[5], r[6] 140 end 141 142 ---@param r1 Range 143 ---@param r2 Range 144 ---@return boolean whether r1 contains r2 145 function M.contains(r1, r2) 146 local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) 147 local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2) 148 149 -- start doesn't fit 150 if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then 151 return false 152 end 153 154 -- end doesn't fit 155 if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then 156 return false 157 end 158 159 return true 160 end 161 162 --- @param source integer|string 163 --- @param index integer 164 --- @return integer 165 local function get_offset(source, index) 166 if index == 0 then 167 return 0 168 end 169 170 if type(source) == 'number' then 171 return api.nvim_buf_get_offset(source, index) 172 end 173 174 local byte = 0 175 local next_offset = source:gmatch('()\n') 176 local line = 1 177 while line <= index do 178 byte = next_offset() --[[@as integer]] 179 line = line + 1 180 end 181 182 return byte 183 end 184 185 ---@param source integer|string 186 ---@param range Range 187 ---@return Range6 188 function M.add_bytes(source, range) 189 if type(range) == 'table' and #range == 6 then 190 return range --[[@as Range6]] 191 end 192 193 local start_row, start_col, end_row, end_col = M.unpack4(range) 194 -- TODO(vigoux): proper byte computation here, and account for EOL ? 195 local start_byte = get_offset(source, start_row) + start_col 196 local end_byte = get_offset(source, end_row) + end_col 197 198 return { start_row, start_col, start_byte, end_row, end_col, end_byte } 199 end 200 201 return M