set.lua (2880B)
1 -- a set class for fast union/diff, can always return a table with the lines 2 -- in the same relative order in which they were added by calling the 3 -- to_table method. It does this by keeping two lua tables that mirror each 4 -- other: 5 -- 1) index => item 6 -- 2) item => index 7 --- @class Set 8 --- @field nelem integer 9 --- @field items string[] 10 --- @field tbl table 11 local Set = {} 12 13 --- @param items? string[] 14 function Set:new(items) 15 local obj = {} --- @type Set 16 setmetatable(obj, self) 17 self.__index = self 18 19 if type(items) == 'table' then 20 local tempset = Set:new() 21 tempset:union_table(items) 22 obj.tbl = tempset:raw_tbl() 23 obj.items = tempset:raw_items() 24 obj.nelem = tempset:size() 25 else 26 obj.tbl = {} 27 obj.items = {} 28 obj.nelem = 0 29 end 30 31 return obj 32 end 33 34 --- @return Set 35 function Set:copy() 36 local obj = { nelem = self.nelem, tbl = {}, items = {} } --- @type Set 37 for k, v in pairs(self.tbl) do 38 obj.tbl[k] = v 39 end 40 for k, v in pairs(self.items) do 41 obj.items[k] = v 42 end 43 setmetatable(obj, Set) 44 obj.__index = Set 45 return obj 46 end 47 48 -- adds the argument Set to this Set 49 --- @param other Set 50 function Set:union(other) 51 for e in other:iterator() do 52 self:add(e) 53 end 54 end 55 56 -- adds the argument table to this Set 57 function Set:union_table(t) 58 for _, v in pairs(t) do 59 self:add(v) 60 end 61 end 62 63 -- subtracts the argument Set from this Set 64 --- @param other Set 65 function Set:diff(other) 66 if other:size() > self:size() then 67 -- this set is smaller than the other set 68 for e in self:iterator() do 69 if other:contains(e) then 70 self:remove(e) 71 end 72 end 73 else 74 -- this set is larger than the other set 75 for e in other:iterator() do 76 if self.items[e] then 77 self:remove(e) 78 end 79 end 80 end 81 end 82 83 --- @param it string 84 function Set:add(it) 85 if not self:contains(it) then 86 local idx = #self.tbl + 1 87 self.tbl[idx] = it 88 self.items[it] = idx 89 self.nelem = self.nelem + 1 90 end 91 end 92 93 --- @param it string 94 function Set:remove(it) 95 if self:contains(it) then 96 local idx = self.items[it] 97 self.tbl[idx] = nil 98 self.items[it] = nil 99 self.nelem = self.nelem - 1 100 end 101 end 102 103 --- @param it string 104 --- @return boolean 105 function Set:contains(it) 106 return self.items[it] or false 107 end 108 109 --- @return integer 110 function Set:size() 111 return self.nelem 112 end 113 114 function Set:raw_tbl() 115 return self.tbl 116 end 117 118 function Set:raw_items() 119 return self.items 120 end 121 122 function Set:iterator() 123 return pairs(self.items) 124 end 125 126 --- @return string[] 127 function Set:to_table() 128 -- there might be gaps in @tbl, so we have to be careful and sort first 129 local keys = {} --- @type string[] 130 for idx, _ in pairs(self.tbl) do 131 keys[#keys + 1] = idx 132 end 133 134 table.sort(keys) 135 local copy = {} --- @type string[] 136 for _, idx in ipairs(keys) do 137 copy[#copy + 1] = self.tbl[idx] 138 end 139 return copy 140 end 141 142 return Set