gru.https.any.js (14272B)
1 // META: title=validation tests for WebNN API gru operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const steps = 2, batchSize = 3, inputSize = 4, hiddenSize = 5, oneDirection = 1, 11 bothDirections = 2; 12 13 // Dimensions required of required inputs. 14 const kValidInputShape = [steps, batchSize, inputSize]; 15 const kValidWeightShape = [oneDirection, 3 * hiddenSize, inputSize]; 16 const kValidRecurrentWeightShape = [oneDirection, 3 * hiddenSize, hiddenSize]; 17 // Dimensions required of optional inputs. 18 const kValidBiasShape = [oneDirection, 3 * hiddenSize]; 19 const kValidRecurrentBiasShape = [oneDirection, 3 * hiddenSize]; 20 const kValidInitialHiddenStateShape = [oneDirection, batchSize, hiddenSize]; 21 22 // Example descriptors which are valid according to the above dimensions. 23 const kExampleInputDescriptor = { 24 dataType: 'float32', 25 shape: kValidInputShape 26 }; 27 const kExampleWeightDescriptor = { 28 dataType: 'float32', 29 shape: kValidWeightShape 30 }; 31 const kExampleRecurrentWeightDescriptor = { 32 dataType: 'float32', 33 shape: kValidRecurrentWeightShape 34 }; 35 const kExampleBiasDescriptor = { 36 dataType: 'float32', 37 shape: kValidBiasShape 38 }; 39 const kExampleRecurrentBiasDescriptor = { 40 dataType: 'float32', 41 shape: kValidRecurrentBiasShape 42 }; 43 const kExampleInitialHiddenStateDescriptor = { 44 dataType: 'float32', 45 shape: kValidInitialHiddenStateShape 46 }; 47 48 const tests = [ 49 { 50 name: '[gru] Test with default options', 51 input: kExampleInputDescriptor, 52 weight: kExampleWeightDescriptor, 53 recurrentWeight: kExampleRecurrentWeightDescriptor, 54 steps: steps, 55 hiddenSize: hiddenSize, 56 outputs: 57 [{dataType: 'float32', shape: [oneDirection, batchSize, hiddenSize]}] 58 }, 59 { 60 name: '[gru] Test with given options', 61 input: kExampleInputDescriptor, 62 weight: { 63 dataType: 'float32', 64 shape: [bothDirections, 3 * hiddenSize, inputSize] 65 }, 66 recurrentWeight: { 67 dataType: 'float32', 68 shape: [bothDirections, 3 * hiddenSize, hiddenSize] 69 }, 70 steps: steps, 71 hiddenSize: hiddenSize, 72 options: { 73 bias: {dataType: 'float32', shape: [bothDirections, 3 * hiddenSize]}, 74 recurrentBias: 75 {dataType: 'float32', shape: [bothDirections, 3 * hiddenSize]}, 76 initialHiddenState: 77 {dataType: 'float32', shape: [bothDirections, batchSize, hiddenSize]}, 78 restAfter: true, 79 returnSequence: true, 80 direction: 'both', 81 layout: 'rzn', 82 activations: ['sigmoid', 'relu'] 83 }, 84 outputs: [ 85 {dataType: 'float32', shape: [bothDirections, batchSize, hiddenSize]}, { 86 dataType: 'float32', 87 shape: [steps, bothDirections, batchSize, hiddenSize] 88 } 89 ] 90 }, 91 { 92 name: '[gru] TypeError is expected if steps equals to zero', 93 input: kExampleInputDescriptor, 94 weight: 95 {dataType: 'float32', shape: [oneDirection, 4 * hiddenSize, inputSize]}, 96 recurrentWeight: { 97 dataType: 'float32', 98 shape: [oneDirection, 4 * hiddenSize, hiddenSize] 99 }, 100 steps: 0, 101 hiddenSize: hiddenSize, 102 }, 103 { 104 name: '[gru] TypeError is expected if hiddenSize equals to zero', 105 input: kExampleInputDescriptor, 106 weight: kExampleWeightDescriptor, 107 recurrentWeight: kExampleRecurrentWeightDescriptor, 108 steps: steps, 109 hiddenSize: 0 110 }, 111 { 112 name: '[gru] TypeError is expected if hiddenSize is too large', 113 input: kExampleInputDescriptor, 114 weight: kExampleWeightDescriptor, 115 recurrentWeight: kExampleRecurrentWeightDescriptor, 116 steps: steps, 117 hiddenSize: 4294967295, 118 }, 119 { 120 name: 121 '[gru] TypeError is expected if the data type of the inputs is not one of the floating point types', 122 input: {dataType: 'uint32', shape: kValidInputShape}, 123 weight: {dataType: 'uint32', shape: kValidWeightShape}, 124 recurrentWeight: {dataType: 'uint32', shape: kValidRecurrentWeightShape}, 125 steps: steps, 126 hiddenSize: hiddenSize 127 }, 128 { 129 name: '[gru] TypeError is expected if the rank of input is not 3', 130 input: {dataType: 'float32', shape: [steps, batchSize]}, 131 weight: kExampleWeightDescriptor, 132 recurrentWeight: kExampleRecurrentWeightDescriptor, 133 steps: steps, 134 hiddenSize: hiddenSize 135 }, 136 { 137 name: '[gru] TypeError is expected if input.shape[0] is not equal to steps', 138 input: {dataType: 'float32', shape: [1000, batchSize, inputSize]}, 139 weight: kExampleWeightDescriptor, 140 recurrentWeight: kExampleRecurrentWeightDescriptor, 141 steps: steps, 142 hiddenSize: hiddenSize 143 }, 144 { 145 name: 146 '[gru] TypeError is expected if weight.shape[1] is not 3 * hiddenSize', 147 input: kExampleInputDescriptor, 148 weight: 149 {dataType: 'float32', shape: [oneDirection, 4 * hiddenSize, inputSize]}, 150 recurrentWeight: kExampleRecurrentWeightDescriptor, 151 steps: steps, 152 hiddenSize: hiddenSize 153 }, 154 { 155 name: '[gru] TypeError is expected if the rank of recurrentWeight is not 3', 156 input: kExampleInputDescriptor, 157 weight: kExampleWeightDescriptor, 158 recurrentWeight: 159 {dataType: 'float32', shape: [oneDirection, 3 * hiddenSize]}, 160 steps: steps, 161 hiddenSize: hiddenSize 162 }, 163 { 164 name: '[gru] TypeError is expected if the recurrentWeight.shape is invalid', 165 input: kExampleInputDescriptor, 166 weight: kExampleWeightDescriptor, 167 recurrentWeight: 168 {dataType: 'float32', shape: [oneDirection, 4 * hiddenSize, inputSize]}, 169 steps: steps, 170 hiddenSize: hiddenSize 171 }, 172 { 173 name: 174 '[gru] TypeError is expected if the size of options.activations is not 2', 175 input: kExampleInputDescriptor, 176 weight: kExampleWeightDescriptor, 177 recurrentWeight: kExampleRecurrentWeightDescriptor, 178 steps: steps, 179 hiddenSize: hiddenSize, 180 options: {activations: ['sigmoid', 'tanh', 'relu']} 181 }, 182 { 183 name: '[gru] TypeError is expected if the rank of options.bias is not 2', 184 input: kExampleInputDescriptor, 185 weight: kExampleWeightDescriptor, 186 recurrentWeight: kExampleRecurrentWeightDescriptor, 187 steps: steps, 188 hiddenSize: hiddenSize, 189 options: {bias: {dataType: 'float32', shape: [oneDirection]}} 190 }, 191 { 192 name: 193 '[gru] TypeError is expected if options.bias.shape[1] is not 3 * hiddenSize', 194 input: kExampleInputDescriptor, 195 weight: kExampleWeightDescriptor, 196 recurrentWeight: kExampleRecurrentWeightDescriptor, 197 steps: steps, 198 hiddenSize: hiddenSize, 199 options: {bias: {dataType: 'float32', shape: [oneDirection, hiddenSize]}} 200 }, 201 { 202 name: 203 '[gru] TypeError is expected if options.recurrentBias.shape[1] is not 3 * hiddenSize', 204 input: {dataType: 'float16', shape: kValidInputShape}, 205 weight: {dataType: 'float16', shape: kValidWeightShape}, 206 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 207 steps: steps, 208 hiddenSize: hiddenSize, 209 options: { 210 recurrentBias: 211 {dataType: 'float16', shape: [oneDirection, 4 * hiddenSize]} 212 } 213 }, 214 { 215 name: 216 '[gru] TypeError is expected if the rank of options.initialHiddenState is not 3', 217 input: {dataType: 'float16', shape: kValidInputShape}, 218 weight: {dataType: 'float16', shape: kValidWeightShape}, 219 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 220 steps: steps, 221 hiddenSize: hiddenSize, 222 options: { 223 initialHiddenState: 224 {dataType: 'float16', shape: [oneDirection, batchSize]} 225 } 226 }, 227 { 228 name: 229 '[gru] TypeError is expected if options.initialHiddenState.shape[2] is not inputSize', 230 input: {dataType: 'float16', shape: kValidInputShape}, 231 weight: {dataType: 'float16', shape: kValidWeightShape}, 232 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 233 steps: steps, 234 hiddenSize: hiddenSize, 235 options: { 236 initialHiddenState: { 237 dataType: 'float16', 238 shape: [oneDirection, batchSize, 3 * hiddenSize] 239 } 240 } 241 }, 242 { 243 name: 244 '[gru] TypeError is expected if the dataType of options.initialHiddenState is incorrect', 245 input: {dataType: 'float16', shape: kValidInputShape}, 246 weight: {dataType: 'float16', shape: kValidWeightShape}, 247 recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape}, 248 steps: steps, 249 hiddenSize: hiddenSize, 250 options: { 251 initialHiddenState: 252 {dataType: 'uint64', shape: [oneDirection, batchSize, hiddenSize]} 253 } 254 } 255 ]; 256 257 tests.forEach( 258 test => promise_test(async t => { 259 const builder = new MLGraphBuilder(context); 260 const input = builder.input('input', test.input); 261 const weight = builder.input('weight', test.weight); 262 const recurrentWeight = 263 builder.input('recurrentWeight', test.recurrentWeight); 264 265 const options = {}; 266 if (test.options) { 267 if (test.options.bias) { 268 options.bias = builder.input('bias', test.options.bias); 269 } 270 if (test.options.recurrentBias) { 271 options.recurrentBias = 272 builder.input('recurrentBias', test.options.recurrentBias); 273 } 274 if (test.options.initialHiddenState) { 275 options.initialHiddenState = builder.input( 276 'initialHiddenState', test.options.initialHiddenState); 277 } 278 if (test.options.resetAfter) { 279 options.resetAfter = test.options.resetAfter; 280 } 281 if (test.options.returnSequence) { 282 options.returnSequence = test.options.returnSequence; 283 } 284 if (test.options.direction) { 285 options.direction = test.options.direction; 286 } 287 if (test.options.layout) { 288 options.layout = test.options.layout; 289 } 290 if (test.options.activations) { 291 options.activations = test.options.activations; 292 } 293 } 294 295 if (test.outputs && 296 context.opSupportLimits().gru.input.dataTypes.includes( 297 test.input.dataType)) { 298 const outputs = builder.gru( 299 input, weight, recurrentWeight, test.steps, test.hiddenSize, 300 options); 301 assert_equals(outputs.length, test.outputs.length); 302 for (let i = 0; i < outputs.length; ++i) { 303 assert_equals(outputs[i].dataType, test.outputs[i].dataType); 304 assert_array_equals(outputs[i].shape, test.outputs[i].shape); 305 } 306 } else { 307 const label = 'gru_xxx'; 308 options.label = label; 309 const regrexp = new RegExp('\\[' + label + '\\]'); 310 assert_throws_with_label( 311 () => builder.gru( 312 input, weight, recurrentWeight, test.steps, test.hiddenSize, 313 options), 314 regrexp); 315 } 316 }, test.name)); 317 318 multi_builder_test(async (t, builder, otherBuilder) => { 319 const inputFromOtherBuilder = 320 otherBuilder.input('input', kExampleInputDescriptor); 321 322 const weight = builder.input('weight', kExampleWeightDescriptor); 323 const recurrentWeight = 324 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 325 assert_throws_js( 326 TypeError, 327 () => builder.gru( 328 inputFromOtherBuilder, weight, recurrentWeight, steps, hiddenSize)); 329 }, '[gru] throw if input is from another builder'); 330 331 multi_builder_test(async (t, builder, otherBuilder) => { 332 const weightFromOtherBuilder = 333 otherBuilder.input('weight', kExampleWeightDescriptor); 334 335 const input = builder.input('input', kExampleInputDescriptor); 336 const recurrentWeight = 337 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 338 assert_throws_js( 339 TypeError, 340 () => builder.gru( 341 input, weightFromOtherBuilder, recurrentWeight, steps, hiddenSize)); 342 }, '[gru] throw if weight is from another builder'); 343 344 multi_builder_test(async (t, builder, otherBuilder) => { 345 const recurrentWeightFromOtherBuilder = 346 otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 347 348 const input = builder.input('input', kExampleInputDescriptor); 349 const weight = builder.input('weight', kExampleWeightDescriptor); 350 assert_throws_js( 351 TypeError, 352 () => builder.gru( 353 input, weight, recurrentWeightFromOtherBuilder, steps, hiddenSize)); 354 }, '[gru] throw if recurrentWeight is from another builder'); 355 356 multi_builder_test(async (t, builder, otherBuilder) => { 357 const biasFromOtherBuilder = 358 otherBuilder.input('bias', kExampleBiasDescriptor); 359 const options = {bias: biasFromOtherBuilder}; 360 361 const input = builder.input('input', kExampleInputDescriptor); 362 const weight = builder.input('weight', kExampleWeightDescriptor); 363 const recurrentWeight = 364 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 365 assert_throws_js( 366 TypeError, 367 () => builder.gru( 368 input, weight, recurrentWeight, steps, hiddenSize, options)); 369 }, '[gru] throw if bias option is from another builder'); 370 371 multi_builder_test(async (t, builder, otherBuilder) => { 372 const recurrentBiasFromOtherBuilder = 373 otherBuilder.input('recurrentBias', kExampleRecurrentBiasDescriptor); 374 const options = {recurrentBias: recurrentBiasFromOtherBuilder}; 375 376 const input = builder.input('input', kExampleInputDescriptor); 377 const weight = builder.input('weight', kExampleWeightDescriptor); 378 const recurrentWeight = 379 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 380 assert_throws_js( 381 TypeError, 382 () => builder.gru( 383 input, weight, recurrentWeight, steps, hiddenSize, options)); 384 }, '[gru] throw if recurrentBias option is from another builder'); 385 386 multi_builder_test(async (t, builder, otherBuilder) => { 387 const initialHiddenStateFromOtherBuilder = otherBuilder.input( 388 'initialHiddenState', kExampleInitialHiddenStateDescriptor); 389 const options = {initialHiddenState: initialHiddenStateFromOtherBuilder}; 390 391 const input = builder.input('input', kExampleInputDescriptor); 392 const weight = builder.input('weight', kExampleWeightDescriptor); 393 const recurrentWeight = 394 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 395 assert_throws_js( 396 TypeError, 397 () => builder.gru( 398 input, weight, recurrentWeight, steps, hiddenSize, options)); 399 }, '[gru] throw if initialHiddenState option is from another builder');