utils_validation.js (20489B)
1 'use strict'; 2 3 // https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype 4 const allWebNNOperandDataTypes = [ 5 'float32', 6 'float16', 7 'int32', 8 'uint32', 9 'int64', 10 'uint64', 11 'int8', 12 'uint8' 13 ]; 14 15 // https://webidl.spec.whatwg.org/#idl-unsigned-long 16 // The unsigned long type is an unsigned integer type that has values in the 17 // range [0, 4294967295]. 18 // 4294967295 = 2 ** 32 - 1 19 const kMaxUnsignedLong = 2 ** 32 - 1; 20 21 const floatingPointTypes = ['float32', 'float16']; 22 23 const signedIntegerTypes = ['int32', 'int64', 'int8']; 24 25 const unsignedLongType = 'unsigned long'; 26 27 const shape0D = []; 28 const shape1D = [2]; 29 const shape2D = [2, 3]; 30 const shape3D = [2, 3, 4]; 31 const shape4D = [2, 3, 4, 5]; 32 const shape5D = [2, 3, 4, 5, 6]; 33 34 const adjustOffsetsArray = [ 35 // Decrease 1 36 -1, 37 // Increase 1 38 1 39 ]; 40 41 // TODO 42 // Add more 5+ dimensions 43 const allWebNNShapesArray = 44 [shape0D, shape1D, shape2D, shape3D, shape4D, shape5D]; 45 46 const notUnsignedLongAxisArray = [ 47 // String 48 'abc', 49 // BigInt 50 BigInt(100), 51 // Object 52 { 53 value: 1 54 }, 55 // Array Object 56 [0, 1], 57 // Date Object 58 new Date("2024-01-01"), 59 ]; 60 61 function getRank(inputShape) { 62 return inputShape.length; 63 } 64 65 function getAxisArray(inputShape) { 66 return Array.from({length: inputShape.length}, (_, i) => i); 67 } 68 69 function getAxesArrayContainSameValues(inputShape) { 70 // TODO 71 // Currently this function returns an array containing each element which all have the same value. 72 // For example axes: [0, 1, 2] for 3D input tensor 73 // this function returns 74 // [ 75 // // two values are same 76 // [0, 0], 77 // [1, 1], 78 // [2, 2], 79 // // three values are same 80 // [0, 0, 0], 81 // [1, 1, 1] 82 // [2, 2, 2] 83 // ] 84 // while it should return 85 // [ 86 // // two values are same 87 // [0, 0], 88 // [1, 1], 89 // [2, 2], 90 // [0, 0, 1], 91 // [0, 0, 2], 92 // [0, 1, 0], 93 // [0, 2, 0], 94 // [1, 0, 0], 95 // [2, 0, 0], 96 // [1, 1, 0], 97 // [1, 1, 2], 98 // [1, 0, 1], 99 // [1, 2, 1], 100 // [0, 1, 1], 101 // [2, 1, 1], 102 // [2, 2, 0], 103 // [2, 2, 1], 104 // [2, 0, 2], 105 // [2, 1, 2], 106 // [0, 2, 2], 107 // [1, 2, 2], 108 // // three (all) values are same 109 // [0, 0, 0], 110 // [1, 1, 1] 111 // [2, 2, 2] 112 // ] 113 const axesArrayContainSameValues = []; 114 const length = inputShape.length; 115 if (length >= 2) { 116 const validAxesArrayFull = getAxisArray(inputShape); 117 for (let index = 0; index < length; index++) { 118 axesArrayContainSameValues.push(new Array(2).fill(validAxesArrayFull[index])); 119 if (length > 2) { 120 axesArrayContainSameValues.push(new Array(3).fill(validAxesArrayFull[index])); 121 } 122 } 123 } 124 return axesArrayContainSameValues; 125 } 126 127 function generateUnbroadcastableShapes(shape) { 128 // Currently this function returns an array of some unbroadcastable shapes. 129 // for example given the input shape [2, 3, 4] 130 // this function returns 131 // [ 132 // [3, 3, 4], 133 // [2, 2, 4], 134 // [2, 4, 4], 135 // [2, 3, 3], 136 // [2, 3, 5], 137 // [3], 138 // [5], 139 // [1, 3], 140 // [1, 5], 141 // [1, 1, 3], 142 // [1, 1, 5], 143 // [1, 1, 1, 3], 144 // [1, 1, 1, 5], 145 // ] 146 if (shape.every(dimension => dimension === 1)) { 147 throw new Error(`[${shape}] always can be broadcasted`); 148 } 149 const resultShapes = []; 150 const length = shape.length; 151 if (!shape.slice(0, length - 1).every(dimension => dimension === 1)) { 152 for (let i = 0; i < length; i++) { 153 if (shape[i] !== 1) { 154 for (let offset of [-1, 1]) { 155 const shapeB = shape.slice(); 156 shapeB[i] += offset; 157 if (shapeB[i] !== 1) { 158 resultShapes.push(shapeB); 159 } 160 } 161 } 162 } 163 } 164 const lastDimensionSize = shape[length - 1]; 165 if (lastDimensionSize !== 1) { 166 for (let j = 0; j <= length; j++) { 167 if (lastDimensionSize > 2) { 168 resultShapes.push(Array(j).fill(1).concat([lastDimensionSize - 1])); 169 } 170 resultShapes.push(Array(j).fill(1).concat([lastDimensionSize + 1])); 171 } 172 } 173 return resultShapes; 174 } 175 176 function generateOutOfRangeValuesArray(type) { 177 let range, outsideValueArray; 178 switch (type) { 179 case 'unsigned long': 180 range = [0, kMaxUnsignedLong]; 181 break; 182 default: 183 throw new Error(`Unsupport ${type}`); 184 } 185 outsideValueArray = [range[0] - 1, range[1] + 1]; 186 return outsideValueArray; 187 } 188 189 let inputIndex = 0; 190 let inputAIndex = 0; 191 let inputBIndex = 0; 192 let context; 193 194 test(() => assert_not_equals(navigator.ml, undefined, "ml property is defined on navigator")); 195 196 promise_setup(async () => { 197 if (navigator.ml === undefined) { 198 return; 199 } 200 const deviceType = new URLSearchParams(location.search).get('device') || 201 location.search.substring(1); 202 context = await navigator.ml.createContext({deviceType: deviceType}); 203 }, {explicit_timeout: true}); 204 205 function assert_throws_with_label(func, regrexp) { 206 try { 207 func.call(this); 208 assert_unreached('Graph builder method unexpectedly succeeded'); 209 } catch (e) { 210 assert_equals(e.name, 'TypeError'); 211 const error_message = e.message; 212 assert_not_equals(error_message.match(regrexp), null); 213 } 214 } 215 216 function validateTwoInputsBroadcastable(operationName, label) { 217 if (navigator.ml === undefined) { 218 return; 219 } 220 promise_test(async t => { 221 const builder = new MLGraphBuilder(context); 222 for (let dataType of allWebNNOperandDataTypes) { 223 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 224 assert_throws_js( 225 TypeError, 226 () => builder.input( 227 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 228 continue; 229 } 230 for (let shape of allWebNNShapesArray) { 231 if (shape.length > 0) { 232 const inputA = 233 builder.input(`inputA${++inputAIndex}`, {dataType, shape}); 234 const unbroadcastableShapes = generateUnbroadcastableShapes(shape); 235 for (let shape of unbroadcastableShapes) { 236 const inputB = 237 builder.input(`inputB${++inputBIndex}`, {dataType, shape}); 238 assert_equals(typeof builder[operationName], 'function'); 239 const options = {label}; 240 const regrexp = new RegExp('\\[' + label + '\\]'); 241 assert_throws_with_label( 242 () => builder[operationName](inputA, inputB, options), regrexp); 243 assert_throws_with_label( 244 () => builder[operationName](inputB, inputA, options), regrexp); 245 } 246 } 247 } 248 } 249 }, `[${operationName}] TypeError is expected if two inputs aren't broadcastable`); 250 } 251 252 function validateTwoBroadcastableInputsTensorLimit(operationName, label) { 253 if (navigator.ml === undefined) { 254 return; 255 } 256 promise_test(async t => { 257 const builder = new MLGraphBuilder(context); 258 259 const a = builder.input('a', {dataType: 'float32', 260 shape: [context.opSupportLimits().maxTensorByteLength / 4, 1]}); 261 const b = builder.input('b', {dataType: 'float32', shape: [1, 5] }); 262 263 const options = {label}; 264 const regrexp = new RegExp('\\[' + label + '\\]'); 265 assert_throws_with_label( 266 () => builder[operationName](a, b, options), regrexp); 267 }, `[${operationName}] throw if the output tensor byte length exceeds limit`); 268 } 269 270 function validateTwoInputsOfSameDataType(operationName, label) { 271 if (navigator.ml === undefined) { 272 return; 273 } 274 let operationNameArray; 275 if (typeof operationName === 'string') { 276 operationNameArray = [operationName]; 277 } else if (Array.isArray(operationName)) { 278 operationNameArray = operationName; 279 } else { 280 throw new Error(`${operationName} should be an operation name string or an operation name string array`); 281 } 282 for (let subOperationName of operationNameArray) { 283 promise_test(async t => { 284 const builder = new MLGraphBuilder(context); 285 for (let dataType of allWebNNOperandDataTypes) { 286 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 287 assert_throws_js( 288 TypeError, 289 () => builder.input( 290 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 291 continue; 292 } 293 for (let shape of allWebNNShapesArray) { 294 const inputA = 295 builder.input(`inputA${++inputAIndex}`, {dataType, shape}); 296 for (let dataTypeB of allWebNNOperandDataTypes) { 297 if (!context.opSupportLimits().input.dataTypes.includes( 298 dataTypeB)) { 299 assert_throws_js( 300 TypeError, 301 () => builder.input( 302 `inputB${++inputBIndex}`, {dataTypeB, shape1D})); 303 continue; 304 } 305 if (dataType !== dataTypeB) { 306 const inputB = builder.input( 307 `inputB${++inputBIndex}`, {dataType: dataTypeB, shape}); 308 const options = {label}; 309 const regrexp = new RegExp('\\[' + label + '\\]'); 310 assert_equals(typeof builder[subOperationName], 'function'); 311 assert_throws_with_label( 312 () => builder[subOperationName](inputA, inputB, options), 313 regrexp); 314 } 315 } 316 } 317 } 318 }, `[${subOperationName}] TypeError is expected if two inputs aren't of same data type`); 319 } 320 } 321 322 /** 323 * Validate options.axes by given operation and input rank for 324 * argMin/Max / layerNormalization / Reduction operations operations 325 * @param {(String[]|String)} operationName - An operation name array or an 326 * operation name 327 */ 328 function validateOptionsAxes(operationName) { 329 if (navigator.ml === undefined) { 330 return; 331 } 332 let operationNameArray; 333 if (typeof operationName === 'string') { 334 operationNameArray = [operationName]; 335 } else if (Array.isArray(operationName)) { 336 operationNameArray = operationName; 337 } else { 338 throw new Error(`${operationName} should be an operation name string or an operation name string array`); 339 } 340 const invalidAxisArray = generateOutOfRangeValuesArray(unsignedLongType); 341 for (let subOperationName of operationNameArray) { 342 // TypeError is expected if any of options.axes elements is not an unsigned long interger 343 promise_test(async t => { 344 const builder = new MLGraphBuilder(context); 345 for (let dataType of allWebNNOperandDataTypes) { 346 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 347 assert_throws_js( 348 TypeError, 349 () => builder.input( 350 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 351 continue; 352 } 353 for (let shape of allWebNNShapesArray) { 354 const rank = getRank(shape); 355 if (rank >= 1) { 356 const input = 357 builder.input(`input${++inputIndex}`, {dataType, shape}); 358 for (let invalidAxis of invalidAxisArray) { 359 assert_equals(typeof builder[subOperationName], 'function'); 360 assert_throws_js( 361 TypeError, 362 () => builder[subOperationName](input, {axes: invalidAxis})); 363 } 364 for (let axis of notUnsignedLongAxisArray) { 365 assert_false( 366 typeof axis === 'number' && Number.isInteger(axis), 367 `[${subOperationName}] any of options.axes elements should be of 'unsigned long'`); 368 assert_equals(typeof builder[subOperationName], 'function'); 369 assert_throws_js( 370 TypeError, 371 () => builder[subOperationName](input, {axes: [axis]})); 372 } 373 } 374 } 375 } 376 }, `[${subOperationName}] TypeError is expected if any of options.axes elements is not an unsigned long interger`); 377 378 // TypeError is expected if any of options.axes elements is greater or equal 379 // to the size of input 380 promise_test(async t => { 381 const builder = new MLGraphBuilder(context); 382 for (let dataType of allWebNNOperandDataTypes) { 383 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 384 assert_throws_js( 385 TypeError, 386 () => builder.input( 387 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 388 continue; 389 } 390 for (let shape of allWebNNShapesArray) { 391 const rank = getRank(shape); 392 if (rank >= 1) { 393 const input = 394 builder.input(`input${++inputIndex}`, {dataType, shape}); 395 assert_equals(typeof builder[subOperationName], 'function'); 396 assert_throws_js( 397 TypeError, 398 () => builder[subOperationName](input, {axes: [rank]})); 399 assert_throws_js( 400 TypeError, 401 () => builder[subOperationName](input, {axes: [rank + 1]})); 402 } 403 } 404 } 405 }, `[${subOperationName}] TypeError is expected if any of options.axes elements is greater or equal to the size of input`); 406 407 // TypeError is expected if two or more values are same in the axes sequence 408 promise_test(async t => { 409 const builder = new MLGraphBuilder(context); 410 for (let dataType of allWebNNOperandDataTypes) { 411 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 412 assert_throws_js( 413 TypeError, 414 () => builder.input( 415 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 416 continue; 417 } 418 for (let shape of allWebNNShapesArray) { 419 const rank = getRank(shape); 420 if (rank >= 2) { 421 const input = 422 builder.input(`input${++inputIndex}`, {dataType, shape}); 423 const axesArrayContainSameValues = 424 getAxesArrayContainSameValues(shape); 425 for (let axes of axesArrayContainSameValues) { 426 assert_equals(typeof builder[subOperationName], 'function'); 427 assert_throws_js( 428 TypeError, () => builder[subOperationName](input, {axes})); 429 } 430 } 431 } 432 } 433 }, `[${subOperationName}] TypeError is expected if two or more values are same in the axes sequence`); 434 } 435 } 436 437 // TODO: remove this method once all the data type limits of the unary 438 // operations are specified in context.OpSupportLimits(). 439 /** 440 * Validate a unary operation 441 * @param {String} operationName - An operation name 442 * @param {Array} supportedDataTypes - Test building with these data types 443 * succeeds and test building with all other data types fails 444 */ 445 function validateUnaryOperation(operationName, supportedDataTypes, label) { 446 promise_test(async t => { 447 const builder = new MLGraphBuilder(context); 448 for (let dataType of supportedDataTypes) { 449 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 450 assert_throws_js( 451 TypeError, 452 () => builder.input( 453 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 454 continue; 455 } 456 for (let shape of allWebNNShapesArray) { 457 const input = builder.input(`input`, {dataType, shape}); 458 assert_equals(typeof builder[operationName], 'function'); 459 const output = builder[operationName](input); 460 assert_equals(output.dataType, dataType); 461 assert_array_equals(output.shape, shape); 462 } 463 } 464 }, `[${operationName}] Test building an unary operator with supported type.`); 465 466 const unsupportedDataTypes = 467 new Set(allWebNNOperandDataTypes).difference(new Set(supportedDataTypes)); 468 promise_test(async t => { 469 const builder = new MLGraphBuilder(context); 470 for (let dataType of unsupportedDataTypes) { 471 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 472 assert_throws_js( 473 TypeError, 474 () => builder.input( 475 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 476 continue; 477 } 478 for (let shape of allWebNNShapesArray) { 479 const input = builder.input(`input`, {dataType, shape}); 480 assert_equals(typeof builder[operationName], 'function'); 481 const options = {label}; 482 const regrexp = new RegExp('\\[' + label + '\\]'); 483 assert_throws_with_label( 484 () => builder[operationName](input, options), regrexp); 485 } 486 } 487 }, `[${operationName}] Throw if the dataType is not supported for an unary operator.`); 488 } 489 490 /** 491 * Validate a single input operation 492 * @param {String} operationName - An operation name 493 */ 494 function validateSingleInputOperation(operationName, label) { 495 promise_test(async t => { 496 const builder = new MLGraphBuilder(context); 497 const supportedDataTypes = 498 context.opSupportLimits()[operationName].input.dataTypes; 499 for (let dataType of supportedDataTypes) { 500 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 501 continue; 502 } 503 for (let shape of allWebNNShapesArray) { 504 const input = builder.input(`input`, {dataType, shape}); 505 const output = builder[operationName](input); 506 assert_equals(output.dataType, dataType); 507 assert_array_equals(output.shape, shape); 508 } 509 } 510 }, `[${operationName}] Test building the operator with supported data type.`); 511 512 promise_test(async t => { 513 const builder = new MLGraphBuilder(context); 514 const unsupportedDataTypes = 515 new Set(allWebNNOperandDataTypes) 516 .difference(new Set( 517 context.opSupportLimits()[operationName].input.dataTypes)); 518 for (let dataType of unsupportedDataTypes) { 519 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 520 assert_throws_js( 521 TypeError, 522 () => builder.input( 523 `inputA${++inputAIndex}`, {dataType, shape: shape1D})); 524 continue; 525 } 526 for (let shape of allWebNNShapesArray) { 527 const input = builder.input(`input`, {dataType, shape}); 528 assert_equals(typeof builder[operationName], 'function'); 529 const options = {label}; 530 const regrexp = new RegExp('\\[' + label + '\\]'); 531 assert_throws_with_label( 532 () => builder[operationName](input, options), regrexp); 533 } 534 } 535 }, `[${operationName}] Throw if the data type is not supported for the operator.`); 536 } 537 538 /** 539 * Basic test that the builder method specified by `operationName` throws if 540 * given an input from another builder. Operands which do not accept a float32 541 * square 2D input should pass their own `operatorDescriptor`. 542 * @param {String} operationName 543 * @param {String} operatorDescriptor 544 */ 545 function validateInputFromAnotherBuilder(operatorName, operatorDescriptor = { 546 dataType: 'float32', 547 shape: [2, 2] 548 }) { 549 multi_builder_test(async (t, builder, otherBuilder) => { 550 const inputFromOtherBuilder = 551 otherBuilder.input('input', operatorDescriptor); 552 assert_equals(typeof builder[operatorName], 'function'); 553 assert_throws_js( 554 TypeError, () => builder[operatorName](inputFromOtherBuilder)); 555 }, `[${operatorName}] throw if input is from another builder`); 556 }; 557 558 /** 559 * Basic test that the builder method specified by `operationName` throws if one 560 * of its inputs is from another builder. This helper may only be used by 561 * operands which accept float32 square 2D inputs. 562 * @param {String} operationName 563 */ 564 function validateTwoInputsFromMultipleBuilders(operatorName) { 565 const opDescriptor = {dataType: 'float32', shape: [2, 2]}; 566 567 multi_builder_test(async (t, builder, otherBuilder) => { 568 const inputFromOtherBuilder = otherBuilder.input('other', opDescriptor); 569 570 const input = builder.input('input', opDescriptor); 571 assert_equals(typeof builder[operatorName], 'function'); 572 assert_throws_js( 573 TypeError, () => builder[operatorName](inputFromOtherBuilder, input)); 574 }, `[${operatorName}] throw if first input is from another builder`); 575 576 multi_builder_test(async (t, builder, otherBuilder) => { 577 const inputFromOtherBuilder = otherBuilder.input('other', opDescriptor); 578 579 const input = builder.input('input', opDescriptor); 580 assert_equals(typeof builder[operatorName], 'function'); 581 assert_throws_js( 582 TypeError, () => builder[operatorName](input, inputFromOtherBuilder)); 583 }, `[${operatorName}] throw if second input is from another builder`); 584 }; 585 586 function multi_builder_test(func, description) { 587 promise_test(async t => { 588 const builder = new MLGraphBuilder(context); 589 const otherBuilder = new MLGraphBuilder(context); 590 591 await func(t, builder, otherBuilder); 592 }, description); 593 }