utils.js (54276B)
1 'use strict'; 2 3 const operatorToleranceDict = { 4 argMax: {float32: 0, float16: 0}, 5 argMin: {float32: 0, float16: 0}, 6 batchNormalization: {float32: 6, float16: 6}, 7 clamp: {float32: 0, float16: 0}, 8 9 // Element-wise binary operations 10 add: {float32: 1, float16: 1}, 11 sub: { 12 float32: 1, 13 float16: 1, 14 int8: 0, 15 uint8: 0, 16 int32: 0, 17 uint32: 0, 18 int64: 0, 19 uint64: 0 20 }, 21 mul: {float32: 1, float16: 1}, 22 max: {float32: 0, float16: 0}, 23 min: {float32: 0, float16: 0}, 24 // Element-wise binary operations 25 26 elu: {float32: 18, float16: 18}, 27 gelu: {float32: 18, float16: 18}, 28 hardSigmoid: {float32: 2, float16: 2}, 29 hardSwish: {float32: 4, float16: 4}, 30 leakyRelu: {float32: 1, float16: 2}, 31 linear: {float32: 2, float16: 2}, 32 prelu: {float32: 1, float16: 1}, 33 relu: {float32: 0, float16: 0, int8: 0, int32: 0}, 34 sigmoid: {float32: 34, float16: 10}, 35 softplus: {float32: 18, float16: 18}, 36 softsign: {float32: 3, float16: 3}, 37 tanh: {float32: 16, float16: 16}, 38 }; 39 40 const zeroULPToleranceOperatorList = [ 41 // data movement operators 42 'concat', 'expand', 'gather', 'gatherElements', 'gatherND', 'identity', 'pad', 43 'reshape', 'reverse', 'scatterElements', 'scatterND', 'slice', 'split', 44 'tile', 'transpose', 45 46 // element-wise logical operators 47 'equal', 'notEqual', 'greater', 'greaterOrEqual', 'lesser', 'lesserOrEqual', 48 'logicalNot', 'logicalAnd', 'logicalOr', 'logicalXor' 49 ]; 50 51 const getZeroULPTolerance = () => { 52 return {metricType: 'ULP', value: 0}; 53 }; 54 55 const getSoftmaxPrecisionTolerance = 56 (op, graphResources, intermediateOperands) => { 57 const {inputs} = graphResources; 58 const args = op.arguments; 59 let inputShape; 60 const inputIndex = args[0][Object.keys(args[0])[0]]; 61 if (inputs[inputIndex]) { 62 inputShape = inputs[inputIndex].descriptor.shape; 63 } else { 64 inputShape = intermediateOperands[inputIndex].shape; 65 } 66 const axis = args.length === 2 ? args[1][Object.keys(args[1])[0]] : 1; 67 const tolerance = inputShape[axis] * 3 + 3; 68 const toleranceValueDict = {float32: tolerance, float16: tolerance}; 69 const expectedDataType = 70 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 71 return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]}; 72 }; 73 74 const getPrecisionTolerance = (graphResources, intermediateOperands) => { 75 const expectedDataType = 76 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 77 let toleranceValue = 0; 78 graphResources.operators.forEach(op => { 79 switch (op.name) { 80 case 'conv2d': 81 toleranceValue += getConv2dPrecisionTolerance(op, graphResources, 82 intermediateOperands).value; 83 break; 84 case 'convTranspose2d': 85 toleranceValue += getConv2dPrecisionTolerance(op, graphResources, 86 intermediateOperands).value; 87 break; 88 case 'gemm': 89 toleranceValue += getGemmPrecisionTolerance(op, graphResources, 90 intermediateOperands).value; 91 break; 92 case 'matmul': 93 toleranceValue += getMatmulPrecisionTolerance(op, graphResources, 94 intermediateOperands).value; 95 break; 96 case 'softmax': 97 toleranceValue += getSoftmaxPrecisionTolerance( 98 op, graphResources, intermediateOperands) 99 .value; 100 break; 101 case 'averagePool2d': 102 case 'maxPool2d': 103 case 'l2Pool2d': 104 toleranceValue += getPoolingOperatorsPrecisionTolerance( 105 op, graphResources, intermediateOperands) 106 .value; 107 break; 108 case 'reduceL1': 109 case 'reduceL2': 110 case 'reduceLogSum': 111 case 'reduceLogSumExp': 112 case 'reduceMax': 113 case 'reduceMean': 114 case 'reduceMin': 115 case 'reduceProduct': 116 case 'reduceSum': 117 case 'reduceSumSquare': 118 toleranceValue += getReductionOperatorsPrecisionTolerance( 119 op, graphResources, intermediateOperands) 120 .value; 121 break; 122 case 'resample2d': 123 toleranceValue += getResample2dPrecisionTolerance( 124 op, graphResources, intermediateOperands) 125 .value; 126 break; 127 default: 128 if (zeroULPToleranceOperatorList.includes(op.name)) { 129 toleranceValue += getZeroULPTolerance().value; 130 } else { 131 const operatorTolerance = 132 operatorToleranceDict[op.name]?.[expectedDataType]; 133 if (operatorTolerance !== undefined) { 134 toleranceValue += operatorTolerance; 135 } 136 } 137 } 138 }); 139 return {metricType: 'ULP', value: toleranceValue}; 140 }; 141 142 // https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype 143 const TypedArrayDict = { 144 float32: Float32Array, 145 146 // Proposal to add float16 TypedArrays to JavaScript. 147 // URL: https://tc39.es/proposal-float16array/ 148 // Use workaround Uint16 for Float16 149 float16: Uint16Array, 150 151 int64: BigInt64Array, 152 uint64: BigUint64Array, 153 int32: Int32Array, 154 uint32: Uint32Array, 155 int8: Int8Array, 156 uint8: Uint8Array, 157 int4: Uint8Array, 158 uint4: Uint8Array, 159 }; 160 161 const kIntTypes = 162 ['uint4', 'int4', 'uint8', 'int8', 'uint32', 'int32', 'uint64', 'int64']; 163 const kFloatTypes = ['float16', 'float32']; 164 165 const findCompatibleType = (dataType, supportedTypes, castOpSupportLimits) => { 166 if (!castOpSupportLimits.input.dataTypes.includes(dataType)) { 167 // Cannot cast from `dataType` to any other type. 168 return null; 169 } 170 171 for (let supportedType of supportedTypes) { 172 if (kIntTypes.includes(dataType) && 173 castOpSupportLimits.output.dataTypes.includes(dataType) && 174 kIntTypes.indexOf(supportedType) > kIntTypes.indexOf(dataType)) { 175 return supportedType; 176 } 177 178 if (kFloatTypes.includes(dataType)) { 179 if (kFloatTypes.indexOf(supportedType) > kFloatTypes.indexOf(dataType)) { 180 return supportedType; 181 } 182 } 183 } 184 return null; 185 }; 186 187 // The maximum index to validate for the output's expected value. 188 const kMaximumIndexToValidate = 1000; 189 190 const kContextOptionsForVariant = { 191 cpu: { 192 deviceType: 'cpu', 193 }, 194 gpu: { 195 deviceType: 'gpu', 196 }, 197 npu: { 198 deviceType: 'npu', 199 }, 200 }; 201 202 const searchParams = new URLSearchParams(location.search); 203 const variant = searchParams.get('device') || location.search.substring(1); 204 const contextOptions = kContextOptionsForVariant[variant]; 205 206 async function getContext() { 207 let context; 208 try { 209 context = await navigator.ml.createContext(contextOptions); 210 } catch (e) { 211 throw new AssertionError( 212 `Unable to create context for ${variant} variant. ${e}`); 213 } 214 return context; 215 } 216 217 const tcNameArray = searchParams.getAll('tc'); 218 219 function isTargetTest(test) { 220 return tcNameArray.length === 0 || tcNameArray.includes(test.name); 221 } 222 223 const assertDescriptorsEquals = (outputOperand, expected) => { 224 const dataType = 225 expected.castedType ? expected.castedType : expected.dataType; 226 assert_equals( 227 outputOperand.dataType, dataType, 228 'actual output dataType should be equal to expected output dataType'); 229 assert_array_equals( 230 outputOperand.shape, expected.shape, 231 'actual output shape should be equal to expected output shape'); 232 }; 233 234 // ref: 235 // http://stackoverflow.com/questions/32633585/how-do-you-convert-to-half-floats-in-javascript 236 const toHalf = (value) => { 237 let floatView = new Float32Array(1); 238 let int32View = new Int32Array(floatView.buffer); 239 240 /* This method is faster than the OpenEXR implementation (very often 241 * used, eg. in Ogre), with the additional benefit of rounding, inspired 242 * by James Tursa's half-precision code. */ 243 244 floatView[0] = value; 245 let x = int32View[0]; 246 247 let bits = (x >> 16) & 0x8000; /* Get the sign */ 248 let m = (x >> 12) & 0x07ff; /* Keep one extra bit for rounding */ 249 let e = (x >> 23) & 0xff; /* Using int is faster here */ 250 251 /* If zero, or denormal, or exponent underflows too much for a denormal 252 * half, return signed zero. */ 253 if (e < 103) { 254 return bits; 255 } 256 257 /* If NaN, return NaN. If Inf or exponent overflow, return Inf. */ 258 if (e > 142) { 259 bits |= 0x7c00; 260 /* If exponent was 0xff and one mantissa bit was set, it means NaN, 261 * not Inf, so make sure we set one mantissa bit too. */ 262 if (e == 255 && (x & 0x007fffff)) { 263 bits |= 1; 264 } 265 return bits; 266 } 267 268 /* If exponent underflows but not too much, return a denormal */ 269 if (e < 113) { 270 m |= 0x0800; 271 /* Extra rounding may overflow and set mantissa to 0 and exponent 272 * to 1, which is OK. */ 273 bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); 274 return bits; 275 } 276 277 bits |= ((e - 112) << 10) | (m >> 1); 278 /* Extra rounding. An overflow will set mantissa to 0 and increment 279 * the exponent, which is OK. */ 280 bits += m & 1; 281 return bits; 282 }; 283 284 const getTypedArrayData = (type, size, data) => { 285 let outData; 286 287 if (type === 'float16') { 288 if (typeof (data) === 'number' && size > 1) { 289 return new TypedArrayDict[type](size).fill(toHalf(data)); 290 } 291 // workaround to convert Float16 to Uint16 292 outData = new TypedArrayDict[type](data.length); 293 for (let i = 0; i < data.length; i++) { 294 outData[i] = toHalf(data[i]); 295 } 296 } else if (type === 'int64' || type === 'uint64') { 297 if (typeof (data) === 'number' && size > 1) { 298 return new TypedArrayDict[type](size).fill(BigInt(data)); 299 } 300 outData = new TypedArrayDict[type](data.length); 301 for (let i = 0; i < data.length; i++) { 302 outData[i] = BigInt(data[i]); 303 } 304 } else if (type === 'uint4' || type === 'int4') { 305 // The first nybble is stored in the first bits 0-3, and later bits 4-7 306 // store the later nybble. The data is packed, without any padding between 307 // dimensions. For example: an array of uint4: 308 // size = [2,5] 309 // values = [1,2,3,4,5,6,7,8,9,10] 310 // Would yield 5 hex bytes: 311 // Uint8Array.of(0x21, 0x43, 0x65, 0x87, 0xA9); 312 const array = new TypedArrayDict[type](Math.ceil(size / 2)); 313 let i = 0; 314 while (i < size - 1) { 315 const packedByte = ((data[i + 1] & 0xF) << 4) | (data[i] & 0xF); 316 array[Math.floor(i / 2)] = packedByte; 317 i = i + 2; 318 } 319 // Handle the odd size. 320 if (i === size - 1) { 321 const packedByte = data[i] & 0xF; 322 array[Math.floor(i / 2)] = packedByte; 323 } 324 return array; 325 } else { 326 if (typeof (data) === 'number' && size > 1) { 327 return new TypedArrayDict[type](size).fill(data); 328 } 329 outData = new TypedArrayDict[type](data); 330 } 331 return outData; 332 }; 333 334 const sizeOfShape = (array) => { 335 return array.reduce((accumulator, currentValue) => accumulator * currentValue, 1); 336 }; 337 338 /** 339 * Get bitwise of the given value. 340 * @param {Number} value 341 * @param {String} dataType - A data type string; currently only "float32" is 342 * supported by this function. 343 * @return {BigInt} A 64-bit signed integer. 344 */ 345 const getBitwise = (value, dataType) => { 346 const buffer = new ArrayBuffer(8); 347 const int64Array = new BigInt64Array(buffer); 348 let typedArray; 349 if (dataType === "float32") { 350 typedArray = new Float32Array(buffer); 351 } else { 352 throw new AssertionError(`Data type ${dataType} is not supported`); 353 } 354 typedArray[0] = Math.abs(value); 355 const int64 = int64Array[0]; 356 return (value < 0) ? -int64 : int64; 357 }; 358 359 /** 360 * Assert that each array property in ``actual`` is a number being close enough 361 * to the corresponding property in ``expected`` by the acceptable ULP distance 362 * ``nulp`` with given ``dataType`` data type. 363 * 364 * @param {Array} actual - Array of test values. 365 * @param {Array} expected - Array of values expected to be close to the values 366 * in ``actual``. 367 * @param {(Number|BigInt)} nulp - A value indicates acceptable ULP distance. 368 * @param {String} dataType - A data type string, value: "float32", 369 * more types, please see: 370 * https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype 371 * @param {String} description - Description of the condition being tested. 372 */ 373 const assert_array_approx_equals_ulp = (actual, expected, nulp, dataType, description) => { 374 /* 375 * Test if two primitive arrays are equal within acceptable ULP distance 376 */ 377 assert_equals( 378 actual.length, expected.length, 379 `assert_array_approx_equals_ulp: ${description} lengths differ`); 380 for (let i = 0; i < actual.length; i++) { 381 if (actual[i] === expected[i]) { 382 continue; 383 } else { 384 let distance = ulpDistance(actual[i], expected[i], dataType); 385 386 // TODO: See if callers can be updated to pass matching type. 387 nulp = typeof distance === 'bigint' ? BigInt(nulp) : Number(nulp); 388 389 assert_less_than_equal(distance, nulp, 390 `assert_array_approx_equals_ulp: ${description} actual ` + 391 `${ 392 dataType === 'float16' ? 393 float16AsUint16ToNumber(actual[i]) : 394 actual[i]} should be close enough to expected ` + 395 `${expected[i]} by ULP distance:`); 396 } 397 } 398 }; 399 400 /** 401 * Compute the ULP distance between ``a`` and ``b`` for the given ``dataType``. 402 * 403 * @param {(Number|BigInt)} a - First value. 404 * @param {(Number|BigInt)} b - Second value. 405 * @param {String} dataType - A data type string, value: "float32", 406 * more types, please see: 407 * https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype 408 */ 409 const ulpDistance = (a, b, dataType) => { 410 let aBitwise, bBitwise; 411 // measure the ULP distance 412 if (dataType === 'float32') { 413 aBitwise = getBitwise(a, dataType); 414 bBitwise = getBitwise(b, dataType); 415 } else if (dataType === 'float16') { 416 aBitwise = a; 417 // convert b data of Float16 to Uint16 418 bBitwise = toHalf(b); 419 420 // Workaround to use mask to check returned special float16 value -0.0 which 421 // is 32768 (1000 0000 0000 0000) of uint16 422 const signExclusionMask = 0x00007FFF; 423 if ((aBitwise & signExclusionMask) === 0 && 424 (bBitwise & signExclusionMask) === 0) { 425 return 0; 426 } 427 } else if (dataType === 'int64' || dataType === 'uint64') { 428 aBitwise = BigInt(a); 429 bBitwise = BigInt(b); 430 } else if ( 431 dataType === 'int8' || dataType === 'uint8' || dataType === 'int32' || 432 dataType === 'uint32' || dataType === 'int4' || dataType === 'uint4') { 433 aBitwise = a; 434 bBitwise = b; 435 } else { 436 throw new AssertionError(`Data type ${dataType} is not supported`); 437 } 438 const distance = aBitwise - bBitwise; 439 return distance >= 0 ? distance : -distance; 440 }; 441 442 /** 443 * This function converts a Float16 stored as the bits of a Uint16 into a 444 * JavaScript Number. 445 * @param {Number} uint16 - a Float16 stored as the bits of a Uint16 446 * @returns An emulated Float16 number. 447 */ 448 function float16AsUint16ToNumber(uint16) { 449 const sign = (uint16 >> 15) & 0x1; 450 const exponent = (uint16 >> 10) & 0x1F; 451 const mantissa = uint16 & 0x3FF; 452 let float16; 453 454 if (exponent === 0) { 455 // Subnormal number 456 float16 = (mantissa / 1024) * Math.pow(2, -14); 457 } else if (exponent === 0x1F) { 458 // NaN or Infinity 459 float16 = mantissa ? NaN : Infinity; 460 } else { 461 // Normalized number 462 float16 = (1 + mantissa / 1024) * Math.pow(2, exponent - 15); 463 } 464 465 // Apply the sign 466 return sign ? -float16 : float16; 467 } 468 469 /** 470 * Assert actual results with expected results. 471 * @param {String} operatorName 472 * @param {(Number[]|Number)} actual 473 * @param {(Number[]|Number)} expected 474 * @param {String} metricType - Value: 'ULP', 'ATOL' 475 * @param {Number} toleranceValue 476 * @param {String} dataType - An operand type string, value: "float32", 477 * more types, please see: 478 * https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype 479 */ 480 const doAssert = 481 (operatorName, actual, expected, metricType, toleranceValue, dataType) => { 482 const description = `test ${operatorName} ${dataType}`; 483 if (typeof expected === 'number') { 484 expected = [expected]; 485 actual = [actual]; 486 } 487 if (metricType === 'ULP') { 488 assert_array_approx_equals_ulp( 489 actual, expected, toleranceValue, dataType, description); 490 } else if (metricType === 'ATOL') { 491 let actualData; 492 if (dataType === 'float16') { 493 // workaround for float16 494 actualData = new Array(actual.length); 495 actual.forEach( 496 (x, index) => actualData[index] = float16AsUint16ToNumber(x)); 497 } else { 498 actualData = actual; 499 } 500 assert_array_approx_equals( 501 actualData, expected, toleranceValue, description); 502 } else { 503 throw new AssertionError( 504 `Tolerance Metric type '${metricType}' is not supported`); 505 } 506 }; 507 508 /** 509 * Assert computed results be equal to expected data. 510 * @param {Object} toleranceFunc 511 * @param {Map<String, ArrayBufferView> | 512 * Array[Map<String, ArrayBufferView>]} actual 513 * @param {Object} graphResources - Resources used for building a graph 514 */ 515 const assertResultsEquals = 516 (toleranceFunc, actual, graphResources, intermediateOperands) => { 517 const operatorName = 518 graphResources.operators.map(operator => operator.name).join(' '); 519 const expectedOutputs = graphResources.expectedOutputs; 520 const toleranceInfo = toleranceFunc(graphResources, intermediateOperands); 521 const metricType = toleranceInfo.metricType; 522 const toleranceValue = toleranceInfo.value; 523 let outputData; 524 525 for (let operandName in actual) { 526 const expectedSuboutput = expectedOutputs[operandName]; 527 const expectedDescriptor = expectedSuboutput.descriptor; 528 let expectedData = expectedSuboutput.data; 529 outputData = actual[operandName]; 530 // If data is scalar and shape is not, it means it's expecting to be 531 // filled by the scalar value. Also limit the array size so it doesn't 532 // timeout. 533 if (typeof (expectedData) === 'number' && expectedDescriptor.shape && 534 sizeOfShape(expectedDescriptor.shape) > 1) { 535 const size = Math.min( 536 kMaximumIndexToValidate, sizeOfShape(expectedDescriptor.shape)); 537 expectedData = new Array(size).fill(expectedData); 538 outputData = outputData.subarray(0, kMaximumIndexToValidate); 539 } else if ( 540 expectedDescriptor.dataType === 'uint4' || 541 expectedDescriptor.dataType === 'int4') { 542 // The int4/uint4 data were packed in Uint8Array. 543 // The first nybble and later nybble of one int8/uint8 value store two 544 // consecutive 4-bits values separately. After unpacking each 4-bits 545 // value, the unpacked int4 value is stored in an element of 546 // Int8Array, and the unpacked uint4 value is stored in an element of 547 // Uint8Array. For example: an array of uint4: 548 // size = [1, 5] 549 // Uint8Array.of(0x21, 0x43, 0x65, 0x87, 0xA9) 550 // Would yield 5 * 2 uint4 data: 551 // Uint8Array.of(1,2,3,4,5,6,7,8,9,10); 552 // Another example: an array of int4: 553 // size = [1, 5] 554 // Uint8Array.of(0xA9, 0xCB, 0xED, 0x0F, 0x21) 555 // Would yield 5 * 2 int4 data: 556 // Int8Array.of(-7, -6, -5, -4, -3, -2, -1, 0, 1, 2); 557 let newOutputData; 558 if (expectedDescriptor.dataType === 'uint4') { 559 newOutputData = 560 new Uint8Array(sizeOfShape(expectedDescriptor.shape)); 561 } else { 562 newOutputData = 563 new Int8Array(sizeOfShape(expectedDescriptor.shape)); 564 } 565 const signMask = 566 (expectedDescriptor.dataType === 'int4') ? 0x08 : 0x00; 567 for (let i = 0; i < sizeOfShape(expectedDescriptor.shape); i++) { 568 const byteIndex = Math.floor(i / 2); 569 let value = (outputData[byteIndex] >> ((i & 1) << 2)) & 0xF; 570 // Handle the negative numbers. 571 if (value & signMask) { 572 value |= 0xF0; 573 } 574 newOutputData[i] = value; 575 } 576 outputData = newOutputData; 577 } 578 doAssert( 579 operatorName, outputData, expectedData, metricType, toleranceValue, 580 expectedDescriptor.dataType); 581 } 582 }; 583 584 const createOperand = (context, builder, operandName, resources) => { 585 let operand; 586 const descriptor = resources.descriptor; 587 const dataType = descriptor.dataType; 588 589 const supportedDataTypes = resources.constant ? 590 context.opSupportLimits().constant.dataTypes : 591 context.opSupportLimits().input.dataTypes; 592 593 // If input data type is not supported on current platform, attempt to use 594 // a supported type to pass the data, then cast back to original type. 595 if (!supportedDataTypes.includes(dataType)) { 596 const compatibleType = findCompatibleType( 597 dataType, supportedDataTypes, context.opSupportLimits().cast); 598 if (compatibleType) { 599 descriptor.castedType = compatibleType; 600 descriptor.dataType = compatibleType; 601 } 602 } 603 604 operand = resources.constant ? 605 builder.constant( 606 descriptor, 607 getTypedArrayData( 608 descriptor.dataType, sizeOfShape(descriptor.shape), 609 resources.data)) : 610 builder.input(operandName, descriptor); 611 612 if (descriptor.castedType) { 613 operand = builder.cast(operand, dataType); 614 } 615 616 return operand; 617 }; 618 619 /** 620 * Create inputs or outputs tensor. 621 * @param {MLContext} context - the context used to create inputs or outputs 622 * tensor. 623 * @param {String} dataType - dataType of inputs / outputs operands 624 * @param {Array} shape - dimensions of inputs / outputs operands 625 * @param {Object} [data] - optional data for inputs tensor 626 * @returns {MLTensor} 627 */ 628 async function createTensorWithData(context, dataType, shape, data) { 629 const tensorDesc = {dataType, shape}; 630 if (data) { 631 tensorDesc.writable = true; 632 } else { 633 tensorDesc.readable = true; 634 } 635 let tensor = await context.createTensor(tensorDesc); 636 if (data) { 637 context.writeTensor(tensor, data); 638 } 639 return tensor; 640 } 641 642 async function prepareInputsForGraph(context, resources) { 643 const inputOperandNameArray = Object.keys(resources).filter( 644 operandName => !resources[operandName].constant); 645 const tensors = await Promise.all(inputOperandNameArray.map((operandName) => { 646 const inputOperandResources = resources[operandName]; 647 const descriptor = inputOperandResources.descriptor; 648 const targetDataType = 649 descriptor.castedType ? descriptor.castedType : descriptor.dataType; 650 const inputBuffer = getTypedArrayData( 651 targetDataType, sizeOfShape(descriptor.shape), 652 inputOperandResources.data); 653 return createTensorWithData( 654 context, targetDataType, descriptor.shape, inputBuffer); 655 })); 656 657 const inputs = {}; 658 inputOperandNameArray.forEach((name, index) => inputs[name] = tensors[index]); 659 return inputs; 660 } 661 662 async function prepareOutputsForGraph(context, resources) { 663 const outputOperandNameArray = Object.keys(resources); 664 const tensors = 665 await Promise.all(outputOperandNameArray.map((operandName) => { 666 const descriptor = resources[operandName].descriptor; 667 const dataType = 668 descriptor.castedType ? descriptor.castedType : descriptor.dataType; 669 return createTensorWithData(context, dataType, descriptor.shape); 670 })); 671 672 const outputs = {}; 673 outputOperandNameArray.forEach( 674 (name, index) => outputs[name] = tensors[index]); 675 return outputs; 676 } 677 678 function getInputName(operatorArguments, operandName) { 679 for (let argument of operatorArguments) { 680 const name = Object.keys(argument)[0]; 681 if (name === operandName) { 682 return argument[operandName]; 683 } else if (name === 'options') { 684 if (Object.keys(argument[name]).includes(operandName)) { 685 return argument[name][operandName]; 686 } 687 } 688 } 689 return null; 690 } 691 692 // This assert() function is to check whether configurations of test case are 693 // set correctly. 694 function assert(condition, message) { 695 if (!condition) { 696 throw new Error(`Wrong test case, ${message}`); 697 } 698 } 699 700 function validateContextSupportsGraph(context, graph) { 701 const supportLimits = context.opSupportLimits(); 702 const castOpSupportLimits = supportLimits.cast; 703 const inputDataTypes = supportLimits.input.dataTypes; 704 const inputRankRange = supportLimits.input.rankRange; 705 const constantDataTypes = supportLimits.constant.dataTypes; 706 const constantRankRange = supportLimits.constant.rankRange; 707 const outputDataTypes = supportLimits.output.dataTypes; 708 const outputRankRange = supportLimits.output.rankRange; 709 710 function validateInputOrConstantDataTypeAndRank( 711 inputName, operatorSupportLimits, operand) { 712 const inputDescriptor = graph.inputs[inputName].descriptor; 713 const inputDataType = inputDescriptor.dataType; 714 const inputRank = inputDescriptor.shape.length; 715 if (inputDescriptor.constant) { 716 // Check graph constant data type 717 if (!constantDataTypes.includes(inputDataType) && 718 !findCompatibleType( 719 inputDataType, constantDataTypes, castOpSupportLimits)) { 720 throw new TypeError( 721 `Unsupported data type, constant '${operand}' data type ${ 722 inputDataType} must be one of [${constantDataTypes}].`); 723 } 724 725 // Check graph constant rank 726 if (inputRank < constantRankRange.min) { 727 throw new TypeError(`Unsupported rank ${inputRank} for constant '${ 728 operand}' (must be at least ${constantRankRange.min}).`); 729 } 730 if (inputRank > constantRankRange.max) { 731 throw new TypeError(`Unsupported rank ${inputRank} for constant '${ 732 operand}' (must be at most ${constantRankRange.max}).`); 733 } 734 } else { 735 // Check graph input data type 736 if (!inputDataTypes.includes(inputDataType) && 737 !findCompatibleType( 738 inputDataType, inputDataTypes, castOpSupportLimits)) { 739 throw new TypeError( 740 `Unsupported data type, input '${operand}' data type ${ 741 inputDataType} must be one of [${inputDataTypes}].`); 742 } 743 744 // Check graph input rank 745 if (inputRank < inputRankRange.min) { 746 throw new TypeError(`Unsupported rank ${inputRank} for input '${ 747 operand}' (must be at least ${inputRankRange.min}).`); 748 } 749 if (inputRank > inputRankRange.max) { 750 throw new TypeError(`Unsupported rank ${inputRank} for input '${ 751 operand}' (must be at most ${inputRankRange.max}).`); 752 } 753 } 754 755 const operandSupportLimits = operatorSupportLimits[operand]; 756 // Check operand data type 757 const inputOperandDataTypes = operandSupportLimits.dataTypes; 758 if (!inputOperandDataTypes.includes(inputDataType) && 759 !findCompatibleType( 760 inputDataType, inputDataTypes, castOpSupportLimits)) { 761 throw new TypeError( 762 `Unsupported data type, input '${operand}' data type ${ 763 inputDataType} must be one of [${inputOperandDataTypes}].`); 764 } 765 766 // Check operand rank 767 const limitsRankRange = operandSupportLimits.rankRange; 768 if (inputRank < limitsRankRange.min) { 769 throw new TypeError(`Unsupported rank ${inputRank} for argument ${ 770 operand} (must be at least ${limitsRankRange.min}).`); 771 } 772 773 if (inputRank > limitsRankRange.max) { 774 throw new TypeError(`Unsupported rank ${inputRank} for argument ${ 775 operand} (must be at most ${limitsRankRange.max}).`); 776 } 777 } 778 779 function validateOutputDataTypeAndRank( 780 outputName, operatorSupportLimits, operand) { 781 const outputDataType = 782 graph.expectedOutputs[outputName].descriptor.dataType; 783 const outputRank = 784 graph.expectedOutputs[outputName].descriptor.shape.length; 785 // Check graph output data type 786 if (!outputDataTypes.includes(outputDataType) && 787 !findCompatibleType( 788 outputDataType, outputDataTypes, castOpSupportLimits)) { 789 throw new TypeError( 790 `Unsupported data type, output '${operand}' data type ${ 791 outputDataType} must be one of [${outputDataTypes}].`); 792 } 793 794 // Check graph output rank 795 if (outputRank < outputRankRange.min) { 796 throw new TypeError(`Unsupported rank ${outputRank} for output '${ 797 operand}' (must be at least ${outputRankRange.min}).`); 798 } 799 if (outputRank > outputRankRange.max) { 800 throw new TypeError(`Unsupported rank ${outputRank} for output '${ 801 operand}' (must be at most ${outputRankRange.max}).`); 802 } 803 804 // Check output operand data type 805 const outputOperandDataTypes = operatorSupportLimits[operand].dataTypes; 806 if (!outputOperandDataTypes.includes(outputDataType) && 807 !findCompatibleType( 808 outputOperandDataTypes, outputDataTypes, castOpSupportLimits)) { 809 throw new TypeError( 810 `Unsupported data type, output '${operand}' data type ${ 811 outputDataType} must be one of [${outputOperandDataTypes}].`); 812 } 813 814 // Check output operand rank 815 const outputOperandRankRange = operatorSupportLimits[operand].rankRange; 816 if (outputRank < outputOperandRankRange.min) { 817 throw new TypeError(`Unsupported rank ${outputRank} for output '${ 818 operand}' (must be at least ${outputOperandRankRange.min}).`); 819 } 820 if (outputRank > outputOperandRankRange.max) { 821 throw new TypeError(`Unsupported rank ${outputRank} for output '${ 822 operand}' (must be at most ${outputOperandRankRange.max}).`); 823 } 824 } 825 826 try { 827 for (let operator of graph.operators) { 828 const operatorName = operator.name; 829 const operatorSupportLimits = supportLimits[operatorName]; 830 for (let operand of Object.keys(operatorSupportLimits)) { 831 if (operand === 'output') { 832 // single output operand 833 assert( 834 typeof operator.outputs === 'string', 835 `the outputs of ${operatorName} should be a string.`); 836 if (!graph.expectedOutputs[operator.outputs]) { 837 // intermediate output 838 continue; 839 } 840 validateOutputDataTypeAndRank( 841 operator.outputs, operatorSupportLimits, 'output'); 842 } else if (operand === 'outputs') { 843 // multiple output operands of split operator 844 assert( 845 Array.isArray(operator.outputs), 846 `the outputs of ${operatorName} should be a string array.`); 847 for (const outputName of operator.outputs) { 848 assert( 849 typeof outputName === 'string', 850 `the outputs' item of ${operatorName} should be a string.`); 851 if (!graph.expectedOutputs[outputName]) { 852 // intermediate output 853 continue; 854 } 855 validateOutputDataTypeAndRank( 856 outputName, operatorSupportLimits, 'outputs'); 857 } 858 } else if (/output[0-2]/.test(operand)) { 859 // multiple output operands of gru/lstm/lstmCell operators 860 assert( 861 Array.isArray(operator.outputs), 862 `the outputs of ${operatorName} should be a string array.`); 863 const index = parseInt(operand.match(/output([0-2])/)[1]); 864 if (index < operator.outputs.length) { 865 validateOutputDataTypeAndRank( 866 operator.outputs[index], operatorSupportLimits, operand); 867 } 868 } else { 869 // input operand(s) 870 if (operatorName === 'concat') { 871 const inputNameArray = operator.arguments[0][operand]; 872 assert( 873 Array.isArray(inputNameArray), 874 `the inputs of ${operatorName} should be a string array.`); 875 for (const inputName of inputNameArray) { 876 assert( 877 typeof inputName === 'string', 878 `the inputs' item of ${operatorName} should be a string.`); 879 if (!graph.inputs[inputName]) { 880 // intermediate input 881 continue; 882 } 883 validateInputOrConstantDataTypeAndRank( 884 inputName, operatorSupportLimits, 'inputs'); 885 } 886 } else { 887 const inputName = getInputName(operator.arguments, operand); 888 if (inputName === null || !graph.inputs[inputName]) { 889 // default options argument or intermediate input 890 continue; 891 } 892 validateInputOrConstantDataTypeAndRank( 893 inputName, operatorSupportLimits, operand); 894 } 895 } 896 } 897 } 898 return /*supported*/ true; 899 } catch (error) { 900 return /*not supported*/ false; 901 } 902 } 903 904 /** 905 * This function is to execute the compiled graph. 906 * @param {MLContext} context 907 * @param {MLGraph} graph 908 * @param {Map<String, { 909 * data: Array.<Number>|Number, 910 * descriptor: MLOperandDescriptor, 911 * constant?: Boolean 912 * }>} graphInputs 913 * @param {Map<String, { 914 * data: Array.<Number>|Number, 915 * descriptor: MLOperandDescriptor, 916 * }>} expectedOutputs 917 * @returns A result object. 918 */ 919 async function computeGraph(context, graph, graphInputs, expectedOutputs) { 920 const inputs = await prepareInputsForGraph(context, graphInputs); 921 const outputs = await prepareOutputsForGraph(context, expectedOutputs); 922 923 // Execute the compiled graph. 924 context.dispatch(graph, inputs, outputs); 925 926 const result = {}; 927 const outputNameArray = Object.keys(expectedOutputs); 928 const outputBuffers = await Promise.all(Object.values(outputs).map( 929 (tensor) => {return context.readTensor(tensor)})); 930 outputNameArray.forEach((name, index) => { 931 const dataType = expectedOutputs[name].descriptor.castedType ? 932 expectedOutputs[name].descriptor.castedType : 933 expectedOutputs[name].descriptor.dataType; 934 result[name] = new TypedArrayDict[dataType](outputBuffers[index]) 935 }); 936 937 return result; 938 } 939 940 /** 941 * This function is to compile and execute the constructed graph. 942 * @param {MLContext} context 943 * @param {MLGraphBuilder} builder 944 * @param {{ 945 * inputs: Map<String, { 946 * data: Array.<Number>|Number, 947 * descriptor: MLOperandDescriptor, 948 * constant?: Boolean 949 * }>, 950 * operators: Array.<{ 951 * name: String, 952 * arguments: Array.<Map<String, Object>> , 953 * outputs: Array.<String>|String 954 * }>, 955 * expectedOutputs: Map<String, { 956 * data: Array.<Number>|Number, 957 * descriptor: MLOperandDescriptor, 958 * }> 959 * }} graphResources - Resources used for building a graph 960 * @returns A Promise of MLComputeResult. 961 */ 962 const buildAndExecuteGraph = async (context, builder, graphResources) => { 963 const outputOperands = []; 964 const graphInputs = graphResources.inputs; 965 const graphOperators = graphResources.operators; 966 const intermediateOperands = {}; 967 for (const operator of graphOperators) { 968 const argumentArray = []; 969 for (const argument of operator.arguments) { 970 for (const argumentName in argument) { 971 if (argumentName !== 'options') { 972 if (operator.name === 'concat' && argumentName === 'inputs') { 973 const concatInputs = []; 974 for (const inputName of argument[argumentName]) { 975 if (graphInputs.hasOwnProperty(inputName)) { 976 const operandName = inputName; 977 const operand = createOperand( 978 context, builder, operandName, graphInputs[operandName]); 979 concatInputs.push(operand); 980 } else if (intermediateOperands.hasOwnProperty(inputName)) { 981 concatInputs.push(intermediateOperands[inputName]); 982 } 983 // concatInputs.push(intermediateOperands[inputName]); 984 } 985 argumentArray.push(concatInputs); 986 } else if (graphInputs.hasOwnProperty(argument[argumentName])) { 987 const operandName = argument[argumentName]; 988 const operand = createOperand( 989 context, builder, operandName, graphInputs[operandName]); 990 argumentArray.push(operand); 991 } else if (intermediateOperands.hasOwnProperty( 992 argument[argumentName])) { 993 argumentArray.push(intermediateOperands[argument[argumentName]]); 994 } else { 995 argumentArray.push(argument[argumentName]); 996 } 997 } else { 998 for (const [optionalArgumentName, value] of Object.entries( 999 argument['options'])) { 1000 if (typeof value === 'string' && 1001 graphInputs.hasOwnProperty(value)) { 1002 const operandName = value; 1003 const operand = createOperand( 1004 context, builder, operandName, graphInputs[operandName]); 1005 argument['options'][optionalArgumentName] = operand; 1006 } else if ( 1007 typeof value === 'string' && 1008 intermediateOperands.hasOwnProperty(value)) { 1009 argument['options'][optionalArgumentName] = 1010 intermediateOperands[value]; 1011 } 1012 } 1013 argumentArray.push(argument['options']); 1014 } 1015 } 1016 } 1017 1018 const currentOutput = builder[operator.name](...argumentArray); 1019 if (Array.isArray(operator.outputs)) { 1020 operator.outputs.forEach((outputName, index) => { 1021 intermediateOperands[outputName] = currentOutput[index]; 1022 }); 1023 } else { 1024 intermediateOperands[operator.outputs] = currentOutput; 1025 } 1026 } 1027 1028 const outputNames = Object.keys(graphResources.expectedOutputs); 1029 outputNames.forEach(outputName => { 1030 if (intermediateOperands.hasOwnProperty(outputName)) { 1031 outputOperands.push(intermediateOperands[outputName]); 1032 } 1033 }); 1034 1035 if (outputOperands.length !== outputNames.length) { 1036 throw new Error('Graph outputs are not properly defined'); 1037 } 1038 1039 for (let i = 0; i < outputOperands.length; ++i) { 1040 const expectedDescriptor = 1041 graphResources 1042 .expectedOutputs[Object.keys(graphResources.expectedOutputs)[i]] 1043 .descriptor; 1044 if (!context.opSupportLimits().output.dataTypes.includes( 1045 expectedDescriptor.dataType)) { 1046 const compatibleType = findCompatibleType( 1047 expectedDescriptor.dataType, 1048 context.opSupportLimits().output.dataTypes, 1049 context.opSupportLimits().cast); 1050 outputOperands[i] = builder.cast(outputOperands[i], compatibleType); 1051 expectedDescriptor.castedType = compatibleType; 1052 } 1053 } 1054 1055 const outputNameArray = Object.keys(graphResources.expectedOutputs); 1056 for (let i = 0; i < outputOperands.length; ++i) { 1057 assertDescriptorsEquals( 1058 outputOperands[i], 1059 graphResources.expectedOutputs[outputNameArray[i]].descriptor); 1060 } 1061 1062 const namedOutputOperand = {}; 1063 outputNameArray.forEach( 1064 (name, index) => namedOutputOperand[name] = outputOperands[index]); 1065 1066 // Compile the constructed graph. 1067 const graph = await builder.build(namedOutputOperand); 1068 1069 // Execute the compiled graph. 1070 const result = await computeGraph( 1071 context, graph, graphInputs, graphResources.expectedOutputs); 1072 1073 return {result, intermediateOperands}; 1074 }; 1075 1076 const getGemmPrecisionTolerance = 1077 (op, graphResources, intermediateOperands) => { 1078 // GEMM : alpha * (A x B) + beta * C 1079 // An upper bound for the worst serial ordering is bounded by 1080 // the number of lossy operations, where matrix multiplication 1081 // is a dot product (mul and add times the number of elements) 1082 // plus bias operations. 1083 const {inputs} = graphResources; 1084 const args = op.arguments; 1085 let ShapeA; 1086 const indexA = args[0][Object.keys(args[0])[0]]; 1087 if (inputs[indexA]) { 1088 ShapeA = inputs[indexA].descriptor.shape; 1089 } else { 1090 ShapeA = intermediateOperands[indexA].shape; 1091 } 1092 const options = 1093 args.length === 3 ? {...args[2][Object.keys(args[2])[0]]} : {}; 1094 const width = options.aTranspose ? ShapeA[0] : ShapeA[1]; 1095 let tolerance = width * 2; 1096 // default options.alpha is 1.0 1097 if (options.alpha !== undefined && options.alpha !== 1.0) { 1098 tolerance++; 1099 } 1100 if (options.c && options.beta !== 0.0) { 1101 // default options.beta is 1.0 1102 if (options.beta !== undefined && options.beta !== 1.0) { 1103 tolerance++; 1104 } 1105 tolerance++; 1106 } 1107 1108 const toleranceValueDict = {float32: tolerance, float16: tolerance}; 1109 const expectedDataType = 1110 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 1111 return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]}; 1112 }; 1113 1114 const getMatmulPrecisionTolerance = 1115 (op, graphResources, intermediateOperands) => { 1116 const {inputs} = graphResources; 1117 const args = op.arguments; 1118 let shapeA; 1119 const indexA = args[0][Object.keys(args[0])[0]]; 1120 if (inputs[indexA]) { 1121 shapeA = inputs[indexA].descriptor.shape; 1122 } else { 1123 shapeA = intermediateOperands[indexA].shape; 1124 } 1125 const tolerance = shapeA[shapeA.length - 1] * 2; 1126 const toleranceValueDict = {float32: tolerance, float16: tolerance}; 1127 const expectedDataType = 1128 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 1129 return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]}; 1130 }; 1131 1132 const getConv2dPrecisionTolerance = 1133 (op, graphResources, intermediateOperands) => { 1134 // number of reduced input elements multiplied by filter and summed (a sliding 1135 // dot product like pooling) 1136 const {inputs} = graphResources; 1137 const operatorName = op.name; 1138 const args = op.arguments; 1139 let inputShape; 1140 const inputIndex = args[0][Object.keys(args[0])[0]]; 1141 const filterIndex = args[1][Object.keys(args[1])[0]]; 1142 if (inputs[inputIndex]) { 1143 inputShape = inputs[inputIndex].descriptor.shape; 1144 } else { 1145 inputShape = intermediateOperands[inputIndex].shape; 1146 } 1147 let filterShape; 1148 if (inputs[filterIndex]) { 1149 filterShape = inputs[filterIndex].descriptor.shape; 1150 } else { 1151 filterShape = intermediateOperands[filterIndex].shape; 1152 } 1153 const options = 1154 args.length === 3 ? {...args[2][Object.keys(args[2])[0]]} : {}; 1155 let inputChannels = inputShape[1]; // default nchw inputLayout 1156 // default oihw filterLayout for conv2d or default iohw filterLayout for 1157 // convTranspose2d 1158 let filterWidth = filterShape[3]; 1159 let filterHeight = filterShape[2]; 1160 const groups = options.groups ? options.groups : 1; 1161 1162 if (options.inputLayout) { 1163 if (!['nchw', 'nhwc'].includes(options.inputLayout)) { 1164 throw new Error(`Unknown inputLayout ${options.inputLayout}`); 1165 } 1166 inputChannels = 1167 options.inputLayout === 'nchw' ? inputChannels : inputShape[3]; 1168 } 1169 if (options.filterLayout) { 1170 let filterLayouts = ['oihw', 'hwio', 'ohwi', 'ihwo']; // default for conv2d 1171 if (operatorName === 'convTranspose2d') { 1172 filterLayouts = ['iohw', 'hwoi', 'ohwi']; 1173 } 1174 if (!filterLayouts.includes(options.filterLayout)) { 1175 throw new Error(`Unknown filterLayout ${options.filterLayout}`); 1176 } 1177 switch (options.filterLayout) { 1178 case 'oihw': 1179 case 'iohw': 1180 // Just use the existing filterWidth and filterHeight above. 1181 break; 1182 case 'hwio': 1183 case 'hwoi': 1184 filterWidth = filterShape[1]; 1185 filterHeight = filterShape[0]; 1186 break; 1187 case 'ohwi': 1188 case 'ihwo': 1189 filterWidth = filterShape[2]; 1190 filterHeight = filterShape[1]; 1191 break; 1192 default: 1193 break; 1194 } 1195 } 1196 1197 const tolerance = filterWidth * filterHeight * (inputChannels / groups) * 2; 1198 const toleranceValueDict = {float32: tolerance, float16: tolerance}; 1199 const expectedDataType = 1200 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 1201 return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]}; 1202 }; 1203 1204 const getPoolingOperatorsPrecisionTolerance = 1205 (op, graphResources, intermediateOperands) => { 1206 const args = op.arguments; 1207 const operatorName = op.name; 1208 const {inputs} = graphResources; 1209 let inputShape; 1210 const inputIndex = args[0][Object.keys(args[0])[0]]; 1211 if (inputs[inputIndex]) { 1212 inputShape = inputs[inputIndex].descriptor.shape; 1213 } else { 1214 inputShape = intermediateOperands[inputIndex].shape; 1215 } 1216 const options = 1217 args.length === 2 ? {...args[1][Object.keys(args[1])[0]]} : {}; 1218 let height; 1219 let width; 1220 1221 if (options.windowDimensions) { 1222 height = options.windowDimensions[0]; 1223 width = options.windowDimensions[1]; 1224 } else { 1225 // If not present, the window dimensions are assumed to be the height 1226 // and width dimensions of the input shape 1227 if (options.layout && options.layout === 'nhwc') { 1228 height = inputShape[1]; 1229 width = inputShape[2]; 1230 } else { 1231 // nhwc layout of input 1232 height = inputShape[2]; 1233 width = inputShape[3]; 1234 } 1235 } 1236 1237 const tolerance = height * width + 2; 1238 const toleranceDict = { 1239 averagePool2d: {float32: tolerance, float16: tolerance}, 1240 l2Pool2d: {float32: tolerance, float16: tolerance}, 1241 maxPool2d: {float32: 0, float16: 0}, 1242 }; 1243 const expectedDataType = 1244 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 1245 return { 1246 metricType: 'ULP', 1247 value: toleranceDict[operatorName][expectedDataType] 1248 }; 1249 }; 1250 1251 const getInstanceNormPrecisionTolerance = (graphResources) => { 1252 // according to 1253 // https://github.com/web-platform-tests/wpt/pull/43891#discussion_r1457026316 1254 const toleranceValueDict = {float32: 840, float16: 8400}; 1255 const expectedDataType = 1256 getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs); 1257 return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]}; 1258 }; 1259 1260 const getExpectedDataTypeOfSingleOutput = (expectedOutput) => { 1261 const expectedDescriptor = 1262 expectedOutput[Object.keys(expectedOutput)[0]].descriptor; 1263 const dataType = expectedDescriptor.castedType ? 1264 expectedDescriptor.castedType : 1265 expectedDescriptor.dataType; 1266 return dataType; 1267 }; 1268 1269 const getReductionOperatorsPrecisionTolerance = 1270 (op, graphResources, intermediateOperands) => { 1271 let tolerance; 1272 const operatorName = op.name; 1273 if (op.name === 'reduceMax' || op.name === 'reduceMin') { 1274 tolerance = 0; 1275 } else { 1276 // other reduction operators 1277 const args = op.arguments; 1278 const {inputs} = graphResources; 1279 let inputShape; 1280 const inputIndex = args[0][Object.keys(args[0])[0]]; 1281 if (inputs[inputIndex]) { 1282 inputShape = inputs[inputIndex].descriptor.shape; 1283 } else { 1284 inputShape = intermediateOperands[inputIndex].shape; 1285 } 1286 1287 const rank = inputShape.length; 1288 const options = 1289 args.length === 2 ? {...args[1][Object.keys(args[1])[0]]} : {}; 1290 let sizes; 1291 1292 if (options && options.axes) { 1293 sizes = options.axes.map( 1294 (axis) => axis < 0 ? inputShape[axis + rank] : inputShape[axis]); 1295 } else { 1296 sizes = inputShape; 1297 } 1298 1299 const elementCount = sizes.reduce( 1300 (accumulator, currentValue) => accumulator * currentValue, 1); 1301 tolerance = elementCount; 1302 } 1303 1304 const toleranceDict = { 1305 reduceL1: tolerance, 1306 reduceL2: tolerance * 2 + 2, 1307 reduceLogSum: tolerance + 18, 1308 reduceLogSumExp: tolerance * 2 + 18, 1309 reduceMax: tolerance, 1310 reduceMean: tolerance + 2, 1311 reduceMin: tolerance, 1312 reduceProduct: tolerance, 1313 reduceSum: tolerance, 1314 reduceSumSquare: tolerance * 2 1315 }; 1316 return {metricType: 'ULP', value: toleranceDict[operatorName]}; 1317 }; 1318 1319 const getResample2dPrecisionTolerance = 1320 (op, graphResources, intermediateOperands) => { 1321 const args = op.arguments; 1322 const options = 1323 args.length === 2 ? {...args[1][Object.keys(args[1])[0]]} : {}; 1324 const expectedOutputs = graphResources.expectedOutputs; 1325 const dataType = 1326 expectedOutputs[Object.keys(expectedOutputs)[0]].descriptor.dataType; 1327 let tolerance; 1328 1329 if (options.mode && options.mode === 'linear') { 1330 // interpolation mode is linear 1331 if (dataType === 'float32') { 1332 tolerance = 84; 1333 } else if (dataType === 'float16') { 1334 tolerance = 10; 1335 } else { 1336 tolerance = 1; 1337 } 1338 } else { 1339 // interpolation mode is nearest-neighbor 1340 tolerance = 0; 1341 } 1342 1343 return {metricType: 'ULP', value: tolerance}; 1344 }; 1345 1346 let minimumDataTypeSet; 1347 1348 function checkMinimum(descriptor, operandMinimumLimits) { 1349 const targetRank = descriptor.shape.length; 1350 const targetDataType = descriptor.dataType; 1351 let isMinimum = operandMinimumLimits.dataTypes.includes(targetDataType); 1352 1353 if (isMinimum) { 1354 isMinimum = operandMinimumLimits.rankRange.min <= targetRank && 1355 targetRank <= operandMinimumLimits.rankRange.max; 1356 } 1357 1358 return isMinimum; 1359 } 1360 1361 function getOutputMinimumLimits(operatorsResources, outputOperandName) { 1362 let operatorName; 1363 let outputName; 1364 for (let operator of operatorsResources) { 1365 if (typeof operator.outputs === 'string' && 1366 operator.outputs === outputOperandName) { 1367 operatorName = operator.name; 1368 outputName = 'output'; 1369 break; 1370 } else if ( 1371 Array.isArray(operator.outputs) && 1372 operator.outputs.includes(outputOperandName)) { 1373 // Current gru, lstm, lstmCell and split operators have multiple outputs 1374 operatorName = operator.name; 1375 if (minimumDataTypeSet[operatorName].hasOwnProperty('outputs')) { 1376 // for split operator 1377 outputName = 'outputs'; 1378 } else { 1379 // for gru, lstm, lstmCell operators 1380 outputName = `output${operator.outputs.indexOf(outputOperandName)}`; 1381 } 1382 break; 1383 } 1384 } 1385 1386 return minimumDataTypeSet[operatorName][outputName]; 1387 } 1388 1389 async function getMinimumDataTypeSetJson() { 1390 try { 1391 const response = await fetch('/webnn/resources/minimum_datatype_set.json'); 1392 1393 if (!response.ok) { 1394 throw new Error(`HTTP error! Status: ${response.status}`); 1395 } 1396 1397 const text = await response.text(); 1398 const jsonText = 1399 text.replace(/\/\/.*|\/\*[\s\S]*?\*\//g, ''); // Remove comments 1400 minimumDataTypeSet = JSON.parse(jsonText); 1401 } catch (error) { 1402 throw new Error(`Error fetching and parsing JSON: ${error.message}`); 1403 } 1404 return minimumDataTypeSet; 1405 } 1406 1407 function isMinimumTest(test) { 1408 let isMinimum = false; 1409 const graphResources = test.graph; 1410 const inputsResources = graphResources.inputs; 1411 1412 // check inputs 1413 for (let operator of graphResources.operators) { 1414 const minimumLimits = minimumDataTypeSet[operator.name]; 1415 for (let argument of operator.arguments) { 1416 for (let [operandName, value] of Object.entries(argument)) { 1417 if (operandName !== 'options') { 1418 if (typeof value === 'string' && 1419 inputsResources.hasOwnProperty(value)) { 1420 isMinimum = checkMinimum( 1421 inputsResources[value].descriptor, minimumLimits[operandName]); 1422 if (!isMinimum) { 1423 return isMinimum; 1424 } 1425 } else if (Array.isArray(value)) { 1426 for (let subValue of value) { 1427 if (typeof subValue === 'string' && 1428 inputsResources.hasOwnProperty(subValue)) { 1429 isMinimum = checkMinimum( 1430 inputsResources[subValue].descriptor, 1431 minimumLimits[operandName]); 1432 if (!isMinimum) { 1433 return isMinimum; 1434 } 1435 } 1436 } 1437 } 1438 } else { 1439 for (let [optionOperandName, optionValue] of Object.entries( 1440 argument['options'])) { 1441 if (typeof value === 'string' && 1442 inputsResources.hasOwnProperty(optionValue)) { 1443 isMinimum = checkMinimum( 1444 inputsResources[optionValue].descriptor, 1445 minimumLimits[optionOperandName]); 1446 if (!isMinimum) { 1447 return isMinimum; 1448 } 1449 } 1450 } 1451 } 1452 } 1453 } 1454 } 1455 1456 // check outputs 1457 const outputsResources = graphResources.expectedOutputs; 1458 for (let [outputOperandName, value] of Object.entries(outputsResources)) { 1459 const outputMinimumLimits = 1460 getOutputMinimumLimits(graphResources.operators, outputOperandName) 1461 isMinimum = checkMinimum(value.descriptor, outputMinimumLimits); 1462 if (!isMinimum) { 1463 return isMinimum; 1464 } 1465 } 1466 1467 return isMinimum; 1468 } 1469 1470 // This array is to save skipped tests which are optional tests unsupported by 1471 // the context. It's helpful to debug to get detail skipped tests in browser 1472 // console by typing testsToSkip after running tests. 1473 const testsToSkip = []; 1474 1475 async function webnn_conformance_test( 1476 tests, buildAndExecuteGraphFunc, toleranceFunc) { 1477 if (navigator.ml === undefined) { 1478 test(() => assert_implements(navigator.ml, 'missing navigator.ml')); 1479 } else { 1480 const testsToRun = []; 1481 promise_setup(async () => { 1482 // Create a context for checking whether tests are supported. 1483 const context = await getContext(); 1484 minimumDataTypeSet = await getMinimumDataTypeSetJson(); 1485 tests.filter(isTargetTest).forEach((test) => { 1486 if (validateContextSupportsGraph(context, test.graph) || 1487 isMinimumTest(test)) { 1488 testsToRun.push(test); 1489 } else { 1490 // This test is optional so it can be skipped. 1491 testsToSkip.push(test); 1492 } 1493 }); 1494 }); 1495 1496 promise_test(async () => { 1497 testsToRun.map((test) => { 1498 promise_test(async () => { 1499 // Create a context for each test. 1500 const context = await getContext(); 1501 const builder = new MLGraphBuilder(context); 1502 const {result, intermediateOperands} = 1503 await buildAndExecuteGraphFunc(context, builder, test.graph); 1504 assertResultsEquals( 1505 toleranceFunc, result, test.graph, intermediateOperands); 1506 }, `${isMinimumTest(test) ? '[required]' : '[optional]'} ${test.name}`); 1507 }); 1508 }); 1509 } 1510 }