commit 7ed8cbd095805b9e6079def91b13ec24ecec5348
parent 40aef0d02e1a3f528e79ece3fe42b24f8d556251
Author: Yi Ming <ofseed@foxmail.com>
Date: Sun, 3 Aug 2025 22:45:49 +0800
feat(lua): vim.list.bisect() #35108
Diffstat:
6 files changed, 271 insertions(+), 65 deletions(-)
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt
@@ -1755,13 +1755,66 @@ vim.islist({t}) *vim.islist()*
See also: ~
• |vim.isarray()|
+vim.list.bisect({t}, {val}, {opts}) *vim.list.bisect()*
+ Search for a position in a sorted list {t} where {val} can be inserted
+ while keeping the list sorted.
+
+ Use {bound} to determine whether to return the first or the last position,
+ defaults to "lower", i.e., the first position.
+
+ NOTE: Behavior is undefined on unsorted lists!
+
+ Example: >lua
+
+ local t = { 1, 2, 2, 3, 3, 3 }
+ local first = vim.list.bisect(t, 3)
+ -- `first` is `val`'s first index if found,
+ -- useful for existence checks.
+ print(t[first]) -- 3
+
+ local last = vim.list.bisect(t, 3, { bound = 'upper' })
+ -- Note that `last` is 7, not 6,
+ -- this is suitable for insertion.
+
+ table.insert(t, last, 4)
+ -- t is now { 1, 2, 2, 3, 3, 3, 4 }
+
+ -- You can use lower bound and upper bound together
+ -- to obtain the range of occurrences of `val`.
+
+ -- 3 is in [first, last)
+ for i = first, last - 1 do
+ print(t[i]) -- { 3, 3, 3 }
+ end
+<
+
+ Parameters: ~
+ • {t} (`any[]`) A comparable list.
+ • {val} (`any`) The value to search.
+ • {opts} (`table?`) A table with the following fields:
+ • {lo}? (`integer`, default: `1`) Start index of the list.
+ • {hi}? (`integer`, default: `#t + 1`) End index of the list,
+ exclusive.
+ • {key}? (`fun(val: any): any`) Optional, compare the return
+ value instead of the {val} itself if provided.
+ • {bound}? (`'lower'|'upper'`, default: `'lower'`) Specifies
+ the search variant.
+ • "lower": returns the first position where inserting {val}
+ keeps the list sorted.
+ • "upper": returns the last position where inserting {val}
+ keeps the list sorted..
+
+ Return: ~
+ (`integer`) index serves as either the lower bound or the upper bound
+ position.
+
vim.list.unique({t}, {key}) *vim.list.unique()*
Removes duplicate values from a list-like table in-place.
Only the first occurrence of each value is kept. The operation is
performed in-place and the input table is modified.
- Accepts an optional `hash` argument that if provided is called for each
+ Accepts an optional `key` argument that if provided is called for each
value in the list to compute a hash key for uniqueness comparison. This is
useful for deduplicating table values or complex objects.
@@ -1778,7 +1831,7 @@ vim.list.unique({t}, {key}) *vim.list.unique()*
Parameters: ~
• {t} (`any[]`)
- • {key} (`fun(x: T): any??`) Optional hash function to determine
+ • {key} (`fun(x: T): any?`) Optional hash function to determine
uniqueness of values
Return: ~
diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt
@@ -234,6 +234,7 @@ LUA
• |Iter:take()| and |Iter:skip()| now optionally accept predicates.
• Built-in plugin manager |vim.pack|
• |vim.list.unique()| to deduplicate lists.
+• |vim.list.bisect()| for binary search.
OPTIONS
diff --git a/runtime/lua/vim/lsp/semantic_tokens.lua b/runtime/lua/vim/lsp/semantic_tokens.lua
@@ -45,38 +45,6 @@ local STHighlighter = { name = 'Semantic Tokens', active = {} }
STHighlighter.__index = STHighlighter
setmetatable(STHighlighter, Capability)
---- Do a binary search of the tokens in the half-open range [lo, hi).
----
---- Return the index i in range such that tokens[j].line < line for all j < i, and
---- tokens[j].line >= line for all j >= i, or return hi if no such index is found.
-local function lower_bound(tokens, line, lo, hi)
- while lo < hi do
- local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2).
- if tokens[mid].end_line < line then
- lo = mid + 1
- else
- hi = mid
- end
- end
- return lo
-end
-
---- Do a binary search of the tokens in the half-open range [lo, hi).
----
---- Return the index i in range such that tokens[j].line <= line for all j < i, and
---- tokens[j].line > line for all j >= i, or return hi if no such index is found.
-local function upper_bound(tokens, line, lo, hi)
- while lo < hi do
- local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2).
- if line < tokens[mid].line then
- hi = mid
- else
- lo = mid + 1
- end
- end
- return lo
-end
-
--- Extracts modifier strings from the encoded number in the token array
---
---@param x integer
@@ -488,8 +456,18 @@ function STHighlighter:on_win(topline, botline)
local ft = vim.bo[self.bufnr].filetype
local highlights = assert(current_result.highlights)
- local first = lower_bound(highlights, topline, 1, #highlights + 1)
- local last = upper_bound(highlights, botline, first, #highlights + 1) - 1
+ local first = vim.list.bisect(highlights, { end_line = topline }, {
+ key = function(highlight)
+ return highlight.end_line
+ end,
+ })
+ local last = vim.list.bisect(highlights, { line = botline }, {
+ lo = first,
+ bound = 'upper',
+ key = function(highlight)
+ return highlight.line
+ end,
+ }) - 1
--- @type boolean?, integer?
local is_folded, foldend
@@ -761,7 +739,11 @@ function M.get_at_pos(bufnr, row, col)
for client_id, client in pairs(highlighter.client_state) do
local highlights = client.current_result.highlights
if highlights then
- local idx = lower_bound(highlights, row, 1, #highlights + 1)
+ local idx = vim.list.bisect(highlights, { end_line = row }, {
+ key = function(highlight)
+ return highlight.end_line
+ end,
+ })
for i = idx, #highlights do
local token = highlights[i]
--- @cast token STTokenRangeInspect
diff --git a/runtime/lua/vim/shared.lua b/runtime/lua/vim/shared.lua
@@ -350,12 +350,21 @@ end
vim.list = {}
+---TODO(ofseed): memoize, string value support, type alias.
+---@generic T
+---@param v T
+---@param key? fun(v: T): any
+---@return any
+local function key_fn(v, key)
+ return key and key(v) or v
+end
+
--- Removes duplicate values from a list-like table in-place.
---
--- Only the first occurrence of each value is kept.
--- The operation is performed in-place and the input table is modified.
---
---- Accepts an optional `hash` argument that if provided is called for each
+--- Accepts an optional `key` argument that if provided is called for each
--- value in the list to compute a hash key for uniqueness comparison.
--- This is useful for deduplicating table values or complex objects.
---
@@ -373,21 +382,18 @@ vim.list = {}
---
--- @generic T
--- @param t T[]
---- @param key? fun(x: T): any? Optional hash function to determine uniqueness of values
+--- @param key? fun(x: T): any Optional hash function to determine uniqueness of values
--- @return T[] : The deduplicated list
function vim.list.unique(t, key)
vim.validate('t', t, 'table')
local seen = {} --- @type table<any,boolean>
local finish = #t
- key = key or function(a)
- return a
- end
local j = 1
for i = 1, finish do
local v = t[i]
- local vh = key(v)
+ local vh = key_fn(v, key)
if not seen[vh] then
t[j] = v
if vh ~= nil then
@@ -404,6 +410,127 @@ function vim.list.unique(t, key)
return t
end
+---@class vim.list.bisect.Opts
+---@inlinedoc
+---
+--- Start index of the list.
+--- (default: `1`)
+---@field lo? integer
+---
+--- End index of the list, exclusive.
+--- (default: `#t + 1`)
+---@field hi? integer
+---
+--- Optional, compare the return value instead of the {val} itself if provided.
+---@field key? fun(val: any): any
+---
+--- Specifies the search variant.
+--- - "lower": returns the first position
+--- where inserting {val} keeps the list sorted.
+--- - "upper": returns the last position
+--- where inserting {val} keeps the list sorted..
+--- (default: `'lower'`)
+---@field bound? 'lower' | 'upper'
+
+---@generic T
+---@param t T[]
+---@param val T
+---@param key? fun(val: any): any
+---@param lo integer
+---@param hi integer
+---@return integer i in range such that `t[j]` < {val} for all j < i,
+--- and `t[j]` >= {val} for all j >= i,
+--- or return {hi} if no such index is found.
+local function lower_bound(t, val, lo, hi, key)
+ local bit = require('bit') -- Load bitop on demand
+ local val_key = key_fn(val, key)
+ while lo < hi do
+ local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2)
+ if key_fn(t[mid], key) < val_key then
+ lo = mid + 1
+ else
+ hi = mid
+ end
+ end
+ return lo
+end
+
+---@generic T
+---@param t T[]
+---@param val T
+---@param key? fun(val: any): any
+---@param lo integer
+---@param hi integer
+---@return integer i in range such that `t[j]` <= {val} for all j < i,
+--- and `t[j]` > {val} for all j >= i,
+--- or return {hi} if no such index is found.
+local function upper_bound(t, val, lo, hi, key)
+ local bit = require('bit') -- Load bitop on demand
+ local val_key = key_fn(val, key)
+ while lo < hi do
+ local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2)
+ if val_key < key_fn(t[mid], key) then
+ hi = mid
+ else
+ lo = mid + 1
+ end
+ end
+ return lo
+end
+
+--- Search for a position in a sorted list {t}
+--- where {val} can be inserted while keeping the list sorted.
+---
+--- Use {bound} to determine whether to return the first or the last position,
+--- defaults to "lower", i.e., the first position.
+---
+--- NOTE: Behavior is undefined on unsorted lists!
+---
+--- Example:
+--- ```lua
+---
+--- local t = { 1, 2, 2, 3, 3, 3 }
+--- local first = vim.list.bisect(t, 3)
+--- -- `first` is `val`'s first index if found,
+--- -- useful for existence checks.
+--- print(t[first]) -- 3
+---
+--- local last = vim.list.bisect(t, 3, { bound = 'upper' })
+--- -- Note that `last` is 7, not 6,
+--- -- this is suitable for insertion.
+---
+--- table.insert(t, last, 4)
+--- -- t is now { 1, 2, 2, 3, 3, 3, 4 }
+---
+--- -- You can use lower bound and upper bound together
+--- -- to obtain the range of occurrences of `val`.
+---
+--- -- 3 is in [first, last)
+--- for i = first, last - 1 do
+--- print(t[i]) -- { 3, 3, 3 }
+--- end
+--- ```
+---@generic T
+---@param t T[] A comparable list.
+---@param val T The value to search.
+---@param opts? vim.list.bisect.Opts
+---@return integer index serves as either the lower bound or the upper bound position.
+function vim.list.bisect(t, val, opts)
+ vim.validate('t', t, 'table')
+ vim.validate('opts', opts, 'table', true)
+
+ opts = opts or {}
+ local lo = opts.lo or 1
+ local hi = opts.hi or #t + 1
+ local key = opts.key
+
+ if opts.bound == 'upper' then
+ return upper_bound(t, val, lo, hi, key)
+ else
+ return lower_bound(t, val, lo, hi, key)
+ end
+end
+
--- Checks if a table is empty.
---
---@see https://github.com/premake/premake-core/blob/master/src/base/table.lua
diff --git a/test/functional/lua/list_spec.lua b/test/functional/lua/list_spec.lua
@@ -0,0 +1,65 @@
+-- Test suite for vim.list
+local t = require('test.testutil')
+local eq = t.eq
+
+describe('vim.list', function()
+ it('vim.list.unique()', function()
+ eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5 }))
+ eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 3, 4, 4, 5, 1, 2, 3, 2, 1, 2, 3, 4, 5 }))
+ eq({ 1, 2, 3, 4, 5, field = 1 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, field = 1 }))
+
+ -- Not properly defined, but test anyway
+ -- luajit evaluates #t as 7, whereas Lua 5.1 evaluates it as 12
+ local r = vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, nil, 6, 6, 7, 7 })
+ if jit then
+ eq({ 1, 2, 3, 4, 5, nil, nil, nil, 6, 6, 7, 7 }, r)
+ else
+ eq({ 1, 2, 3, 4, 5, nil, 6, 7 }, r)
+ end
+
+ eq(
+ { { 1 }, { 2 }, { 3 } },
+ vim.list.unique({ { 1 }, { 1 }, { 2 }, { 2 }, { 3 }, { 3 } }, function(x)
+ return x[1]
+ end)
+ )
+ end)
+
+ --- Generate a list like { 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, ...}.
+ ---@param num integer
+ local function gen_list(num)
+ ---@type integer[]
+ local list = {}
+ for i = 1, num do
+ for _ = 1, i do
+ list[#list + 1] = i
+ end
+ end
+ return list
+ end
+
+ --- Index of the last {num}.
+ --- Mathematically, a triangular number.
+ ---@param num integer
+ local function index(num)
+ return math.floor((math.pow(num, 2) + num) / 2)
+ end
+
+ it("vim.list.bisect(..., { bound = 'lower' })", function()
+ local num = math.random(100)
+ local list = gen_list(num)
+
+ local target = math.random(num)
+ eq(vim.list.bisect(list, target, { bound = 'lower' }), index(target - 1) + 1)
+ eq(vim.list.bisect(list, num + 1, { bound = 'lower' }), index(num) + 1)
+ end)
+
+ it("vim.list.bisect(..., bound = { 'upper' })", function()
+ local num = math.random(100)
+ local list = gen_list(num)
+
+ local target = math.random(num)
+ eq(vim.list.bisect(list, target, { bound = 'upper' }), index(target) + 1)
+ eq(vim.list.bisect(list, num + 1, { bound = 'upper' }), index(num) + 1)
+ end)
+end)
diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua
@@ -1260,28 +1260,6 @@ describe('lua stdlib', function()
eq({ 2 }, exec_lua [[ return vim.list_extend({}, {2;a=1}, -1, 2) ]])
end)
- it('vim.list.unique', function()
- eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5 }))
- eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 3, 4, 4, 5, 1, 2, 3, 2, 1, 2, 3, 4, 5 }))
- eq({ 1, 2, 3, 4, 5, field = 1 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, field = 1 }))
-
- -- Not properly defined, but test anyway
- -- luajit evaluates #t as 7, whereas Lua 5.1 evaluates it as 12
- local r = vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, nil, 6, 6, 7, 7 })
- if jit then
- eq({ 1, 2, 3, 4, 5, nil, nil, nil, 6, 6, 7, 7 }, r)
- else
- eq({ 1, 2, 3, 4, 5, nil, 6, 7 }, r)
- end
-
- eq(
- { { 1 }, { 2 }, { 3 } },
- vim.list.unique({ { 1 }, { 1 }, { 2 }, { 2 }, { 3 }, { 3 } }, function(x)
- return x[1]
- end)
- )
- end)
-
it('vim.tbl_add_reverse_lookup', function()
eq(
true,