lstm.https.any.js (15605B)
1 // META: title=validation tests for WebNN API lstm operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const steps = 10, batchSize = 5, inputSize = 3, hiddenSize = 8, 11 numDirections = 1; 12 13 // Dimensions required of required inputs. 14 const kValidInputShape = [steps, batchSize, inputSize]; 15 const kValidWeightShape = [numDirections, 4 * hiddenSize, inputSize]; 16 const kValidRecurrentWeightShape = [numDirections, 4 * hiddenSize, hiddenSize]; 17 // Dimensions required of optional inputs. 18 const kValidBiasShape = [numDirections, 4 * hiddenSize]; 19 const kValidPeepholeWeightShape = [numDirections, 3 * hiddenSize]; 20 const kValidInitialHiddenStateShape = [numDirections, batchSize, hiddenSize]; 21 22 // Example descriptors which are valid according to the above dimensions. 23 const kExampleInputDescriptor = { 24 dataType: 'float32', 25 shape: kValidInputShape 26 }; 27 const kExampleWeightDescriptor = { 28 dataType: 'float32', 29 shape: kValidWeightShape 30 }; 31 const kExampleRecurrentWeightDescriptor = { 32 dataType: 'float32', 33 shape: kValidRecurrentWeightShape 34 }; 35 const kExampleBiasDescriptor = { 36 dataType: 'float32', 37 shape: kValidBiasShape 38 }; 39 const kExamplePeepholeWeightDescriptor = { 40 dataType: 'float32', 41 shape: kValidPeepholeWeightShape 42 }; 43 const kExampleInitialHiddenStateDescriptor = { 44 dataType: 'float32', 45 shape: kValidInitialHiddenStateShape 46 }; 47 48 const tests = [ 49 { 50 name: '[lstm] Test with default options', 51 input: {dataType: 'float16', shape: kValidInputShape}, 52 weight: {dataType: 'float16', shape: kValidWeightShape}, 53 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 54 steps: steps, 55 hiddenSize: hiddenSize, 56 outputs: [ 57 {dataType: 'float16', shape: [numDirections, batchSize, hiddenSize]}, 58 {dataType: 'float16', shape: [numDirections, batchSize, hiddenSize]} 59 ] 60 }, 61 { 62 name: '[lstm] Test with given options', 63 input: kExampleInputDescriptor, 64 weight: { 65 dataType: 'float32', 66 shape: [/*numDirections=*/ 2, 4 * hiddenSize, inputSize] 67 }, 68 recurrentWeight: { 69 dataType: 'float32', 70 shape: [/*numDirections=*/ 2, 4 * hiddenSize, hiddenSize] 71 }, 72 steps: steps, 73 hiddenSize: hiddenSize, 74 options: { 75 bias: 76 {dataType: 'float32', shape: [/*numDirections=*/ 2, 4 * hiddenSize]}, 77 recurrentBias: 78 {dataType: 'float32', shape: [/*numDirections=*/ 2, 4 * hiddenSize]}, 79 peepholeWeight: 80 {dataType: 'float32', shape: [/*numDirections=*/ 2, 3 * hiddenSize]}, 81 initialHiddenState: { 82 dataType: 'float32', 83 shape: [/*numDirections=*/ 2, batchSize, hiddenSize] 84 }, 85 initialCellState: { 86 dataType: 'float32', 87 shape: [/*numDirections=*/ 2, batchSize, hiddenSize] 88 }, 89 returnSequence: true, 90 direction: 'both', 91 layout: 'ifgo', 92 activations: ['sigmoid', 'relu', 'tanh'] 93 }, 94 outputs: [ 95 { 96 dataType: 'float32', 97 shape: [/*numDirections=*/ 2, batchSize, hiddenSize] 98 }, 99 { 100 dataType: 'float32', 101 shape: [/*numDirections=*/ 2, batchSize, hiddenSize] 102 }, 103 { 104 dataType: 'float32', 105 shape: [steps, /*numDirections=*/ 2, batchSize, hiddenSize] 106 } 107 ] 108 }, 109 { 110 name: '[lstm] TypeError is expected if hiddenSize equals to zero', 111 input: kExampleInputDescriptor, 112 weight: kExampleWeightDescriptor, 113 recurrentWeight: kExampleRecurrentWeightDescriptor, 114 steps: steps, 115 hiddenSize: 0 116 }, 117 { 118 name: '[lstm] TypeError is expected if hiddenSize is too large', 119 input: kExampleInputDescriptor, 120 weight: kExampleWeightDescriptor, 121 recurrentWeight: kExampleRecurrentWeightDescriptor, 122 steps: steps, 123 hiddenSize: 4294967295, 124 }, 125 { 126 name: '[lstm] TypeError is expected if steps equals to zero', 127 input: kExampleInputDescriptor, 128 weight: kExampleWeightDescriptor, 129 recurrentWeight: kExampleRecurrentWeightDescriptor, 130 steps: 0, 131 hiddenSize: hiddenSize, 132 }, 133 { 134 name: 135 '[lstm] TypeError is expected if the data type is not one of the floating point types', 136 input: {dataType: 'uint32', shape: kValidInputShape}, 137 weight: {dataType: 'uint32', shape: kValidWeightShape}, 138 recurrentWeight: {dataType: 'uint32', shape: kValidRecurrentWeightShape}, 139 steps: steps, 140 hiddenSize: hiddenSize 141 }, 142 { 143 name: '[lstm] TypeError is expected if the rank of input is not 3', 144 input: {dataType: 'float32', shape: [steps, batchSize]}, 145 weight: kExampleWeightDescriptor, 146 recurrentWeight: kExampleRecurrentWeightDescriptor, 147 steps: steps, 148 hiddenSize: hiddenSize 149 }, 150 { 151 name: 152 '[lstm] TypeError is expected if input.shape[0] is not equal to steps', 153 input: {dataType: 'float32', shape: [1000, batchSize, inputSize]}, 154 weight: kExampleWeightDescriptor, 155 recurrentWeight: kExampleRecurrentWeightDescriptor, 156 steps: steps, 157 hiddenSize: hiddenSize 158 }, 159 { 160 name: '[lstm] TypeError is expected if the shape of weight is incorrect', 161 input: kExampleInputDescriptor, 162 weight: {dataType: 'float32', shape: [numDirections, 4 * hiddenSize, 1000]}, 163 recurrentWeight: kExampleRecurrentWeightDescriptor, 164 steps: steps, 165 hiddenSize: hiddenSize 166 }, 167 { 168 name: 169 '[lstm] TypeError is expected if the rank of recurrentWeight is not 3', 170 input: kExampleInputDescriptor, 171 weight: kExampleWeightDescriptor, 172 recurrentWeight: 173 {dataType: 'float32', shape: [numDirections, 4 * hiddenSize]}, 174 steps: steps, 175 hiddenSize: hiddenSize 176 }, 177 { 178 name: 179 '[lstm] TypeError is expected if the size of options.activations is not 3', 180 input: kExampleInputDescriptor, 181 weight: kExampleWeightDescriptor, 182 recurrentWeight: kExampleRecurrentWeightDescriptor, 183 steps: steps, 184 hiddenSize: hiddenSize, 185 options: {activations: ['sigmoid', 'tanh']} 186 }, 187 { 188 name: '[lstm] TypeError is expected if the rank of options.bias is not 2', 189 input: {dataType: 'float16', shape: kValidInputShape}, 190 weight: {dataType: 'float16', shape: kValidWeightShape}, 191 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 192 steps: steps, 193 hiddenSize: hiddenSize, 194 options: {bias: {dataType: 'float16', shape: [numDirections]}} 195 }, 196 { 197 name: 198 '[lstm] TypeError is expected if the shape of options.recurrentBias.shape is incorrect', 199 input: {dataType: 'float16', shape: kValidInputShape}, 200 weight: {dataType: 'float16', shape: kValidWeightShape}, 201 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 202 steps: steps, 203 hiddenSize: hiddenSize, 204 options: 205 {recurrentBias: {dataType: 'float16', shape: [numDirections, 1000]}} 206 }, 207 { 208 name: 209 '[lstm] TypeError is expected if the dataType of options.peepholeWeight is incorrect', 210 input: {dataType: 'float16', shape: kValidInputShape}, 211 weight: {dataType: 'float16', shape: kValidWeightShape}, 212 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 213 steps: steps, 214 hiddenSize: hiddenSize, 215 options: { 216 peepholeWeight: 217 {dataType: 'float32', shape: [numDirections, 3 * hiddenSize]} 218 } 219 }, 220 { 221 name: 222 '[lstm] TypeError is expected if the dataType of options.initialHiddenState is incorrect', 223 input: {dataType: 'float16', shape: kValidInputShape}, 224 weight: {dataType: 'float16', shape: kValidWeightShape}, 225 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 226 steps: steps, 227 hiddenSize: hiddenSize, 228 options: { 229 initialHiddenState: 230 {dataType: 'uint64', shape: [numDirections, batchSize, hiddenSize]} 231 } 232 }, 233 { 234 name: 235 '[lstm] TypeError is expected if the shape of options.initialCellState is incorrect', 236 input: kExampleInputDescriptor, 237 weight: kExampleWeightDescriptor, 238 recurrentWeight: kExampleRecurrentWeightDescriptor, 239 steps: steps, 240 hiddenSize: hiddenSize, 241 options: { 242 initialCellState: 243 {dataType: 'float32', shape: [numDirections, batchSize, 1000]} 244 } 245 } 246 ]; 247 248 tests.forEach( 249 test => promise_test(async t => { 250 const builder = new MLGraphBuilder(context); 251 const input = builder.input('input', test.input); 252 const weight = builder.input('weight', test.weight); 253 const recurrentWeight = 254 builder.input('recurrentWeight', test.recurrentWeight); 255 256 const options = {}; 257 if (test.options) { 258 if (test.options.bias) { 259 options.bias = builder.input('bias', test.options.bias); 260 } 261 if (test.options.recurrentBias) { 262 options.recurrentBias = 263 builder.input('recurrentBias', test.options.recurrentBias); 264 } 265 if (test.options.peepholeWeight) { 266 options.peepholeWeight = 267 builder.input('peepholeWeight', test.options.peepholeWeight); 268 } 269 if (test.options.initialHiddenState) { 270 options.initialHiddenState = builder.input( 271 'initialHiddenState', test.options.initialHiddenState); 272 } 273 if (test.options.initialCellState) { 274 options.initialCellState = 275 builder.input('initialCellState', test.options.initialCellState); 276 } 277 if (test.options.returnSequence) { 278 options.returnSequence = test.options.returnSequence; 279 } 280 if (test.options.direction) { 281 options.direction = test.options.direction; 282 } 283 if (test.options.layout) { 284 options.layout = test.options.layout; 285 } 286 if (test.options.activations) { 287 options.activations = test.options.activations; 288 } 289 } 290 291 if (test.outputs && 292 context.opSupportLimits().lstm.input.dataTypes.includes( 293 test.input.dataType)) { 294 const outputs = builder.lstm( 295 input, weight, recurrentWeight, test.steps, test.hiddenSize, 296 options); 297 assert_equals(outputs.length, test.outputs.length); 298 for (let i = 0; i < outputs.length; ++i) { 299 assert_equals(outputs[i].dataType, test.outputs[i].dataType); 300 assert_array_equals(outputs[i].shape, test.outputs[i].shape); 301 } 302 } else { 303 const label = 'lstm_xxx'; 304 options.label = label; 305 const regrexp = new RegExp('\\[' + label + '\\]'); 306 assert_throws_with_label( 307 () => builder.lstm( 308 input, weight, recurrentWeight, test.steps, test.hiddenSize, 309 options), 310 regrexp); 311 } 312 }, test.name)); 313 314 multi_builder_test(async (t, builder, otherBuilder) => { 315 const inputFromOtherBuilder = 316 otherBuilder.input('input', kExampleInputDescriptor); 317 const weight = builder.input('weight', kExampleWeightDescriptor); 318 const recurrentWeight = 319 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 320 321 assert_throws_js( 322 TypeError, 323 () => builder.lstm( 324 inputFromOtherBuilder, weight, recurrentWeight, steps, hiddenSize)); 325 }, '[lstm] throw if input is from another builder'); 326 327 multi_builder_test(async (t, builder, otherBuilder) => { 328 const input = builder.input('input', kExampleInputDescriptor); 329 const weightFromOtherBuilder = 330 otherBuilder.input('weight', kExampleWeightDescriptor); 331 const recurrentWeight = 332 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 333 334 assert_throws_js( 335 TypeError, 336 () => builder.lstm( 337 input, weightFromOtherBuilder, recurrentWeight, steps, hiddenSize)); 338 }, '[lstm] throw if weight is from another builder'); 339 340 341 multi_builder_test(async (t, builder, otherBuilder) => { 342 const input = builder.input('input', kExampleInputDescriptor); 343 const weight = builder.input('weight', kExampleWeightDescriptor); 344 const recurrentWeightFromOtherBuilder = 345 otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 346 347 assert_throws_js( 348 TypeError, 349 () => builder.lstm( 350 input, weight, recurrentWeightFromOtherBuilder, steps, hiddenSize)); 351 }, '[lstm] throw if recurrentWeight is from another builder'); 352 353 multi_builder_test(async (t, builder, otherBuilder) => { 354 const biasFromOtherBuilder = 355 otherBuilder.input('bias', kExampleBiasDescriptor); 356 const options = {bias: biasFromOtherBuilder}; 357 358 const input = builder.input('input', kExampleInputDescriptor); 359 const weight = builder.input('weight', kExampleWeightDescriptor); 360 const recurrentWeight = 361 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 362 assert_throws_js( 363 TypeError, 364 () => builder.lstm( 365 input, weight, recurrentWeight, steps, hiddenSize, options)); 366 }, '[lstm] throw if bias option is from another builder'); 367 368 multi_builder_test(async (t, builder, otherBuilder) => { 369 const recurrentBiasFromOtherBuilder = 370 otherBuilder.input('bias', kExampleBiasDescriptor); 371 const options = {recurrentBias: recurrentBiasFromOtherBuilder}; 372 373 const input = builder.input('input', kExampleInputDescriptor); 374 const weight = builder.input('weight', kExampleWeightDescriptor); 375 const recurrentWeight = 376 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 377 assert_throws_js( 378 TypeError, 379 () => builder.lstm( 380 input, weight, recurrentWeight, steps, hiddenSize, options)); 381 }, '[lstm] throw if recurrentBias option is from another builder'); 382 383 multi_builder_test(async (t, builder, otherBuilder) => { 384 const peepholeWeightFromOtherBuilder = 385 otherBuilder.input('peepholeWeight', kExamplePeepholeWeightDescriptor); 386 const options = {peepholeWeight: peepholeWeightFromOtherBuilder}; 387 388 const input = builder.input('input', kExampleInputDescriptor); 389 const weight = builder.input('weight', kExampleWeightDescriptor); 390 const recurrentWeight = 391 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 392 assert_throws_js( 393 TypeError, 394 () => builder.lstm( 395 input, weight, recurrentWeight, steps, hiddenSize, options)); 396 }, '[lstm] throw if peepholeWeight option is from another builder'); 397 398 multi_builder_test(async (t, builder, otherBuilder) => { 399 const initialHiddenStateFromOtherBuilder = otherBuilder.input( 400 'initialHiddenState', kExampleInitialHiddenStateDescriptor); 401 const options = {initialHiddenState: initialHiddenStateFromOtherBuilder}; 402 403 const input = builder.input('input', kExampleInputDescriptor); 404 const weight = builder.input('weight', kExampleWeightDescriptor); 405 const recurrentWeight = 406 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 407 assert_throws_js( 408 TypeError, 409 () => builder.lstm( 410 input, weight, recurrentWeight, steps, hiddenSize, options)); 411 }, '[lstm] throw if initialHiddenState option is from another builder'); 412 413 multi_builder_test(async (t, builder, otherBuilder) => { 414 const initialCellStateFromOtherBuilder = otherBuilder.input( 415 'initialCellState', kExampleInitialHiddenStateDescriptor); 416 const options = {initialCellState: initialCellStateFromOtherBuilder}; 417 418 const input = builder.input('input', kExampleInputDescriptor); 419 const weight = builder.input('weight', kExampleWeightDescriptor); 420 const recurrentWeight = 421 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 422 assert_throws_js( 423 TypeError, 424 () => builder.lstm( 425 input, weight, recurrentWeight, steps, hiddenSize, options)); 426 }, '[lstm] throw if initialCellState option is from another builder');