commit f01419f3d56390af4bce0bd36a89a5f52e7c81e2
parent 715c28d67f04b23c0f9ebc48db5ea3b967c0a16d
Author: Mart-Mihkel Aun <73405010+mart-mihkel@users.noreply.github.com>
Date: Thu, 3 Jul 2025 16:12:24 +0300
feat(runtime): accept predicates in take and skip (#34657)
Make `vim.iter():take()` and `vim.iter():skip()`
optionally accept predicates to enable takewhile
and skipwhile patterns used in functional
programming.
Diffstat:
4 files changed, 148 insertions(+), 12 deletions(-)
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt
@@ -4545,17 +4545,24 @@ Iter:rskip({n}) *Iter:rskip()*
(`Iter`)
Iter:skip({n}) *Iter:skip()*
- Skips `n` values of an iterator pipeline.
+ Skips `n` values of an iterator pipeline, or all values satisfying a
+ predicate of a |list-iterator|.
Example: >lua
local it = vim.iter({ 3, 6, 9, 12 }):skip(2)
it:next()
-- 9
+
+ local function pred(x) return x < 10 end
+ local it2 = vim.iter({ 3, 6, 9, 12 }):skip(pred)
+ it2:next()
+ -- 12
<
Parameters: ~
- • {n} (`number`) Number of values to skip.
+ • {n} (`integer|fun(...):boolean`) Number of values to skip or a
+ predicate.
Return: ~
(`Iter`)
@@ -4573,7 +4580,8 @@ Iter:slice({first}, {last}) *Iter:slice()*
(`Iter`)
Iter:take({n}) *Iter:take()*
- Transforms an iterator to yield only the first n values.
+ Transforms an iterator to yield only the first n values, or all values
+ satisfying a predicate.
Example: >lua
local it = vim.iter({ 1, 2, 3, 4 }):take(2)
@@ -4583,10 +4591,18 @@ Iter:take({n}) *Iter:take()*
-- 2
it:next()
-- nil
+
+ local function pred(x) return x < 2 end
+ local it2 = vim.iter({ 1, 2, 3, 4 }):take(pred)
+ it2:next()
+ -- 1
+ it2:next()
+ -- nil
<
Parameters: ~
- • {n} (`integer`)
+ • {n} (`integer|fun(...):boolean`) Number of values to take or a
+ predicate.
Return: ~
(`Iter`)
diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt
@@ -201,6 +201,7 @@ LUA
• |vim.fs.root()| can define "equal priority" via nested lists.
• |vim.version.range()| output can be converted to human-readable string with |tostring()|.
• |vim.version.intersect()| computes intersection of two version ranges.
+• |Iter:take()| and |Iter:skip()| now optionally accept predicates.
OPTIONS
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
@@ -681,7 +681,8 @@ function ArrayIter:rfind(f)
self._head = self._tail
end
---- Transforms an iterator to yield only the first n values.
+--- Transforms an iterator to yield only the first n values, or all values
+--- satisfying a predicate.
---
--- Example:
---
@@ -693,24 +694,56 @@ end
--- -- 2
--- it:next()
--- -- nil
+---
+--- local function pred(x) return x < 2 end
+--- local it2 = vim.iter({ 1, 2, 3, 4 }):take(pred)
+--- it2:next()
+--- -- 1
+--- it2:next()
+--- -- nil
--- ```
---
----@param n integer
+---@param n integer|fun(...):boolean Number of values to take or a predicate.
---@return Iter
function Iter:take(n)
- local next = self.next
local i = 0
- self.next = function()
- if i < n then
+ local f = n
+ if type(n) ~= 'function' then
+ f = function()
+ return i < n
+ end
+ end
+
+ local stop = false
+ local function fn(...)
+ if not stop and select(1, ...) ~= nil and f(...) then
i = i + 1
- return next(self)
+ return ...
+ else
+ stop = true
end
end
+
+ local next = self.next
+ self.next = function()
+ return fn(next(self))
+ end
return self
end
---@private
function ArrayIter:take(n)
+ if type(n) == 'function' then
+ local inc = self._head < self._tail and 1 or -1
+ for i = self._head, self._tail, inc do
+ if not n(unpack(self._table[i])) then
+ self._tail = i
+ break
+ end
+ end
+ return self
+ end
+
local inc = self._head < self._tail and n or -n
local cmp = self._head < self._tail and math.min or math.max
self._tail = cmp(self._tail, self._head + inc)
@@ -772,7 +805,8 @@ function ArrayIter:rpeek()
end
end
---- Skips `n` values of an iterator pipeline.
+--- Skips `n` values of an iterator pipeline, or all values satisfying a
+--- predicate of a |list-iterator|.
---
--- Example:
---
@@ -782,11 +816,20 @@ end
--- it:next()
--- -- 9
---
+--- local function pred(x) return x < 10 end
+--- local it2 = vim.iter({ 3, 6, 9, 12 }):skip(pred)
+--- it2:next()
+--- -- 12
--- ```
---
----@param n number Number of values to skip.
+---@param n integer|fun(...):boolean Number of values to skip or a predicate.
---@return Iter
function Iter:skip(n)
+ if type(n) == 'function' then
+ -- We would need to evaluate the perdicate without advancing iterator
+ error('skip() with predicate requires an array-like table')
+ end
+
for _ = 1, n do
local _ = self:next()
end
@@ -795,6 +838,16 @@ end
---@private
function ArrayIter:skip(n)
+ if type(n) == 'function' then
+ local inc = self._head < self._tail and 1 or -1
+ local i = self._head
+ while n(unpack(self:peek())) and i ~= self._tail do
+ self:next()
+ i = i + inc
+ end
+ return self
+ end
+
local inc = self._head < self._tail and n or -n
self._head = self._head + inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
diff --git a/test/functional/lua/iter_spec.lua b/test/functional/lua/iter_spec.lua
@@ -160,6 +160,30 @@ describe('vim.iter', function()
end
do
+ local function wrong()
+ return false
+ end
+
+ local function correct()
+ return true
+ end
+
+ local q = { 4, 3, 2, 1 }
+
+ eq({ 4, 3, 2, 1 }, vim.iter(q):skip(wrong):totable())
+ eq(
+ { 2, 1 },
+ vim
+ .iter(q)
+ :skip(function(x)
+ return x > 2
+ end)
+ :totable()
+ )
+ eq({}, vim.iter(q):skip(correct):totable())
+ end
+
+ do
local function skip(n)
return vim.iter(vim.gsplit('a|b|c|d', '|')):skip(n):totable()
end
@@ -241,6 +265,14 @@ describe('vim.iter', function()
end)
it('take()', function()
+ local function correct()
+ return true
+ end
+
+ local function wrong()
+ return false
+ end
+
do
local q = { 4, 3, 2, 1 }
eq({}, vim.iter(q):take(0):totable())
@@ -253,6 +285,22 @@ describe('vim.iter', function()
do
local q = { 4, 3, 2, 1 }
+
+ eq({}, vim.iter(q):take(wrong):totable())
+ eq(
+ { 4, 3 },
+ vim
+ .iter(q)
+ :take(function(x)
+ return x > 2
+ end)
+ :totable()
+ )
+ eq({ 4, 3, 2, 1 }, vim.iter(q):take(correct):totable())
+ end
+
+ do
+ local q = { 4, 3, 2, 1 }
eq({ 1, 2, 3 }, vim.iter(q):rev():take(3):totable())
eq({ 2, 3, 4 }, vim.iter(q):take(3):rev():totable())
end
@@ -271,6 +319,24 @@ describe('vim.iter', function()
-- non-array iterators are consumed by take()
eq({}, it:take(2):totable())
end
+
+ do
+ eq({ 'a', 'b', 'c', 'd' }, vim.iter(vim.gsplit('a|b|c|d', '|')):take(correct):totable())
+ eq(
+ { 'a', 'b', 'c' },
+ vim
+ .iter(vim.gsplit('a|b|c|d', '|'))
+ :enumerate()
+ :take(function(i, x)
+ return i < 3 or x == 'c'
+ end)
+ :map(function(_, x)
+ return x
+ end)
+ :totable()
+ )
+ eq({}, vim.iter(vim.gsplit('a|b|c|d', '|')):take(wrong):totable())
+ end
end)
it('any()', function()