gruCell.https.any.js (15302B)
1 // META: title=validation tests for WebNN API gruCell operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const batchSize = 3, inputSize = 4, hiddenSize = 5; 11 12 // Dimensions required of required inputs. 13 const kValidInputShape = [batchSize, inputSize]; 14 const kValidWeightShape = [3 * hiddenSize, inputSize]; 15 const kValidRecurrentWeightShape = [3 * hiddenSize, hiddenSize]; 16 const kValidHiddenStateShape = [batchSize, hiddenSize]; 17 // Dimensions required of optional inputs. 18 const kValidBiasShape = [3 * hiddenSize]; 19 const kValidRecurrentBiasShape = [3 * hiddenSize]; 20 // Dimensions required of required output. 21 const kValidOutputShape = [batchSize, hiddenSize]; 22 23 // Example descriptors which are valid according to the above dimensions. 24 const kExampleInputDescriptor = { 25 dataType: 'float32', 26 shape: kValidInputShape 27 }; 28 const kExampleWeightDescriptor = { 29 dataType: 'float32', 30 shape: kValidWeightShape 31 }; 32 const kExampleRecurrentWeightDescriptor = { 33 dataType: 'float32', 34 shape: kValidRecurrentWeightShape 35 }; 36 const kExampleHiddenStateDescriptor = { 37 dataType: 'float32', 38 shape: kValidHiddenStateShape 39 }; 40 const kExampleBiasDescriptor = { 41 dataType: 'float32', 42 shape: kValidBiasShape 43 }; 44 const kExampleRecurrentBiasDescriptor = { 45 dataType: 'float32', 46 shape: kValidRecurrentBiasShape 47 }; 48 const kExampleOutputDescriptor = { 49 dataType: 'float32', 50 shape: kValidOutputShape 51 }; 52 53 const tests = [ 54 { 55 name: '[gruCell] Test with default options', 56 input: kExampleInputDescriptor, 57 weight: kExampleWeightDescriptor, 58 recurrentWeight: kExampleRecurrentWeightDescriptor, 59 hiddenState: kExampleHiddenStateDescriptor, 60 hiddenSize: hiddenSize, 61 output: kExampleOutputDescriptor 62 }, 63 { 64 name: '[gruCell] Test with given options', 65 input: kExampleInputDescriptor, 66 weight: kExampleWeightDescriptor, 67 recurrentWeight: kExampleRecurrentWeightDescriptor, 68 hiddenState: kExampleHiddenStateDescriptor, 69 hiddenSize: hiddenSize, 70 options: { 71 bias: kExampleBiasDescriptor, 72 recurrentBias: kExampleRecurrentBiasDescriptor, 73 restAfter: true, 74 layout: 'rzn', 75 activations: ['sigmoid', 'relu'] 76 }, 77 output: kExampleOutputDescriptor 78 }, 79 { 80 name: '[gruCell] Throw if hiddenSize equals to zero', 81 input: kExampleInputDescriptor, 82 weight: kExampleWeightDescriptor, 83 recurrentWeight: kExampleRecurrentWeightDescriptor, 84 hiddenState: kExampleHiddenStateDescriptor, 85 hiddenSize: 0 86 }, 87 { 88 name: '[gruCell] Throw if hiddenSize is too large', 89 input: kExampleInputDescriptor, 90 weight: kExampleWeightDescriptor, 91 recurrentWeight: kExampleRecurrentWeightDescriptor, 92 hiddenState: kExampleHiddenStateDescriptor, 93 hiddenSize: 4294967295, 94 }, 95 { 96 name: 97 '[gruCell] Throw if the data type of the inputs is not one of the floating point types', 98 input: {dataType: 'uint32', shape: kValidInputShape}, 99 weight: {dataType: 'uint32', shape: kValidWeightShape}, 100 recurrentWeight: {dataType: 'uint32', shape: kValidRecurrentWeightShape}, 101 hiddenState: {dataType: 'uint32', shape: kValidHiddenStateShape}, 102 hiddenSize: hiddenSize 103 }, 104 { 105 name: '[gruCell] Throw if the rank of input is not 2', 106 input: {dataType: 'float32', shape: [batchSize]}, 107 weight: kExampleWeightDescriptor, 108 recurrentWeight: kExampleRecurrentWeightDescriptor, 109 hiddenState: kExampleHiddenStateDescriptor, 110 hiddenSize: hiddenSize 111 }, 112 { 113 name: '[gruCell] Throw if the input.shape[1] is incorrect', 114 input: {dataType: 'float32', shape: [inputSize, inputSize]}, 115 weight: kExampleWeightDescriptor, 116 recurrentWeight: kExampleRecurrentWeightDescriptor, 117 hiddenState: kExampleHiddenStateDescriptor, 118 hiddenSize: hiddenSize 119 }, 120 { 121 name: 122 '[gruCell] Throw if data type of weight is not one of the floating point types', 123 input: kExampleInputDescriptor, 124 weight: {dataType: 'int8', shape: [3 * hiddenSize, inputSize]}, 125 recurrentWeight: kExampleRecurrentWeightDescriptor, 126 hiddenState: kExampleHiddenStateDescriptor, 127 hiddenSize: hiddenSize 128 }, 129 { 130 name: '[gruCell] Throw if rank of weight is not 2', 131 input: kExampleInputDescriptor, 132 weight: {dataType: 'float32', shape: [3 * hiddenSize]}, 133 recurrentWeight: kExampleRecurrentWeightDescriptor, 134 hiddenState: kExampleHiddenStateDescriptor, 135 hiddenSize: hiddenSize 136 }, 137 { 138 name: '[gruCell] Throw if weight.shape[0] is not 3 * hiddenSize', 139 input: kExampleInputDescriptor, 140 weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 141 recurrentWeight: kExampleRecurrentWeightDescriptor, 142 hiddenState: kExampleHiddenStateDescriptor, 143 hiddenSize: hiddenSize 144 }, 145 { 146 name: 147 '[gruCell] Throw if data type of recurrentWeight is not one of the floating point types', 148 input: kExampleInputDescriptor, 149 weight: kExampleWeightDescriptor, 150 recurrentWeight: {dataType: 'int32', shape: [3 * hiddenSize, hiddenSize]}, 151 hiddenState: kExampleHiddenStateDescriptor, 152 hiddenSize: hiddenSize 153 }, 154 { 155 name: '[gruCell] Throw if the rank of recurrentWeight is not 2', 156 input: kExampleInputDescriptor, 157 weight: kExampleWeightDescriptor, 158 recurrentWeight: {dataType: 'float32', shape: [3 * hiddenSize]}, 159 hiddenState: kExampleHiddenStateDescriptor, 160 hiddenSize: hiddenSize 161 }, 162 { 163 name: '[gruCell] Throw if the recurrentWeight.shape is invalid', 164 input: kExampleInputDescriptor, 165 weight: kExampleWeightDescriptor, 166 recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]}, 167 hiddenState: kExampleHiddenStateDescriptor, 168 hiddenSize: hiddenSize 169 }, 170 { 171 name: 172 '[gruCell] Throw if data type of hiddenState is not one of the floating point types', 173 input: kExampleInputDescriptor, 174 weight: kExampleWeightDescriptor, 175 recurrentWeight: kExampleRecurrentWeightDescriptor, 176 hiddenState: {dataType: 'uint32', shape: [batchSize, hiddenSize]}, 177 hiddenSize: hiddenSize 178 }, 179 { 180 name: '[gruCell] Throw if the rank of hiddenState is not 2', 181 input: kExampleInputDescriptor, 182 weight: kExampleWeightDescriptor, 183 recurrentWeight: kExampleRecurrentWeightDescriptor, 184 hiddenState: {dataType: 'float32', shape: [hiddenSize]}, 185 hiddenSize: hiddenSize 186 }, 187 { 188 name: '[gruCell] Throw if the hiddenState.shape is invalid', 189 input: kExampleInputDescriptor, 190 weight: kExampleWeightDescriptor, 191 recurrentWeight: kExampleRecurrentWeightDescriptor, 192 hiddenState: {dataType: 'float32', shape: [batchSize, 3 * hiddenSize]}, 193 hiddenSize: hiddenSize 194 }, 195 { 196 name: '[gruCell] Throw if the size of options.activations is not 2', 197 input: kExampleInputDescriptor, 198 weight: kExampleWeightDescriptor, 199 recurrentWeight: kExampleRecurrentWeightDescriptor, 200 hiddenState: kExampleHiddenStateDescriptor, 201 hiddenSize: hiddenSize, 202 options: {activations: ['sigmoid', 'tanh', 'relu']} 203 }, 204 { 205 name: 206 '[gruCell] Throw if data type of options.bias is not one of the floating point types', 207 input: kExampleInputDescriptor, 208 weight: kExampleWeightDescriptor, 209 recurrentWeight: kExampleRecurrentWeightDescriptor, 210 hiddenState: kExampleHiddenStateDescriptor, 211 hiddenSize: hiddenSize, 212 options: {bias: {dataType: 'uint8', shape: [3 * hiddenSize]}} 213 }, 214 { 215 name: '[gruCell] Throw if the rank of options.bias is not 1', 216 input: kExampleInputDescriptor, 217 weight: kExampleWeightDescriptor, 218 recurrentWeight: kExampleRecurrentWeightDescriptor, 219 hiddenState: kExampleHiddenStateDescriptor, 220 hiddenSize: hiddenSize, 221 options: {bias: {dataType: 'float32', shape: [batchSize, 3 * hiddenSize]}} 222 }, 223 { 224 name: '[gruCell] Throw if options.bias.shape[0] is not 3 * hiddenSize', 225 input: kExampleInputDescriptor, 226 weight: kExampleWeightDescriptor, 227 recurrentWeight: kExampleRecurrentWeightDescriptor, 228 hiddenState: kExampleHiddenStateDescriptor, 229 hiddenSize: hiddenSize, 230 options: {bias: {dataType: 'float32', shape: [2 * hiddenSize]}} 231 }, 232 { 233 name: 234 '[gruCell] Throw if data type of options.recurrentBias is not one of the floating point types', 235 input: kExampleInputDescriptor, 236 weight: kExampleWeightDescriptor, 237 recurrentWeight: kExampleRecurrentWeightDescriptor, 238 hiddenState: kExampleHiddenStateDescriptor, 239 hiddenSize: hiddenSize, 240 options: {recurrentBias: {dataType: 'int8', shape: [3 * hiddenSize]}} 241 }, 242 { 243 name: '[gruCell] Throw if the rank of options.recurrentBias is not 1', 244 input: kExampleInputDescriptor, 245 weight: kExampleWeightDescriptor, 246 recurrentWeight: kExampleRecurrentWeightDescriptor, 247 hiddenState: kExampleHiddenStateDescriptor, 248 hiddenSize: hiddenSize, 249 options: { 250 recurrentBias: {dataType: 'float32', shape: [batchSize, 3 * hiddenSize]} 251 } 252 }, 253 { 254 name: 255 '[gruCell] Throw if options.recurrentBias.shape[0] is not 3 * hiddenSize', 256 input: kExampleInputDescriptor, 257 weight: kExampleWeightDescriptor, 258 recurrentWeight: kExampleRecurrentWeightDescriptor, 259 hiddenState: kExampleHiddenStateDescriptor, 260 hiddenSize: hiddenSize, 261 options: {recurrentBias: {dataType: 'float16', shape: [4 * hiddenSize]}} 262 } 263 ]; 264 265 tests.forEach( 266 test => 267 promise_test(async t => { 268 const builder = new MLGraphBuilder(context); 269 const input = builder.input('input', test.input); 270 const weight = builder.input('weight', test.weight); 271 const recurrentWeight = 272 builder.input('recurrentWeight', test.recurrentWeight); 273 const hiddenState = builder.input('hiddenState', test.hiddenState); 274 275 const options = {}; 276 if (test.options) { 277 if (test.options.bias) { 278 options.bias = builder.input('bias', test.options.bias); 279 } 280 if (test.options.recurrentBias) { 281 options.recurrentBias = 282 builder.input('recurrentBias', test.options.recurrentBias); 283 } 284 if (test.options.resetAfter) { 285 options.resetAfter = test.options.resetAfter; 286 } 287 if (test.options.layout) { 288 options.layout = test.options.layout; 289 } 290 if (test.options.activations) { 291 options.activations = test.options.activations; 292 } 293 } 294 295 if (test.output && 296 context.opSupportLimits().gruCell.input.dataTypes.includes( 297 test.input.dataType)) { 298 const output = builder.gruCell( 299 input, weight, recurrentWeight, hiddenState, test.hiddenSize, 300 options); 301 assert_equals(output.dataType, test.output.dataType); 302 assert_array_equals(output.shape, test.output.shape); 303 } else { 304 const label = 'gru_cell_xxx'; 305 options.label = label; 306 const regrexp = new RegExp('\\[' + label + '\\]'); 307 assert_throws_with_label( 308 () => builder.gruCell( 309 input, weight, recurrentWeight, hiddenState, 310 test.hiddenSize, options), 311 regrexp); 312 } 313 }, test.name)); 314 315 multi_builder_test(async (t, builder, otherBuilder) => { 316 const inputFromOtherBuilder = 317 otherBuilder.input('input', kExampleInputDescriptor); 318 319 const weight = builder.input('weight', kExampleWeightDescriptor); 320 const recurrentWeight = 321 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 322 const hiddenState = 323 builder.input('hiddenState', kExampleHiddenStateDescriptor); 324 assert_throws_js( 325 TypeError, 326 () => builder.gruCell( 327 inputFromOtherBuilder, weight, recurrentWeight, hiddenState, 328 hiddenSize)); 329 }, '[gruCell] throw if input is from another builder'); 330 331 multi_builder_test(async (t, builder, otherBuilder) => { 332 const weightFromOtherBuilder = 333 otherBuilder.input('weight', kExampleWeightDescriptor); 334 335 const input = builder.input('input', kExampleInputDescriptor); 336 const recurrentWeight = 337 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 338 const hiddenState = 339 builder.input('hiddenState', kExampleHiddenStateDescriptor); 340 assert_throws_js( 341 TypeError, 342 () => builder.gruCell( 343 input, weightFromOtherBuilder, recurrentWeight, hiddenState, 344 hiddenSize)); 345 }, '[gruCell] throw if weight is from another builder'); 346 347 multi_builder_test(async (t, builder, otherBuilder) => { 348 const recurrentWeightFromOtherBuilder = 349 otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 350 351 const input = builder.input('input', kExampleInputDescriptor); 352 const weight = builder.input('weight', kExampleWeightDescriptor); 353 const hiddenState = 354 builder.input('hiddenState', kExampleHiddenStateDescriptor); 355 assert_throws_js( 356 TypeError, 357 () => builder.gruCell( 358 input, weight, recurrentWeightFromOtherBuilder, hiddenState, 359 hiddenSize)); 360 }, '[gruCell] throw if recurrentWeight is from another builder'); 361 362 multi_builder_test(async (t, builder, otherBuilder) => { 363 const hiddenStateFromOtherBuilder = 364 otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor); 365 366 const input = builder.input('input', kExampleInputDescriptor); 367 const weight = builder.input('weight', kExampleWeightDescriptor); 368 const recurrentWeight = 369 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 370 assert_throws_js( 371 TypeError, 372 () => builder.gruCell( 373 input, weight, recurrentWeight, hiddenStateFromOtherBuilder, 374 hiddenSize)); 375 }, '[gruCell] throw if hiddenState is from another builder'); 376 377 multi_builder_test(async (t, builder, otherBuilder) => { 378 const biasFromOtherBuilder = 379 otherBuilder.input('bias', kExampleBiasDescriptor); 380 const options = {bias: biasFromOtherBuilder}; 381 382 const input = builder.input('input', kExampleInputDescriptor); 383 const weight = builder.input('weight', kExampleWeightDescriptor); 384 const recurrentWeight = 385 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 386 const hiddenState = 387 builder.input('hiddenState', kExampleHiddenStateDescriptor); 388 assert_throws_js( 389 TypeError, 390 () => builder.gruCell( 391 input, weight, recurrentWeight, hiddenState, hiddenSize, options)); 392 }, '[gruCell] throw if bias option is from another builder'); 393 394 multi_builder_test(async (t, builder, otherBuilder) => { 395 const recurrentBiasFromOtherBuilder = 396 otherBuilder.input('recurrentBias', kExampleRecurrentBiasDescriptor); 397 const options = {recurrentBias: recurrentBiasFromOtherBuilder}; 398 399 const input = builder.input('input', kExampleInputDescriptor); 400 const weight = builder.input('weight', kExampleWeightDescriptor); 401 const recurrentWeight = 402 builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); 403 const hiddenState = 404 builder.input('hiddenState', kExampleHiddenStateDescriptor); 405 assert_throws_js( 406 TypeError, 407 () => builder.gruCell( 408 input, weight, recurrentWeight, hiddenState, hiddenSize, options)); 409 }, '[gruCell] throw if recurrentBias option is from another builder');