commit 6b96122453fda22dc44a581af1d536988c1adf41
parent 0a3645a72307afa563683a6e06c544810e0b65eb
Author: Gregory Anders <greg@gpanders.com>
Date: Wed, 19 Apr 2023 06:45:56 -0600
fix(iter): add tag to packed table
If pack() is called with a single value, it does not create a table; it
simply returns the value it is passed. When unpack is called with a
table argument, it interprets that table as a list of values that were
packed together into a table.
This causes a problem when the single value being packed is _itself_ a
table. pack() will not place it into another table, but unpack() sees
the table argument and tries to unpack it.
To fix this, we add a simple "tag" to packed table values so that
unpack() only attempts to unpack tables that have this tag. Other tables
are left alone. The tag is simply the length of the table.
Diffstat:
2 files changed, 47 insertions(+), 5 deletions(-)
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
@@ -28,16 +28,17 @@ end
---@private
local function unpack(t)
- if type(t) == 'table' then
- return _G.unpack(t)
+ if type(t) == 'table' and t.__n ~= nil then
+ return _G.unpack(t, 1, t.__n)
end
return t
end
---@private
local function pack(...)
- if select('#', ...) > 1 then
- return { ... }
+ local n = select('#', ...)
+ if n > 1 then
+ return { __n = n, ... }
end
return ...
end
@@ -210,6 +211,12 @@ function Iter.totable(self)
if args == nil then
break
end
+
+ if type(args) == 'table' then
+ -- Removed packed table tag if it exists
+ args.__n = nil
+ end
+
t[#t + 1] = args
end
return t
@@ -218,6 +225,14 @@ end
---@private
function ListIter.totable(self)
if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then
+ -- Remove any packed table tags
+ for i = 1, #self._table do
+ local v = self._table[i]
+ if type(v) == 'table' then
+ v.__n = nil
+ self._table[i] = v
+ end
+ end
return self._table
end
@@ -747,7 +762,7 @@ function ListIter.enumerate(self)
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
local v = self._table[i]
- self._table[i] = { i, v }
+ self._table[i] = pack(i, v)
end
return self
end
diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua
@@ -3381,6 +3381,33 @@ describe('lua stdlib', function()
end
end)
eq({ A = 2, C = 6 }, it:totable())
+
+ it('handles table values mid-pipeline', function()
+ local map = {
+ item = {
+ file = 'test',
+ },
+ item_2 = {
+ file = 'test',
+ },
+ item_3 = {
+ file = 'test',
+ },
+ }
+
+ local output = vim.iter(map):map(function(key, value)
+ return { [key] = value.file }
+ end):totable()
+
+ table.sort(output, function(a, b)
+ return next(a) < next(b)
+ end)
+
+ eq({
+ { item = 'test' },
+ { item_2 = 'test' },
+ { item_3 = 'test' },
+ }, output)
end)
end)
end)