test_storage_ext_vec.js (6434B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 // This file tests support for the sqlite-vec extension. 6 7 function tensorToBlob(tensor) { 8 return new Uint8ClampedArray(new Float32Array(tensor).buffer); 9 } 10 11 add_setup(async function () { 12 cleanup(); 13 }); 14 15 add_task(async function test_synchronous() { 16 info("Testing synchronous connection"); 17 let conn = getOpenedUnsharedDatabase(); 18 Assert.throws( 19 () => 20 conn.executeSimpleSQL( 21 `CREATE VIRTUAL TABLE test USING vec0( 22 embedding FLOAT[4] 23 );` 24 ), 25 /NS_ERROR_FAILURE/, 26 "Should not be able to use vec without loading the extension" 27 ); 28 29 await loadExtension(conn); 30 31 conn.executeSimpleSQL( 32 ` 33 CREATE VIRTUAL TABLE test USING vec0( 34 embedding FLOAT[4] 35 ) 36 ` 37 ); 38 39 let stmt = conn.createStatement( 40 ` 41 INSERT INTO test(rowid, embedding) 42 VALUES (1, :vector) 43 ` 44 ); 45 stmt.bindBlobByName("vector", tensorToBlob([0.3, 0.3, 0.3, 0.3])); 46 stmt.executeStep(); 47 stmt.reset(); 48 stmt.finalize(); 49 50 stmt = conn.createStatement( 51 ` 52 SELECT 53 rowid, 54 distance 55 FROM test 56 WHERE embedding MATCH :vector 57 ORDER BY distance 58 LIMIT 1 59 ` 60 ); 61 stmt.bindBlobByName("vector", tensorToBlob([0.3, 0.3, 0.3, 0.3])); 62 Assert.ok(stmt.executeStep()); 63 Assert.equal(stmt.getInt32(0), 1); 64 Assert.equal(stmt.getDouble(1), 0.0); 65 stmt.reset(); 66 stmt.finalize(); 67 68 cleanup(); 69 }); 70 71 add_task(async function test_asynchronous() { 72 info("Testing asynchronous connection"); 73 let conn = await openAsyncDatabase(getTestDB()); 74 75 await Assert.rejects( 76 executeSimpleSQLAsync( 77 conn, 78 ` 79 CREATE VIRTUAL TABLE test USING vec0( 80 embedding float[4] 81 ) 82 ` 83 ), 84 err => err.message.startsWith("no such module"), 85 "Should not be able to use vec without loading the extension" 86 ); 87 88 await loadExtension(conn); 89 90 await executeSimpleSQLAsync( 91 conn, 92 ` 93 CREATE VIRTUAL TABLE test USING vec0( 94 embedding float[4] 95 ) 96 ` 97 ); 98 99 await asyncClose(conn); 100 await IOUtils.remove(getTestDB().path, { ignoreAbsent: true }); 101 }); 102 103 add_task(async function test_clone() { 104 info("Testing cloning synchronous connection loads extensions in clone"); 105 let conn1 = getOpenedUnsharedDatabase(); 106 await loadExtension(conn1); 107 108 let conn2 = conn1.clone(false); 109 conn2.executeSimpleSQL( 110 ` 111 CREATE VIRTUAL TABLE test USING vec0( 112 embedding float[4] 113 ) 114 ` 115 ); 116 117 conn2.close(); 118 cleanup(); 119 }); 120 121 add_task(async function test_asyncClone() { 122 info("Testing asynchronously cloning connection loads extensions in clone"); 123 let conn1 = getOpenedUnsharedDatabase(); 124 await loadExtension(conn1); 125 126 let conn2 = await asyncClone(conn1, false); 127 await executeSimpleSQLAsync( 128 conn2, 129 ` 130 CREATE VIRTUAL TABLE test USING vec0( 131 embedding float[4] 132 ) 133 ` 134 ); 135 136 await asyncClose(conn2); 137 await asyncCleanup(); 138 }); 139 140 async function loadExtension(conn, ext = "vec") { 141 await new Promise((resolve, reject) => { 142 conn.loadExtension(ext, status => { 143 if (Components.isSuccessCode(status)) { 144 resolve(); 145 } else { 146 reject(status); 147 } 148 }); 149 }); 150 } 151 152 add_task(async function test_invariants() { 153 // Test some invariants of the vec extension that we rely upon, so that if 154 // the behavior changes we can catch it. 155 let conn = getOpenedUnsharedDatabase(); 156 await loadExtension(conn); 157 158 conn.executeSimpleSQL( 159 ` 160 CREATE VIRTUAL TABLE vectors USING vec0( 161 embedding FLOAT[4] 162 ) 163 ` 164 ); 165 conn.executeSimpleSQL( 166 ` 167 CREATE TABLE relations ( 168 rowid INTEGER PRIMARY KEY, 169 content TEXT 170 ) 171 ` 172 ); 173 174 let rowids = []; 175 let insertRelStmt = conn.createStatement( 176 ` 177 INSERT INTO relations (rowid, content) 178 VALUES (NULL, "test") 179 RETURNING rowid 180 ` 181 ); 182 Assert.ok(insertRelStmt.executeStep()); 183 rowids.push(insertRelStmt.getInt32(0)); 184 insertRelStmt.reset(); 185 Assert.ok(insertRelStmt.executeStep()); 186 rowids.push(insertRelStmt.getInt32(0)); 187 insertRelStmt.reset(); 188 189 // Try to insert the same rowid twice in the vec table. 190 let insertVecStmt = conn.createStatement( 191 ` 192 INSERT INTO vectors (rowid, embedding) 193 VALUES (:rowid, :vector) 194 ` 195 ); 196 insertVecStmt.bindByName("rowid", rowids[0]); 197 insertVecStmt.bindBlobByName("vector", tensorToBlob([0.1, 0.1, 0.1, 0.1])); 198 insertVecStmt.executeStep(); 199 insertVecStmt.reset(); 200 201 let deleteStmt = conn.createStatement( 202 ` 203 DELETE FROM vectors WHERE rowid = :rowid 204 ` 205 ); 206 deleteStmt.bindByName("rowid", rowids[0]); 207 deleteStmt.executeStep(); 208 deleteStmt.finalize(); 209 210 insertVecStmt.bindByName("rowid", rowids[0]); 211 insertVecStmt.bindBlobByName("vector", tensorToBlob([0.2, 0.2, 0.2, 0.2])); 212 insertVecStmt.executeStep(); 213 insertVecStmt.reset(); 214 215 let selectStmt = conn.createStatement( 216 ` 217 SELECT 218 rowid, 219 vec_to_json(embedding) 220 FROM vectors 221 ` 222 ); 223 let count = 0; 224 while (selectStmt.executeStep()) { 225 count++; 226 Assert.equal(selectStmt.getInt32(0), rowids[0]); 227 Assert.equal( 228 selectStmt.getUTF8String(1).replace(/(?<=[0-9])0+/g, ""), 229 "[0.2,0.2,0.2,0.2]" 230 ); 231 } 232 Assert.equal(count, 1, "Should have one row in the vec table"); 233 selectStmt.reset(); 234 235 Assert.ok(insertRelStmt.executeStep()); 236 rowids.push(insertRelStmt.getInt32(0)); 237 insertRelStmt.finalize(); 238 insertVecStmt.bindByName("rowid", rowids[2]); 239 insertVecStmt.bindBlobByName("vector", tensorToBlob([0.3, 0.3, 0.3, 0.3])); 240 insertVecStmt.executeStep(); 241 insertVecStmt.finalize(); 242 243 let expected = [ 244 { rowid: rowids[0], vector: "[0.2,0.2,0.2,0.2]" }, 245 { rowid: rowids[2], vector: "[0.3,0.3,0.3,0.3]" }, 246 ]; 247 count = 0; 248 for (let i = 0; selectStmt.executeStep(); i++) { 249 count++; 250 Assert.equal(selectStmt.getInt32(0), expected[i].rowid); 251 Assert.equal( 252 selectStmt.getUTF8String(1).replace(/(?<=[0-9])0+/g, ""), 253 expected[i].vector 254 ); 255 } 256 Assert.equal(count, 2, "Should have two rows in the vec table"); 257 selectStmt.finalize(); 258 259 // TODO: In the future add testing for RETURNING and UPSERT as those are 260 // currently broken. See: 261 // https://github.com/asg017/sqlite-vec/issues/127 262 // https://github.com/asg017/sqlite-vec/issues/229 263 264 cleanup(); 265 });