lstmCell.https.any.js (24492B)
1 // META: title=validation tests for WebNN API lstmCell operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const batchSize = 3, inputSize = 4, hiddenSize = 5; 11 12 // Dimensions required of required inputs. 13 const kValidInputShape = [batchSize, inputSize]; 14 const kValidWeightShape = [4 * hiddenSize, inputSize]; 15 const kValidRecurrentWeightShape = [4 * hiddenSize, hiddenSize]; 16 const kValidHiddenStateShape = [batchSize, hiddenSize]; 17 const kValidCellStateShape = [batchSize, hiddenSize]; 18 // Dimensions required of optional inputs. 19 const kValidBiasShape = [4 * hiddenSize]; 20 const kValidPeepholeWeightShape = [3 * hiddenSize]; 21 22 // Example descriptors which are valid according to the above dimensions. 23 const kExampleInputDescriptor = { 24 dataType: 'float32', 25 shape: kValidInputShape 26 }; 27 const kExampleWeightDescriptor = { 28 dataType: 'float32', 29 shape: kValidWeightShape 30 }; 31 const kExampleRecurrentWeightDescriptor = { 32 dataType: 'float32', 33 shape: kValidRecurrentWeightShape 34 }; 35 const kExampleHiddenStateDescriptor = { 36 dataType: 'float32', 37 shape: kValidHiddenStateShape 38 }; 39 const kExampleCellStateDescriptor = { 40 dataType: 'float32', 41 shape: kValidCellStateShape 42 }; 43 const kExampleBiasDescriptor = { 44 dataType: 'float32', 45 shape: kValidBiasShape 46 }; 47 const kExamplePeepholeWeightDescriptor = { 48 dataType: 'float32', 49 shape: kValidPeepholeWeightShape 50 }; 51 52 multi_builder_test(async (t, builder, otherBuilder) => { 53 const inputFromOtherBuilder = 54 otherBuilder.input('input', kExampleInputDescriptor); 55 56 const weight = builder.input('weight', kExampleWeightDescriptor); 57 const recurrentWeight = 58 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 59 const hiddenState = 60 builder.input('hiddenState', kExampleHiddenStateDescriptor); 61 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 62 assert_throws_js( 63 TypeError, 64 () => builder.lstmCell( 65 inputFromOtherBuilder, weight, recurrentWeight, hiddenState, 66 cellState, hiddenSize)); 67 }, '[lstmCell] throw if input is from another builder'); 68 69 multi_builder_test(async (t, builder, otherBuilder) => { 70 const weightFromOtherBuilder = 71 otherBuilder.input('weight', kExampleWeightDescriptor); 72 73 const input = builder.input('input', kExampleInputDescriptor); 74 const recurrentWeight = 75 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 76 const hiddenState = 77 builder.input('hiddenState', kExampleHiddenStateDescriptor); 78 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 79 assert_throws_js( 80 TypeError, 81 () => builder.lstmCell( 82 input, weightFromOtherBuilder, recurrentWeight, hiddenState, 83 cellState, hiddenSize)); 84 }, '[lstmCell] throw if weight is from another builder'); 85 86 multi_builder_test(async (t, builder, otherBuilder) => { 87 const recurrentWeightFromOtherBuilder = 88 otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 89 90 const input = builder.input('input', kExampleInputDescriptor); 91 const weight = builder.input('weight', kExampleWeightDescriptor); 92 const hiddenState = 93 builder.input('hiddenState', kExampleHiddenStateDescriptor); 94 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 95 assert_throws_js( 96 TypeError, 97 () => builder.lstmCell( 98 input, weight, recurrentWeightFromOtherBuilder, hiddenState, 99 cellState, hiddenSize)); 100 }, '[lstmCell] throw if recurrentWeight is from another builder'); 101 102 multi_builder_test(async (t, builder, otherBuilder) => { 103 const hiddenStateFromOtherBuilder = 104 otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor); 105 106 const input = builder.input('input', kExampleInputDescriptor); 107 const weight = builder.input('weight', kExampleWeightDescriptor); 108 const recurrentWeight = 109 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 110 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 111 assert_throws_js( 112 TypeError, 113 () => builder.lstmCell( 114 input, weight, recurrentWeight, hiddenStateFromOtherBuilder, 115 cellState, hiddenSize)); 116 }, '[lstmCell] throw if hiddenState is from another builder'); 117 118 multi_builder_test(async (t, builder, otherBuilder) => { 119 const cellStateFromOtherBuilder = 120 otherBuilder.input('cellState', kExampleCellStateDescriptor); 121 122 const input = builder.input('input', kExampleInputDescriptor); 123 const weight = builder.input('weight', kExampleWeightDescriptor); 124 const recurrentWeight = 125 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 126 const hiddenState = 127 builder.input('hiddenState', kExampleHiddenStateDescriptor); 128 assert_throws_js( 129 TypeError, 130 () => builder.lstmCell( 131 input, weight, recurrentWeight, hiddenState, 132 cellStateFromOtherBuilder, hiddenSize)); 133 }, '[lstmCell] throw if cellState is from another builder'); 134 135 multi_builder_test(async (t, builder, otherBuilder) => { 136 const biasFromOtherBuilder = 137 otherBuilder.input('bias', kExampleBiasDescriptor); 138 const options = {bias: biasFromOtherBuilder}; 139 140 const input = builder.input('input', kExampleInputDescriptor); 141 const weight = builder.input('weight', kExampleWeightDescriptor); 142 const recurrentWeight = 143 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 144 const hiddenState = 145 builder.input('hiddenState', kExampleHiddenStateDescriptor); 146 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 147 assert_throws_js( 148 TypeError, 149 () => builder.lstmCell( 150 input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, 151 options)); 152 }, '[lstmCell] throw if bias option is from another builder'); 153 154 multi_builder_test(async (t, builder, otherBuilder) => { 155 const recurrentBiasFromOtherBuilder = 156 otherBuilder.input('bias', kExampleBiasDescriptor); 157 const options = {recurrentBias: recurrentBiasFromOtherBuilder}; 158 159 const input = builder.input('input', kExampleInputDescriptor); 160 const weight = builder.input('weight', kExampleWeightDescriptor); 161 const recurrentWeight = 162 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 163 const hiddenState = 164 builder.input('hiddenState', kExampleHiddenStateDescriptor); 165 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 166 assert_throws_js( 167 TypeError, 168 () => builder.lstmCell( 169 input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, 170 options)); 171 }, '[lstmCell] throw if recurrentBias option is from another builder'); 172 173 multi_builder_test(async (t, builder, otherBuilder) => { 174 const peepholeWeightFromOtherBuilder = 175 otherBuilder.input('peepholeWeight', kExamplePeepholeWeightDescriptor); 176 const options = {peepholeWeight: peepholeWeightFromOtherBuilder}; 177 178 const input = builder.input('input', kExampleInputDescriptor); 179 const weight = builder.input('weight', kExampleWeightDescriptor); 180 const recurrentWeight = 181 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 182 const hiddenState = 183 builder.input('hiddenState', kExampleHiddenStateDescriptor); 184 const cellState = builder.input('cellState', kExampleCellStateDescriptor); 185 assert_throws_js( 186 TypeError, 187 () => builder.lstmCell( 188 input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, 189 options)); 190 }, '[lstmCell] throw if peepholeWeight option is from another builder'); 191 192 const tests = [ 193 { 194 name: '[lstmCell] Test with default options', 195 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 196 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 197 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 198 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 199 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 200 hiddenSize: hiddenSize, 201 outputs: [ 202 {dataType: 'float16', shape: [batchSize, hiddenSize]}, 203 {dataType: 'float16', shape: [batchSize, hiddenSize]} 204 ] 205 }, 206 { 207 name: '[lstmCell] Test with given options', 208 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 209 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 210 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 211 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 212 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 213 hiddenSize: hiddenSize, 214 options: { 215 bias: {dataType: 'float32', shape: [4 * hiddenSize]}, 216 recurrentBias: {dataType: 'float32', shape: [4 * hiddenSize]}, 217 peepholeWeight: {dataType: 'float32', shape: [3 * hiddenSize]}, 218 layout: 'ifgo', 219 activations: ['sigmoid', 'relu', 'tanh'] 220 }, 221 outputs: [ 222 {dataType: 'float32', shape: [batchSize, hiddenSize]}, 223 {dataType: 'float32', shape: [batchSize, hiddenSize]} 224 ] 225 }, 226 { 227 name: '[lstmCell] Throw if hiddenSize is equal to zero', 228 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 229 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 230 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 231 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 232 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 233 hiddenSize: 0 234 }, 235 { 236 name: '[lstmCell] Throw if hiddenSize is too large', 237 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 238 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 239 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 240 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 241 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 242 hiddenSize: 4294967295 243 }, 244 { 245 name: 246 '[lstmCell] Throw if the input data type is not one of the floating point types', 247 input: {dataType: 'uint32', shape: [batchSize, inputSize]}, 248 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 249 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 250 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 251 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 252 hiddenSize: hiddenSize 253 }, 254 { 255 name: '[lstmCell] Throw if the rank of input is not 2', 256 input: {dataType: 'float32', shape: [batchSize]}, 257 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 258 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 259 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 260 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 261 hiddenSize: hiddenSize 262 }, 263 { 264 name: '[lstmCell] Throw if the shape of input is incorrect', 265 input: {dataType: 'float32', shape: [batchSize, 1000]}, 266 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 267 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 268 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 269 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 270 hiddenSize: hiddenSize 271 }, 272 { 273 name: '[lstmCell] Throw if the data type of weight is incorrect', 274 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 275 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 276 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 277 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 278 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 279 hiddenSize: hiddenSize 280 }, 281 { 282 name: '[lstmCell] Throw if the rank of weight is not 2', 283 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 284 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize, 1000]}, 285 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 286 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 287 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 288 hiddenSize: hiddenSize 289 }, 290 { 291 name: '[lstmCell] Throw if the shape of weight is incorrect', 292 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 293 weight: {dataType: 'float32', shape: [1000, inputSize]}, 294 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 295 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 296 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 297 hiddenSize: hiddenSize 298 }, 299 { 300 name: '[lstmCell] Throw if the data type of recurrentWeight is incorrect', 301 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 302 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 303 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 304 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 305 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 306 hiddenSize: hiddenSize 307 }, 308 { 309 name: '[lstmCell] Throw if the rank of recurrentWeight is not 2', 310 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 311 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 312 recurrentWeight: 313 {dataType: 'float32', shape: [1000, 4 * hiddenSize, hiddenSize]}, 314 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 315 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 316 hiddenSize: hiddenSize 317 }, 318 { 319 name: '[lstmCell] Throw if the shape of recurrentWeight is incorrect', 320 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 321 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 322 recurrentWeight: {dataType: 'float32', shape: [1000, hiddenSize]}, 323 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 324 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 325 hiddenSize: hiddenSize 326 }, 327 { 328 name: '[lstmCell] Throw if the data type of hiddenState is incorrect', 329 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 330 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 331 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 332 hiddenState: {dataType: 'int64', shape: [batchSize, hiddenSize]}, 333 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 334 hiddenSize: hiddenSize 335 }, 336 { 337 name: '[lstmCell] Throw if the rank of hiddenState is not 2', 338 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 339 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 340 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 341 hiddenState: {dataType: 'float32', shape: [batchSize]}, 342 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 343 hiddenSize: hiddenSize 344 }, 345 { 346 name: '[lstmCell] Throw if the shape of hiddenState is incorrect', 347 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 348 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 349 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 350 hiddenState: {dataType: 'float32', shape: [batchSize, 1000]}, 351 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 352 hiddenSize: hiddenSize 353 }, 354 { 355 name: '[lstmCell] Throw if the data type of cellState is incorrect', 356 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 357 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 358 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 359 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 360 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 361 hiddenSize: hiddenSize 362 }, 363 { 364 name: '[lstmCell] Throw if the rank of cellState is not 2', 365 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 366 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 367 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 368 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 369 cellState: {dataType: 'float32', shape: [batchSize]}, 370 hiddenSize: hiddenSize 371 }, 372 { 373 name: '[lstmCell] Throw if the shape of cellState is incorrect', 374 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 375 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 376 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 377 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 378 cellState: {dataType: 'float16', shape: [batchSize, 1000]}, 379 hiddenSize: hiddenSize 380 }, 381 { 382 name: '[lstmCell] Throw if the data type of options.bias is incorrect', 383 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 384 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 385 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 386 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 387 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 388 hiddenSize: hiddenSize, 389 options: {bias: {dataType: 'int8', shape: [4 * hiddenSize]}} 390 }, 391 { 392 name: '[lstmCell] Throw if the rank of options.bias is not 1', 393 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 394 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 395 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 396 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 397 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 398 hiddenSize: hiddenSize, 399 options: {bias: {dataType: 'float16', shape: [4 * hiddenSize, 1000]}} 400 }, 401 { 402 name: '[lstmCell] Throw if the shape of options.bias is incorrect', 403 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 404 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 405 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 406 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 407 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 408 hiddenSize: hiddenSize, 409 options: {bias: {dataType: 'float16', shape: [1000]}} 410 }, 411 { 412 name: 413 '[lstmCell] Throw if the data type of options.recurrentBias is incorrect', 414 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 415 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 416 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 417 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 418 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 419 hiddenSize: hiddenSize, 420 options: {recurrentBias: {dataType: 'uint8', shape: [4 * hiddenSize]}} 421 }, 422 { 423 name: '[lstmCell] Throw if the rank of options.recurrentBias is not 1', 424 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 425 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 426 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 427 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 428 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 429 hiddenSize: hiddenSize, 430 options: 431 {recurrentBias: {dataType: 'float16', shape: [4 * hiddenSize, 1000]}} 432 }, 433 { 434 name: '[lstmCell] Throw if the shape of options.recurrentBias is incorrect', 435 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 436 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 437 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 438 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 439 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 440 hiddenSize: hiddenSize, 441 options: {recurrentBias: {dataType: 'float16', shape: [1000]}} 442 }, 443 { 444 name: 445 '[lstmCell] Throw if the data type of options.peepholeWeight is incorrect', 446 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 447 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 448 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 449 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 450 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 451 hiddenSize: hiddenSize, 452 options: {peepholeWeight: {dataType: 'float32', shape: [3 * hiddenSize]}} 453 }, 454 { 455 name: '[lstmCell] Throw if the rank of options.peepholeWeight is not 1', 456 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 457 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 458 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 459 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 460 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 461 hiddenSize: hiddenSize, 462 options: {peepholeWeight: {dataType: 'float16', shape: []}} 463 }, 464 { 465 name: 466 '[lstmCell] Throw if the shape of options.peepholeWeight is incorrect', 467 input: {dataType: 'float16', shape: [batchSize, inputSize]}, 468 weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]}, 469 recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]}, 470 hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 471 cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]}, 472 hiddenSize: hiddenSize, 473 options: {peepholeWeight: {dataType: 'float16', shape: [1000]}} 474 }, 475 { 476 name: '[lstmCell] Throw if the size of options.activations is not 3', 477 input: {dataType: 'float32', shape: [batchSize, inputSize]}, 478 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 479 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]}, 480 hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 481 cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]}, 482 hiddenSize: hiddenSize, 483 options: {activations: ['sigmoid', 'tanh', 'sigmoid', 'tanh']} 484 } 485 ]; 486 487 tests.forEach( 488 test => promise_test(async t => { 489 const builder = new MLGraphBuilder(context); 490 const input = builder.input('input', test.input); 491 const weight = builder.input('weight', test.weight); 492 const recurrentWeight = 493 builder.input('recurrentWeight', test.recurrentWeight); 494 const hiddenState = builder.input('hiddenState', test.hiddenState); 495 const cellState = builder.input('cellState', test.cellState); 496 497 const options = {}; 498 if (test.options) { 499 if (test.options.bias) { 500 options.bias = builder.input('bias', test.options.bias); 501 } 502 if (test.options.recurrentBias) { 503 options.recurrentBias = 504 builder.input('recurrentBias', test.options.recurrentBias); 505 } 506 if (test.options.peepholeWeight) { 507 options.peepholeWeight = 508 builder.input('peepholeWeight', test.options.peepholeWeight); 509 } 510 if (test.options.layout) { 511 options.layout = test.options.layout; 512 } 513 if (test.options.activations) { 514 options.activations = test.options.activations; 515 } 516 } 517 518 if (test.outputs && 519 context.opSupportLimits().lstmCell.input.dataTypes.includes( 520 test.input.dataType)) { 521 const outputs = builder.lstmCell( 522 input, weight, recurrentWeight, hiddenState, cellState, 523 test.hiddenSize, options); 524 assert_equals(outputs.length, test.outputs.length); 525 for (let i = 0; i < outputs.length; ++i) { 526 assert_equals(outputs[i].dataType, test.outputs[i].dataType); 527 assert_array_equals(outputs[i].shape, test.outputs[i].shape); 528 } 529 } else { 530 const label = 'lstm_cell_xxx'; 531 options.label = label; 532 const regrexp = new RegExp('\\[' + label + '\\]'); 533 assert_throws_with_label( 534 () => builder.lstmCell( 535 input, weight, recurrentWeight, hiddenState, cellState, 536 test.hiddenSize, options), 537 regrexp); 538 } 539 }, test.name));