tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

utils.js (54276B)


      1 'use strict';
      2 
      3 const operatorToleranceDict = {
      4  argMax: {float32: 0, float16: 0},
      5  argMin: {float32: 0, float16: 0},
      6  batchNormalization: {float32: 6, float16: 6},
      7  clamp: {float32: 0, float16: 0},
      8 
      9  // Element-wise binary operations
     10  add: {float32: 1, float16: 1},
     11  sub: {
     12    float32: 1,
     13    float16: 1,
     14    int8: 0,
     15    uint8: 0,
     16    int32: 0,
     17    uint32: 0,
     18    int64: 0,
     19    uint64: 0
     20  },
     21  mul: {float32: 1, float16: 1},
     22  max: {float32: 0, float16: 0},
     23  min: {float32: 0, float16: 0},
     24  // Element-wise binary operations
     25 
     26  elu: {float32: 18, float16: 18},
     27  gelu: {float32: 18, float16: 18},
     28  hardSigmoid: {float32: 2, float16: 2},
     29  hardSwish: {float32: 4, float16: 4},
     30  leakyRelu: {float32: 1, float16: 2},
     31  linear: {float32: 2, float16: 2},
     32  prelu: {float32: 1, float16: 1},
     33  relu: {float32: 0, float16: 0, int8: 0, int32: 0},
     34  sigmoid: {float32: 34, float16: 10},
     35  softplus: {float32: 18, float16: 18},
     36  softsign: {float32: 3, float16: 3},
     37  tanh: {float32: 16, float16: 16},
     38 };
     39 
     40 const zeroULPToleranceOperatorList = [
     41  // data movement operators
     42  'concat', 'expand', 'gather', 'gatherElements', 'gatherND', 'identity', 'pad',
     43  'reshape', 'reverse', 'scatterElements', 'scatterND', 'slice', 'split',
     44  'tile', 'transpose',
     45 
     46  // element-wise logical operators
     47  'equal', 'notEqual', 'greater', 'greaterOrEqual', 'lesser', 'lesserOrEqual',
     48  'logicalNot', 'logicalAnd', 'logicalOr', 'logicalXor'
     49 ];
     50 
     51 const getZeroULPTolerance = () => {
     52  return {metricType: 'ULP', value: 0};
     53 };
     54 
     55 const getSoftmaxPrecisionTolerance =
     56    (op, graphResources, intermediateOperands) => {
     57      const {inputs} = graphResources;
     58      const args = op.arguments;
     59      let inputShape;
     60      const inputIndex = args[0][Object.keys(args[0])[0]];
     61      if (inputs[inputIndex]) {
     62        inputShape = inputs[inputIndex].descriptor.shape;
     63      } else {
     64        inputShape = intermediateOperands[inputIndex].shape;
     65      }
     66      const axis = args.length === 2 ? args[1][Object.keys(args[1])[0]] : 1;
     67      const tolerance = inputShape[axis] * 3 + 3;
     68      const toleranceValueDict = {float32: tolerance, float16: tolerance};
     69      const expectedDataType =
     70          getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
     71      return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]};
     72    };
     73 
     74 const getPrecisionTolerance = (graphResources, intermediateOperands) => {
     75  const expectedDataType =
     76      getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
     77  let toleranceValue = 0;
     78  graphResources.operators.forEach(op => {
     79    switch (op.name) {
     80      case 'conv2d':
     81        toleranceValue += getConv2dPrecisionTolerance(op, graphResources,
     82            intermediateOperands).value;
     83        break;
     84      case 'convTranspose2d':
     85        toleranceValue += getConv2dPrecisionTolerance(op, graphResources,
     86            intermediateOperands).value;
     87        break;
     88      case 'gemm':
     89        toleranceValue += getGemmPrecisionTolerance(op, graphResources,
     90            intermediateOperands).value;
     91        break;
     92      case 'matmul':
     93        toleranceValue += getMatmulPrecisionTolerance(op, graphResources,
     94            intermediateOperands).value;
     95        break;
     96      case 'softmax':
     97        toleranceValue += getSoftmaxPrecisionTolerance(
     98                              op, graphResources, intermediateOperands)
     99                              .value;
    100        break;
    101      case 'averagePool2d':
    102      case 'maxPool2d':
    103      case 'l2Pool2d':
    104        toleranceValue += getPoolingOperatorsPrecisionTolerance(
    105                              op, graphResources, intermediateOperands)
    106                              .value;
    107        break;
    108      case 'reduceL1':
    109      case 'reduceL2':
    110      case 'reduceLogSum':
    111      case 'reduceLogSumExp':
    112      case 'reduceMax':
    113      case 'reduceMean':
    114      case 'reduceMin':
    115      case 'reduceProduct':
    116      case 'reduceSum':
    117      case 'reduceSumSquare':
    118        toleranceValue += getReductionOperatorsPrecisionTolerance(
    119                              op, graphResources, intermediateOperands)
    120                              .value;
    121        break;
    122      case 'resample2d':
    123        toleranceValue += getResample2dPrecisionTolerance(
    124                              op, graphResources, intermediateOperands)
    125                              .value;
    126        break;
    127      default:
    128        if (zeroULPToleranceOperatorList.includes(op.name)) {
    129          toleranceValue += getZeroULPTolerance().value;
    130        } else {
    131          const operatorTolerance =
    132              operatorToleranceDict[op.name]?.[expectedDataType];
    133          if (operatorTolerance !== undefined) {
    134            toleranceValue += operatorTolerance;
    135          }
    136        }
    137    }
    138  });
    139  return {metricType: 'ULP', value: toleranceValue};
    140 };
    141 
    142 // https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype
    143 const TypedArrayDict = {
    144  float32: Float32Array,
    145 
    146  // Proposal to add float16 TypedArrays to JavaScript.
    147  // URL: https://tc39.es/proposal-float16array/
    148  // Use workaround Uint16 for Float16
    149  float16: Uint16Array,
    150 
    151  int64: BigInt64Array,
    152  uint64: BigUint64Array,
    153  int32: Int32Array,
    154  uint32: Uint32Array,
    155  int8: Int8Array,
    156  uint8: Uint8Array,
    157  int4: Uint8Array,
    158  uint4: Uint8Array,
    159 };
    160 
    161 const kIntTypes =
    162    ['uint4', 'int4', 'uint8', 'int8', 'uint32', 'int32', 'uint64', 'int64'];
    163 const kFloatTypes = ['float16', 'float32'];
    164 
    165 const findCompatibleType = (dataType, supportedTypes, castOpSupportLimits) => {
    166  if (!castOpSupportLimits.input.dataTypes.includes(dataType)) {
    167    // Cannot cast from `dataType` to any other type.
    168    return null;
    169  }
    170 
    171  for (let supportedType of supportedTypes) {
    172    if (kIntTypes.includes(dataType) &&
    173        castOpSupportLimits.output.dataTypes.includes(dataType) &&
    174        kIntTypes.indexOf(supportedType) > kIntTypes.indexOf(dataType)) {
    175      return supportedType;
    176    }
    177 
    178    if (kFloatTypes.includes(dataType)) {
    179      if (kFloatTypes.indexOf(supportedType) > kFloatTypes.indexOf(dataType)) {
    180        return supportedType;
    181      }
    182    }
    183  }
    184  return null;
    185 };
    186 
    187 // The maximum index to validate for the output's expected value.
    188 const kMaximumIndexToValidate = 1000;
    189 
    190 const kContextOptionsForVariant = {
    191  cpu: {
    192    deviceType: 'cpu',
    193  },
    194  gpu: {
    195    deviceType: 'gpu',
    196  },
    197  npu: {
    198    deviceType: 'npu',
    199  },
    200 };
    201 
    202 const searchParams = new URLSearchParams(location.search);
    203 const variant = searchParams.get('device') || location.search.substring(1);
    204 const contextOptions = kContextOptionsForVariant[variant];
    205 
    206 async function getContext() {
    207  let context;
    208  try {
    209    context = await navigator.ml.createContext(contextOptions);
    210  } catch (e) {
    211    throw new AssertionError(
    212        `Unable to create context for ${variant} variant. ${e}`);
    213  }
    214  return context;
    215 }
    216 
    217 const tcNameArray = searchParams.getAll('tc');
    218 
    219 function isTargetTest(test) {
    220  return tcNameArray.length === 0 || tcNameArray.includes(test.name);
    221 }
    222 
    223 const assertDescriptorsEquals = (outputOperand, expected) => {
    224  const dataType =
    225      expected.castedType ? expected.castedType : expected.dataType;
    226  assert_equals(
    227      outputOperand.dataType, dataType,
    228      'actual output dataType should be equal to expected output dataType');
    229  assert_array_equals(
    230      outputOperand.shape, expected.shape,
    231      'actual output shape should be equal to expected output shape');
    232 };
    233 
    234 // ref:
    235 // http://stackoverflow.com/questions/32633585/how-do-you-convert-to-half-floats-in-javascript
    236 const toHalf = (value) => {
    237  let floatView = new Float32Array(1);
    238  let int32View = new Int32Array(floatView.buffer);
    239 
    240  /* This method is faster than the OpenEXR implementation (very often
    241   * used, eg. in Ogre), with the additional benefit of rounding, inspired
    242   * by James Tursa's half-precision code. */
    243 
    244  floatView[0] = value;
    245  let x = int32View[0];
    246 
    247  let bits = (x >> 16) & 0x8000; /* Get the sign */
    248  let m = (x >> 12) & 0x07ff;    /* Keep one extra bit for rounding */
    249  let e = (x >> 23) & 0xff;      /* Using int is faster here */
    250 
    251  /* If zero, or denormal, or exponent underflows too much for a denormal
    252   * half, return signed zero. */
    253  if (e < 103) {
    254    return bits;
    255  }
    256 
    257  /* If NaN, return NaN. If Inf or exponent overflow, return Inf. */
    258  if (e > 142) {
    259    bits |= 0x7c00;
    260    /* If exponent was 0xff and one mantissa bit was set, it means NaN,
    261     * not Inf, so make sure we set one mantissa bit too. */
    262    if (e == 255 && (x & 0x007fffff)) {
    263      bits |= 1;
    264    }
    265    return bits;
    266  }
    267 
    268  /* If exponent underflows but not too much, return a denormal */
    269  if (e < 113) {
    270    m |= 0x0800;
    271    /* Extra rounding may overflow and set mantissa to 0 and exponent
    272     * to 1, which is OK. */
    273    bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
    274    return bits;
    275  }
    276 
    277  bits |= ((e - 112) << 10) | (m >> 1);
    278  /* Extra rounding. An overflow will set mantissa to 0 and increment
    279   * the exponent, which is OK. */
    280  bits += m & 1;
    281  return bits;
    282 };
    283 
    284 const getTypedArrayData = (type, size, data) => {
    285  let outData;
    286 
    287  if (type === 'float16') {
    288    if (typeof (data) === 'number' && size > 1) {
    289      return new TypedArrayDict[type](size).fill(toHalf(data));
    290    }
    291    // workaround to convert Float16 to Uint16
    292    outData = new TypedArrayDict[type](data.length);
    293    for (let i = 0; i < data.length; i++) {
    294      outData[i] = toHalf(data[i]);
    295    }
    296  } else if (type === 'int64' || type === 'uint64') {
    297    if (typeof (data) === 'number' && size > 1) {
    298      return new TypedArrayDict[type](size).fill(BigInt(data));
    299    }
    300    outData = new TypedArrayDict[type](data.length);
    301    for (let i = 0; i < data.length; i++) {
    302      outData[i] = BigInt(data[i]);
    303    }
    304  } else if (type === 'uint4' || type === 'int4') {
    305    // The first nybble is stored in the first bits 0-3, and later bits 4-7
    306    // store the later nybble. The data is packed, without any padding between
    307    // dimensions. For example: an array of uint4:
    308    //   size = [2,5]
    309    //   values = [1,2,3,4,5,6,7,8,9,10]
    310    // Would yield 5 hex bytes:
    311    //   Uint8Array.of(0x21, 0x43, 0x65, 0x87, 0xA9);
    312    const array = new TypedArrayDict[type](Math.ceil(size / 2));
    313    let i = 0;
    314    while (i < size - 1) {
    315      const packedByte = ((data[i + 1] & 0xF) << 4) | (data[i] & 0xF);
    316      array[Math.floor(i / 2)] = packedByte;
    317      i = i + 2;
    318    }
    319    // Handle the odd size.
    320    if (i === size - 1) {
    321      const packedByte = data[i] & 0xF;
    322      array[Math.floor(i / 2)] = packedByte;
    323    }
    324    return array;
    325  } else {
    326    if (typeof (data) === 'number' && size > 1) {
    327      return new TypedArrayDict[type](size).fill(data);
    328    }
    329    outData = new TypedArrayDict[type](data);
    330  }
    331  return outData;
    332 };
    333 
    334 const sizeOfShape = (array) => {
    335  return array.reduce((accumulator, currentValue) => accumulator * currentValue, 1);
    336 };
    337 
    338 /**
    339 * Get bitwise of the given value.
    340 * @param {Number} value
    341 * @param {String} dataType - A data type string; currently only "float32" is
    342 *     supported by this function.
    343 * @return {BigInt} A 64-bit signed integer.
    344 */
    345 const getBitwise = (value, dataType) => {
    346  const buffer = new ArrayBuffer(8);
    347  const int64Array = new BigInt64Array(buffer);
    348  let typedArray;
    349  if (dataType === "float32") {
    350    typedArray = new Float32Array(buffer);
    351  } else {
    352    throw new AssertionError(`Data type ${dataType} is not supported`);
    353  }
    354  typedArray[0] = Math.abs(value);
    355  const int64 = int64Array[0];
    356  return (value < 0) ? -int64 : int64;
    357 };
    358 
    359 /**
    360 * Assert that each array property in ``actual`` is a number being close enough
    361 * to the corresponding property in ``expected`` by the acceptable ULP distance
    362 * ``nulp`` with given ``dataType`` data type.
    363 *
    364 * @param {Array} actual - Array of test values.
    365 * @param {Array} expected - Array of values expected to be close to the values
    366 *     in ``actual``.
    367 * @param {(Number|BigInt)} nulp - A value indicates acceptable ULP distance.
    368 * @param {String} dataType - A data type string, value: "float32",
    369 *     more types, please see:
    370 *     https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype
    371 * @param {String} description - Description of the condition being tested.
    372 */
    373 const assert_array_approx_equals_ulp = (actual, expected, nulp, dataType, description) => {
    374  /*
    375    * Test if two primitive arrays are equal within acceptable ULP distance
    376    */
    377  assert_equals(
    378      actual.length, expected.length,
    379      `assert_array_approx_equals_ulp: ${description} lengths differ`);
    380  for (let i = 0; i < actual.length; i++) {
    381    if (actual[i] === expected[i]) {
    382      continue;
    383    } else {
    384      let distance = ulpDistance(actual[i], expected[i], dataType);
    385 
    386      // TODO: See if callers can be updated to pass matching type.
    387      nulp = typeof distance === 'bigint' ? BigInt(nulp) : Number(nulp);
    388 
    389      assert_less_than_equal(distance, nulp,
    390            `assert_array_approx_equals_ulp: ${description} actual ` +
    391                `${
    392                    dataType === 'float16' ?
    393                        float16AsUint16ToNumber(actual[i]) :
    394                        actual[i]} should be close enough to expected ` +
    395                `${expected[i]} by ULP distance:`);
    396    }
    397  }
    398 };
    399 
    400 /**
    401 * Compute the ULP distance between ``a`` and ``b`` for the given ``dataType``.
    402 *
    403 * @param {(Number|BigInt)} a - First value.
    404 * @param {(Number|BigInt)} b - Second value.
    405 * @param {String} dataType - A data type string, value: "float32",
    406 *     more types, please see:
    407 *     https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype
    408 */
    409 const ulpDistance = (a, b, dataType) => {
    410  let aBitwise, bBitwise;
    411  // measure the ULP distance
    412  if (dataType === 'float32') {
    413    aBitwise = getBitwise(a, dataType);
    414    bBitwise = getBitwise(b, dataType);
    415  } else if (dataType === 'float16') {
    416    aBitwise = a;
    417    // convert b data of Float16 to Uint16
    418    bBitwise = toHalf(b);
    419 
    420    // Workaround to use mask to check returned special float16 value -0.0 which
    421    // is 32768 (1000 0000 0000 0000) of uint16
    422    const signExclusionMask = 0x00007FFF;
    423    if ((aBitwise & signExclusionMask) === 0 &&
    424        (bBitwise & signExclusionMask) === 0) {
    425      return 0;
    426    }
    427  } else if (dataType === 'int64' || dataType === 'uint64') {
    428    aBitwise = BigInt(a);
    429    bBitwise = BigInt(b);
    430  } else if (
    431      dataType === 'int8' || dataType === 'uint8' || dataType === 'int32' ||
    432      dataType === 'uint32' || dataType === 'int4' || dataType === 'uint4') {
    433    aBitwise = a;
    434    bBitwise = b;
    435  } else {
    436    throw new AssertionError(`Data type ${dataType} is not supported`);
    437  }
    438  const distance = aBitwise - bBitwise;
    439  return distance >= 0 ? distance : -distance;
    440 };
    441 
    442 /**
    443 * This function converts a Float16 stored as the bits of a Uint16 into a
    444 * JavaScript Number.
    445 * @param {Number} uint16 - a Float16 stored as the bits of a Uint16
    446 * @returns An emulated Float16 number.
    447 */
    448 function float16AsUint16ToNumber(uint16) {
    449  const sign = (uint16 >> 15) & 0x1;
    450  const exponent = (uint16 >> 10) & 0x1F;
    451  const mantissa = uint16 & 0x3FF;
    452  let float16;
    453 
    454  if (exponent === 0) {
    455    // Subnormal number
    456    float16 = (mantissa / 1024) * Math.pow(2, -14);
    457  } else if (exponent === 0x1F) {
    458    // NaN or Infinity
    459    float16 = mantissa ? NaN : Infinity;
    460  } else {
    461    // Normalized number
    462    float16 = (1 + mantissa / 1024) * Math.pow(2, exponent - 15);
    463  }
    464 
    465  // Apply the sign
    466  return sign ? -float16 : float16;
    467 }
    468 
    469 /**
    470 * Assert actual results with expected results.
    471 * @param {String} operatorName
    472 * @param {(Number[]|Number)} actual
    473 * @param {(Number[]|Number)} expected
    474 * @param {String} metricType - Value: 'ULP', 'ATOL'
    475 * @param {Number} toleranceValue
    476 * @param {String} dataType  - An operand type string, value: "float32",
    477 *     more types, please see:
    478 *     https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype
    479 */
    480 const doAssert =
    481    (operatorName, actual, expected, metricType, toleranceValue, dataType) => {
    482      const description = `test ${operatorName} ${dataType}`;
    483      if (typeof expected === 'number') {
    484        expected = [expected];
    485        actual = [actual];
    486      }
    487      if (metricType === 'ULP') {
    488        assert_array_approx_equals_ulp(
    489            actual, expected, toleranceValue, dataType, description);
    490      } else if (metricType === 'ATOL') {
    491        let actualData;
    492        if (dataType === 'float16') {
    493          // workaround for float16
    494          actualData = new Array(actual.length);
    495          actual.forEach(
    496              (x, index) => actualData[index] = float16AsUint16ToNumber(x));
    497        } else {
    498          actualData = actual;
    499        }
    500        assert_array_approx_equals(
    501            actualData, expected, toleranceValue, description);
    502      } else {
    503        throw new AssertionError(
    504            `Tolerance Metric type '${metricType}' is not supported`);
    505      }
    506    };
    507 
    508 /**
    509 * Assert computed results be equal to expected data.
    510 * @param {Object} toleranceFunc
    511 * @param {Map<String, ArrayBufferView> |
    512 *     Array[Map<String, ArrayBufferView>]} actual
    513 * @param {Object} graphResources - Resources used for building a graph
    514 */
    515 const assertResultsEquals =
    516    (toleranceFunc, actual, graphResources, intermediateOperands) => {
    517      const operatorName =
    518          graphResources.operators.map(operator => operator.name).join(' ');
    519      const expectedOutputs = graphResources.expectedOutputs;
    520      const toleranceInfo = toleranceFunc(graphResources, intermediateOperands);
    521      const metricType = toleranceInfo.metricType;
    522      const toleranceValue = toleranceInfo.value;
    523      let outputData;
    524 
    525      for (let operandName in actual) {
    526        const expectedSuboutput = expectedOutputs[operandName];
    527        const expectedDescriptor = expectedSuboutput.descriptor;
    528        let expectedData = expectedSuboutput.data;
    529        outputData = actual[operandName];
    530        // If data is scalar and shape is not, it means it's expecting to be
    531        // filled by the scalar value. Also limit the array size so it doesn't
    532        // timeout.
    533        if (typeof (expectedData) === 'number' && expectedDescriptor.shape &&
    534            sizeOfShape(expectedDescriptor.shape) > 1) {
    535          const size = Math.min(
    536              kMaximumIndexToValidate, sizeOfShape(expectedDescriptor.shape));
    537          expectedData = new Array(size).fill(expectedData);
    538          outputData = outputData.subarray(0, kMaximumIndexToValidate);
    539        } else if (
    540            expectedDescriptor.dataType === 'uint4' ||
    541            expectedDescriptor.dataType === 'int4') {
    542          // The int4/uint4 data were packed in Uint8Array.
    543          // The first nybble and later nybble of one int8/uint8 value store two
    544          // consecutive 4-bits values separately. After unpacking each 4-bits
    545          // value, the unpacked int4 value is stored in an element of
    546          // Int8Array, and the unpacked uint4 value is stored in an element of
    547          // Uint8Array. For example: an array of uint4:
    548          //   size = [1, 5]
    549          //   Uint8Array.of(0x21, 0x43, 0x65, 0x87, 0xA9)
    550          // Would yield 5 * 2 uint4 data:
    551          //   Uint8Array.of(1,2,3,4,5,6,7,8,9,10);
    552          // Another example: an array of int4:
    553          //   size = [1, 5]
    554          //   Uint8Array.of(0xA9, 0xCB, 0xED, 0x0F, 0x21)
    555          // Would yield 5 * 2 int4 data:
    556          //   Int8Array.of(-7, -6, -5, -4, -3, -2, -1, 0, 1, 2);
    557          let newOutputData;
    558          if (expectedDescriptor.dataType === 'uint4') {
    559            newOutputData =
    560                new Uint8Array(sizeOfShape(expectedDescriptor.shape));
    561          } else {
    562            newOutputData =
    563                new Int8Array(sizeOfShape(expectedDescriptor.shape));
    564          }
    565          const signMask =
    566              (expectedDescriptor.dataType === 'int4') ? 0x08 : 0x00;
    567          for (let i = 0; i < sizeOfShape(expectedDescriptor.shape); i++) {
    568            const byteIndex = Math.floor(i / 2);
    569            let value = (outputData[byteIndex] >> ((i & 1) << 2)) & 0xF;
    570            // Handle the negative numbers.
    571            if (value & signMask) {
    572              value |= 0xF0;
    573            }
    574            newOutputData[i] = value;
    575          }
    576          outputData = newOutputData;
    577        }
    578        doAssert(
    579            operatorName, outputData, expectedData, metricType, toleranceValue,
    580            expectedDescriptor.dataType);
    581      }
    582    };
    583 
    584 const createOperand = (context, builder, operandName, resources) => {
    585  let operand;
    586  const descriptor = resources.descriptor;
    587  const dataType = descriptor.dataType;
    588 
    589  const supportedDataTypes = resources.constant ?
    590      context.opSupportLimits().constant.dataTypes :
    591      context.opSupportLimits().input.dataTypes;
    592 
    593  // If input data type is not supported on current platform, attempt to use
    594  // a supported type to pass the data, then cast back to original type.
    595  if (!supportedDataTypes.includes(dataType)) {
    596    const compatibleType = findCompatibleType(
    597        dataType, supportedDataTypes, context.opSupportLimits().cast);
    598    if (compatibleType) {
    599      descriptor.castedType = compatibleType;
    600      descriptor.dataType = compatibleType;
    601    }
    602  }
    603 
    604  operand = resources.constant ?
    605      builder.constant(
    606          descriptor,
    607          getTypedArrayData(
    608              descriptor.dataType, sizeOfShape(descriptor.shape),
    609              resources.data)) :
    610      builder.input(operandName, descriptor);
    611 
    612  if (descriptor.castedType) {
    613    operand = builder.cast(operand, dataType);
    614  }
    615 
    616  return operand;
    617 };
    618 
    619 /**
    620 * Create inputs or outputs tensor.
    621 * @param {MLContext} context - the context used to create inputs or outputs
    622 *     tensor.
    623 * @param {String} dataType - dataType of inputs / outputs operands
    624 * @param {Array} shape - dimensions of inputs / outputs operands
    625 * @param {Object} [data] - optional data for inputs tensor
    626 * @returns {MLTensor}
    627 */
    628 async function createTensorWithData(context, dataType, shape, data) {
    629  const tensorDesc = {dataType, shape};
    630  if (data) {
    631    tensorDesc.writable = true;
    632  } else {
    633    tensorDesc.readable = true;
    634  }
    635  let tensor = await context.createTensor(tensorDesc);
    636  if (data) {
    637    context.writeTensor(tensor, data);
    638  }
    639  return tensor;
    640 }
    641 
    642 async function prepareInputsForGraph(context, resources) {
    643  const inputOperandNameArray = Object.keys(resources).filter(
    644      operandName => !resources[operandName].constant);
    645  const tensors = await Promise.all(inputOperandNameArray.map((operandName) => {
    646    const inputOperandResources = resources[operandName];
    647    const descriptor = inputOperandResources.descriptor;
    648    const targetDataType =
    649        descriptor.castedType ? descriptor.castedType : descriptor.dataType;
    650    const inputBuffer = getTypedArrayData(
    651        targetDataType, sizeOfShape(descriptor.shape),
    652        inputOperandResources.data);
    653    return createTensorWithData(
    654        context, targetDataType, descriptor.shape, inputBuffer);
    655  }));
    656 
    657  const inputs = {};
    658  inputOperandNameArray.forEach((name, index) => inputs[name] = tensors[index]);
    659  return inputs;
    660 }
    661 
    662 async function prepareOutputsForGraph(context, resources) {
    663  const outputOperandNameArray = Object.keys(resources);
    664  const tensors =
    665      await Promise.all(outputOperandNameArray.map((operandName) => {
    666        const descriptor = resources[operandName].descriptor;
    667        const dataType =
    668            descriptor.castedType ? descriptor.castedType : descriptor.dataType;
    669        return createTensorWithData(context, dataType, descriptor.shape);
    670      }));
    671 
    672  const outputs = {};
    673  outputOperandNameArray.forEach(
    674      (name, index) => outputs[name] = tensors[index]);
    675  return outputs;
    676 }
    677 
    678 function getInputName(operatorArguments, operandName) {
    679  for (let argument of operatorArguments) {
    680    const name = Object.keys(argument)[0];
    681    if (name === operandName) {
    682      return argument[operandName];
    683    } else if (name === 'options') {
    684      if (Object.keys(argument[name]).includes(operandName)) {
    685        return argument[name][operandName];
    686      }
    687    }
    688  }
    689  return null;
    690 }
    691 
    692 // This assert() function is to check whether configurations of test case are
    693 // set correctly.
    694 function assert(condition, message) {
    695  if (!condition) {
    696    throw new Error(`Wrong test case, ${message}`);
    697  }
    698 }
    699 
    700 function validateContextSupportsGraph(context, graph) {
    701  const supportLimits = context.opSupportLimits();
    702  const castOpSupportLimits = supportLimits.cast;
    703  const inputDataTypes = supportLimits.input.dataTypes;
    704  const inputRankRange = supportLimits.input.rankRange;
    705  const constantDataTypes = supportLimits.constant.dataTypes;
    706  const constantRankRange = supportLimits.constant.rankRange;
    707  const outputDataTypes = supportLimits.output.dataTypes;
    708  const outputRankRange = supportLimits.output.rankRange;
    709 
    710  function validateInputOrConstantDataTypeAndRank(
    711      inputName, operatorSupportLimits, operand) {
    712    const inputDescriptor = graph.inputs[inputName].descriptor;
    713    const inputDataType = inputDescriptor.dataType;
    714    const inputRank = inputDescriptor.shape.length;
    715    if (inputDescriptor.constant) {
    716      // Check graph constant data type
    717      if (!constantDataTypes.includes(inputDataType) &&
    718          !findCompatibleType(
    719              inputDataType, constantDataTypes, castOpSupportLimits)) {
    720        throw new TypeError(
    721            `Unsupported data type, constant '${operand}' data type ${
    722                inputDataType} must be one of [${constantDataTypes}].`);
    723      }
    724 
    725      // Check graph constant rank
    726      if (inputRank < constantRankRange.min) {
    727        throw new TypeError(`Unsupported rank ${inputRank} for constant '${
    728            operand}' (must be at least ${constantRankRange.min}).`);
    729      }
    730      if (inputRank > constantRankRange.max) {
    731        throw new TypeError(`Unsupported rank ${inputRank} for constant '${
    732            operand}' (must be at most ${constantRankRange.max}).`);
    733      }
    734    } else {
    735      // Check graph input data type
    736      if (!inputDataTypes.includes(inputDataType) &&
    737          !findCompatibleType(
    738              inputDataType, inputDataTypes, castOpSupportLimits)) {
    739        throw new TypeError(
    740            `Unsupported data type, input '${operand}' data type ${
    741                inputDataType} must be one of [${inputDataTypes}].`);
    742      }
    743 
    744      // Check graph input rank
    745      if (inputRank < inputRankRange.min) {
    746        throw new TypeError(`Unsupported rank ${inputRank} for input '${
    747            operand}' (must be at least ${inputRankRange.min}).`);
    748      }
    749      if (inputRank > inputRankRange.max) {
    750        throw new TypeError(`Unsupported rank ${inputRank} for input '${
    751            operand}' (must be at most ${inputRankRange.max}).`);
    752      }
    753    }
    754 
    755    const operandSupportLimits = operatorSupportLimits[operand];
    756    // Check operand data type
    757    const inputOperandDataTypes = operandSupportLimits.dataTypes;
    758    if (!inputOperandDataTypes.includes(inputDataType) &&
    759        !findCompatibleType(
    760            inputDataType, inputDataTypes, castOpSupportLimits)) {
    761      throw new TypeError(
    762          `Unsupported data type, input '${operand}' data type ${
    763              inputDataType} must be one of [${inputOperandDataTypes}].`);
    764    }
    765 
    766    // Check operand rank
    767    const limitsRankRange = operandSupportLimits.rankRange;
    768    if (inputRank < limitsRankRange.min) {
    769      throw new TypeError(`Unsupported rank ${inputRank} for argument ${
    770          operand} (must be at least ${limitsRankRange.min}).`);
    771    }
    772 
    773    if (inputRank > limitsRankRange.max) {
    774      throw new TypeError(`Unsupported rank ${inputRank} for argument ${
    775          operand} (must be at most ${limitsRankRange.max}).`);
    776    }
    777  }
    778 
    779  function validateOutputDataTypeAndRank(
    780      outputName, operatorSupportLimits, operand) {
    781    const outputDataType =
    782        graph.expectedOutputs[outputName].descriptor.dataType;
    783    const outputRank =
    784        graph.expectedOutputs[outputName].descriptor.shape.length;
    785    // Check graph output data type
    786    if (!outputDataTypes.includes(outputDataType) &&
    787        !findCompatibleType(
    788            outputDataType, outputDataTypes, castOpSupportLimits)) {
    789      throw new TypeError(
    790          `Unsupported data type, output '${operand}' data type ${
    791              outputDataType} must be one of [${outputDataTypes}].`);
    792    }
    793 
    794    // Check graph output rank
    795    if (outputRank < outputRankRange.min) {
    796      throw new TypeError(`Unsupported rank ${outputRank} for output '${
    797          operand}' (must be at least ${outputRankRange.min}).`);
    798    }
    799    if (outputRank > outputRankRange.max) {
    800      throw new TypeError(`Unsupported rank ${outputRank} for output '${
    801          operand}' (must be at most ${outputRankRange.max}).`);
    802    }
    803 
    804    // Check output operand data type
    805    const outputOperandDataTypes = operatorSupportLimits[operand].dataTypes;
    806    if (!outputOperandDataTypes.includes(outputDataType) &&
    807        !findCompatibleType(
    808            outputOperandDataTypes, outputDataTypes, castOpSupportLimits)) {
    809      throw new TypeError(
    810          `Unsupported data type, output '${operand}' data type ${
    811              outputDataType} must be one of [${outputOperandDataTypes}].`);
    812    }
    813 
    814    // Check output operand rank
    815    const outputOperandRankRange = operatorSupportLimits[operand].rankRange;
    816    if (outputRank < outputOperandRankRange.min) {
    817      throw new TypeError(`Unsupported rank ${outputRank} for output '${
    818          operand}' (must be at least ${outputOperandRankRange.min}).`);
    819    }
    820    if (outputRank > outputOperandRankRange.max) {
    821      throw new TypeError(`Unsupported rank ${outputRank} for output '${
    822          operand}' (must be at most ${outputOperandRankRange.max}).`);
    823    }
    824  }
    825 
    826  try {
    827    for (let operator of graph.operators) {
    828      const operatorName = operator.name;
    829      const operatorSupportLimits = supportLimits[operatorName];
    830      for (let operand of Object.keys(operatorSupportLimits)) {
    831        if (operand === 'output') {
    832          // single output operand
    833          assert(
    834              typeof operator.outputs === 'string',
    835              `the outputs of ${operatorName} should be a string.`);
    836          if (!graph.expectedOutputs[operator.outputs]) {
    837            // intermediate output
    838            continue;
    839          }
    840          validateOutputDataTypeAndRank(
    841              operator.outputs, operatorSupportLimits, 'output');
    842        } else if (operand === 'outputs') {
    843          // multiple output operands of split operator
    844          assert(
    845              Array.isArray(operator.outputs),
    846              `the outputs of ${operatorName} should be a string array.`);
    847          for (const outputName of operator.outputs) {
    848            assert(
    849                typeof outputName === 'string',
    850                `the outputs' item of ${operatorName} should be a string.`);
    851            if (!graph.expectedOutputs[outputName]) {
    852              // intermediate output
    853              continue;
    854            }
    855            validateOutputDataTypeAndRank(
    856                outputName, operatorSupportLimits, 'outputs');
    857          }
    858        } else if (/output[0-2]/.test(operand)) {
    859          // multiple output operands of gru/lstm/lstmCell operators
    860          assert(
    861              Array.isArray(operator.outputs),
    862              `the outputs of ${operatorName} should be a string array.`);
    863          const index = parseInt(operand.match(/output([0-2])/)[1]);
    864          if (index < operator.outputs.length) {
    865            validateOutputDataTypeAndRank(
    866                operator.outputs[index], operatorSupportLimits, operand);
    867          }
    868        } else {
    869          // input operand(s)
    870          if (operatorName === 'concat') {
    871            const inputNameArray = operator.arguments[0][operand];
    872            assert(
    873                Array.isArray(inputNameArray),
    874                `the inputs of ${operatorName} should be a string array.`);
    875            for (const inputName of inputNameArray) {
    876              assert(
    877                  typeof inputName === 'string',
    878                  `the inputs' item of ${operatorName} should be a string.`);
    879              if (!graph.inputs[inputName]) {
    880                // intermediate input
    881                continue;
    882              }
    883              validateInputOrConstantDataTypeAndRank(
    884                  inputName, operatorSupportLimits, 'inputs');
    885            }
    886          } else {
    887            const inputName = getInputName(operator.arguments, operand);
    888            if (inputName === null || !graph.inputs[inputName]) {
    889              // default options argument or intermediate input
    890              continue;
    891            }
    892            validateInputOrConstantDataTypeAndRank(
    893                inputName, operatorSupportLimits, operand);
    894          }
    895        }
    896      }
    897    }
    898    return /*supported*/ true;
    899  } catch (error) {
    900    return /*not supported*/ false;
    901  }
    902 }
    903 
    904 /**
    905 * This function is to execute the compiled graph.
    906 * @param {MLContext} context
    907 * @param {MLGraph} graph
    908 * @param {Map<String, {
    909 *                       data: Array.<Number>|Number,
    910 *                       descriptor: MLOperandDescriptor,
    911 *                       constant?: Boolean
    912 *                     }>} graphInputs
    913 * @param {Map<String, {
    914 *                      data: Array.<Number>|Number,
    915 *                      descriptor: MLOperandDescriptor,
    916 *                     }>} expectedOutputs
    917 * @returns A result object.
    918 */
    919 async function computeGraph(context, graph, graphInputs, expectedOutputs) {
    920  const inputs = await prepareInputsForGraph(context, graphInputs);
    921  const outputs = await prepareOutputsForGraph(context, expectedOutputs);
    922 
    923  // Execute the compiled graph.
    924  context.dispatch(graph, inputs, outputs);
    925 
    926  const result = {};
    927  const outputNameArray = Object.keys(expectedOutputs);
    928  const outputBuffers = await Promise.all(Object.values(outputs).map(
    929      (tensor) => {return context.readTensor(tensor)}));
    930  outputNameArray.forEach((name, index) => {
    931    const dataType = expectedOutputs[name].descriptor.castedType ?
    932        expectedOutputs[name].descriptor.castedType :
    933        expectedOutputs[name].descriptor.dataType;
    934    result[name] = new TypedArrayDict[dataType](outputBuffers[index])
    935  });
    936 
    937  return result;
    938 }
    939 
    940 /**
    941 * This function is to compile and execute the constructed graph.
    942 * @param {MLContext} context
    943 * @param {MLGraphBuilder} builder
    944 * @param {{
    945 *           inputs: Map<String, {
    946 *                                 data: Array.<Number>|Number,
    947 *                                 descriptor: MLOperandDescriptor,
    948 *                                 constant?: Boolean
    949 *                               }>,
    950 *           operators: Array.<{
    951 *                               name: String,
    952 *                               arguments: Array.<Map<String, Object>> ,
    953 *                               outputs: Array.<String>|String
    954 *                             }>,
    955 *           expectedOutputs: Map<String, {
    956 *                                          data: Array.<Number>|Number,
    957 *                                          descriptor: MLOperandDescriptor,
    958 *                                        }>
    959 *        }} graphResources - Resources used for building a graph
    960 * @returns A Promise of MLComputeResult.
    961 */
    962 const buildAndExecuteGraph = async (context, builder, graphResources) => {
    963  const outputOperands = [];
    964  const graphInputs = graphResources.inputs;
    965  const graphOperators = graphResources.operators;
    966  const intermediateOperands = {};
    967  for (const operator of graphOperators) {
    968    const argumentArray = [];
    969    for (const argument of operator.arguments) {
    970      for (const argumentName in argument) {
    971        if (argumentName !== 'options') {
    972          if (operator.name === 'concat' && argumentName === 'inputs') {
    973            const concatInputs = [];
    974            for (const inputName of argument[argumentName]) {
    975              if (graphInputs.hasOwnProperty(inputName)) {
    976                const operandName = inputName;
    977                const operand = createOperand(
    978                    context, builder, operandName, graphInputs[operandName]);
    979                concatInputs.push(operand);
    980              } else if (intermediateOperands.hasOwnProperty(inputName)) {
    981                concatInputs.push(intermediateOperands[inputName]);
    982              }
    983              // concatInputs.push(intermediateOperands[inputName]);
    984            }
    985            argumentArray.push(concatInputs);
    986          } else if (graphInputs.hasOwnProperty(argument[argumentName])) {
    987            const operandName = argument[argumentName];
    988            const operand = createOperand(
    989                context, builder, operandName, graphInputs[operandName]);
    990            argumentArray.push(operand);
    991          } else if (intermediateOperands.hasOwnProperty(
    992                         argument[argumentName])) {
    993            argumentArray.push(intermediateOperands[argument[argumentName]]);
    994          } else {
    995            argumentArray.push(argument[argumentName]);
    996          }
    997        } else {
    998          for (const [optionalArgumentName, value] of Object.entries(
    999                   argument['options'])) {
   1000            if (typeof value === 'string' &&
   1001                graphInputs.hasOwnProperty(value)) {
   1002              const operandName = value;
   1003              const operand = createOperand(
   1004                  context, builder, operandName, graphInputs[operandName]);
   1005              argument['options'][optionalArgumentName] = operand;
   1006            } else if (
   1007                typeof value === 'string' &&
   1008                intermediateOperands.hasOwnProperty(value)) {
   1009              argument['options'][optionalArgumentName] =
   1010                  intermediateOperands[value];
   1011            }
   1012          }
   1013          argumentArray.push(argument['options']);
   1014        }
   1015      }
   1016    }
   1017 
   1018    const currentOutput = builder[operator.name](...argumentArray);
   1019    if (Array.isArray(operator.outputs)) {
   1020      operator.outputs.forEach((outputName, index) => {
   1021        intermediateOperands[outputName] = currentOutput[index];
   1022      });
   1023    } else {
   1024      intermediateOperands[operator.outputs] = currentOutput;
   1025    }
   1026  }
   1027 
   1028  const outputNames = Object.keys(graphResources.expectedOutputs);
   1029  outputNames.forEach(outputName => {
   1030    if (intermediateOperands.hasOwnProperty(outputName)) {
   1031      outputOperands.push(intermediateOperands[outputName]);
   1032    }
   1033  });
   1034 
   1035  if (outputOperands.length !== outputNames.length) {
   1036    throw new Error('Graph outputs are not properly defined');
   1037  }
   1038 
   1039  for (let i = 0; i < outputOperands.length; ++i) {
   1040    const expectedDescriptor =
   1041        graphResources
   1042            .expectedOutputs[Object.keys(graphResources.expectedOutputs)[i]]
   1043            .descriptor;
   1044    if (!context.opSupportLimits().output.dataTypes.includes(
   1045            expectedDescriptor.dataType)) {
   1046      const compatibleType = findCompatibleType(
   1047          expectedDescriptor.dataType,
   1048          context.opSupportLimits().output.dataTypes,
   1049          context.opSupportLimits().cast);
   1050      outputOperands[i] = builder.cast(outputOperands[i], compatibleType);
   1051      expectedDescriptor.castedType = compatibleType;
   1052    }
   1053  }
   1054 
   1055  const outputNameArray = Object.keys(graphResources.expectedOutputs);
   1056  for (let i = 0; i < outputOperands.length; ++i) {
   1057    assertDescriptorsEquals(
   1058        outputOperands[i],
   1059        graphResources.expectedOutputs[outputNameArray[i]].descriptor);
   1060  }
   1061 
   1062  const namedOutputOperand = {};
   1063  outputNameArray.forEach(
   1064      (name, index) => namedOutputOperand[name] = outputOperands[index]);
   1065 
   1066  // Compile the constructed graph.
   1067  const graph = await builder.build(namedOutputOperand);
   1068 
   1069  // Execute the compiled graph.
   1070  const result = await computeGraph(
   1071      context, graph, graphInputs, graphResources.expectedOutputs);
   1072 
   1073  return {result, intermediateOperands};
   1074 };
   1075 
   1076 const getGemmPrecisionTolerance =
   1077    (op, graphResources, intermediateOperands) => {
   1078  // GEMM : alpha * (A x B) + beta * C
   1079  // An upper bound for the worst serial ordering is bounded by
   1080  // the number of lossy operations, where matrix multiplication
   1081  // is a dot product (mul and add times the number of elements)
   1082  // plus bias operations.
   1083  const {inputs} = graphResources;
   1084  const args = op.arguments;
   1085  let ShapeA;
   1086  const indexA = args[0][Object.keys(args[0])[0]];
   1087  if (inputs[indexA]) {
   1088    ShapeA = inputs[indexA].descriptor.shape;
   1089  } else {
   1090    ShapeA = intermediateOperands[indexA].shape;
   1091  }
   1092  const options =
   1093      args.length === 3 ? {...args[2][Object.keys(args[2])[0]]} : {};
   1094  const width = options.aTranspose ? ShapeA[0] : ShapeA[1];
   1095  let tolerance = width * 2;
   1096  // default options.alpha is 1.0
   1097  if (options.alpha !== undefined && options.alpha !== 1.0) {
   1098    tolerance++;
   1099  }
   1100  if (options.c && options.beta !== 0.0) {
   1101    // default options.beta is 1.0
   1102    if (options.beta !== undefined && options.beta !== 1.0) {
   1103      tolerance++;
   1104    }
   1105    tolerance++;
   1106  }
   1107 
   1108  const toleranceValueDict = {float32: tolerance, float16: tolerance};
   1109  const expectedDataType =
   1110      getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
   1111  return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]};
   1112 };
   1113 
   1114 const getMatmulPrecisionTolerance =
   1115    (op, graphResources, intermediateOperands) => {
   1116  const {inputs} = graphResources;
   1117  const args = op.arguments;
   1118  let shapeA;
   1119  const indexA = args[0][Object.keys(args[0])[0]];
   1120  if (inputs[indexA]) {
   1121    shapeA = inputs[indexA].descriptor.shape;
   1122  } else {
   1123    shapeA = intermediateOperands[indexA].shape;
   1124  }
   1125  const tolerance = shapeA[shapeA.length - 1] * 2;
   1126  const toleranceValueDict = {float32: tolerance, float16: tolerance};
   1127  const expectedDataType =
   1128      getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
   1129  return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]};
   1130 };
   1131 
   1132 const getConv2dPrecisionTolerance =
   1133    (op, graphResources, intermediateOperands) => {
   1134  // number of reduced input elements multiplied by filter and summed (a sliding
   1135  // dot product like pooling)
   1136  const {inputs} = graphResources;
   1137  const operatorName = op.name;
   1138  const args = op.arguments;
   1139  let inputShape;
   1140  const inputIndex = args[0][Object.keys(args[0])[0]];
   1141  const filterIndex = args[1][Object.keys(args[1])[0]];
   1142  if (inputs[inputIndex]) {
   1143    inputShape = inputs[inputIndex].descriptor.shape;
   1144  } else {
   1145    inputShape = intermediateOperands[inputIndex].shape;
   1146  }
   1147  let filterShape;
   1148  if (inputs[filterIndex]) {
   1149    filterShape = inputs[filterIndex].descriptor.shape;
   1150  } else {
   1151    filterShape = intermediateOperands[filterIndex].shape;
   1152  }
   1153  const options =
   1154      args.length === 3 ? {...args[2][Object.keys(args[2])[0]]} : {};
   1155  let inputChannels = inputShape[1];  // default nchw inputLayout
   1156  // default oihw filterLayout for conv2d or default iohw filterLayout for
   1157  // convTranspose2d
   1158  let filterWidth = filterShape[3];
   1159  let filterHeight = filterShape[2];
   1160  const groups = options.groups ? options.groups : 1;
   1161 
   1162  if (options.inputLayout) {
   1163    if (!['nchw', 'nhwc'].includes(options.inputLayout)) {
   1164      throw new Error(`Unknown inputLayout ${options.inputLayout}`);
   1165    }
   1166    inputChannels =
   1167        options.inputLayout === 'nchw' ? inputChannels : inputShape[3];
   1168  }
   1169  if (options.filterLayout) {
   1170    let filterLayouts = ['oihw', 'hwio', 'ohwi', 'ihwo'];  // default for conv2d
   1171    if (operatorName === 'convTranspose2d') {
   1172      filterLayouts = ['iohw', 'hwoi', 'ohwi'];
   1173    }
   1174    if (!filterLayouts.includes(options.filterLayout)) {
   1175      throw new Error(`Unknown filterLayout ${options.filterLayout}`);
   1176    }
   1177    switch (options.filterLayout) {
   1178      case 'oihw':
   1179      case 'iohw':
   1180        // Just use the existing filterWidth and filterHeight above.
   1181        break;
   1182      case 'hwio':
   1183      case 'hwoi':
   1184        filterWidth = filterShape[1];
   1185        filterHeight = filterShape[0];
   1186        break;
   1187      case 'ohwi':
   1188      case 'ihwo':
   1189        filterWidth = filterShape[2];
   1190        filterHeight = filterShape[1];
   1191        break;
   1192      default:
   1193        break;
   1194    }
   1195  }
   1196 
   1197  const tolerance = filterWidth * filterHeight * (inputChannels / groups) * 2;
   1198  const toleranceValueDict = {float32: tolerance, float16: tolerance};
   1199  const expectedDataType =
   1200      getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
   1201  return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]};
   1202 };
   1203 
   1204 const getPoolingOperatorsPrecisionTolerance =
   1205    (op, graphResources, intermediateOperands) => {
   1206  const args = op.arguments;
   1207  const operatorName = op.name;
   1208  const {inputs} = graphResources;
   1209  let inputShape;
   1210  const inputIndex = args[0][Object.keys(args[0])[0]];
   1211  if (inputs[inputIndex]) {
   1212    inputShape = inputs[inputIndex].descriptor.shape;
   1213  } else {
   1214    inputShape = intermediateOperands[inputIndex].shape;
   1215  }
   1216  const options =
   1217      args.length === 2 ? {...args[1][Object.keys(args[1])[0]]} : {};
   1218  let height;
   1219  let width;
   1220 
   1221  if (options.windowDimensions) {
   1222    height = options.windowDimensions[0];
   1223    width = options.windowDimensions[1];
   1224  } else {
   1225    // If not present, the window dimensions are assumed to be the height
   1226    // and width dimensions of the input shape
   1227    if (options.layout && options.layout === 'nhwc') {
   1228      height = inputShape[1];
   1229      width = inputShape[2];
   1230    } else {
   1231      // nhwc layout of input
   1232      height = inputShape[2];
   1233      width = inputShape[3];
   1234    }
   1235  }
   1236 
   1237  const tolerance = height * width + 2;
   1238  const toleranceDict = {
   1239    averagePool2d: {float32: tolerance, float16: tolerance},
   1240    l2Pool2d: {float32: tolerance, float16: tolerance},
   1241    maxPool2d: {float32: 0, float16: 0},
   1242  };
   1243  const expectedDataType =
   1244      getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
   1245  return {
   1246    metricType: 'ULP',
   1247    value: toleranceDict[operatorName][expectedDataType]
   1248  };
   1249 };
   1250 
   1251 const getInstanceNormPrecisionTolerance = (graphResources) => {
   1252  // according to
   1253  // https://github.com/web-platform-tests/wpt/pull/43891#discussion_r1457026316
   1254  const toleranceValueDict = {float32: 840, float16: 8400};
   1255  const expectedDataType =
   1256      getExpectedDataTypeOfSingleOutput(graphResources.expectedOutputs);
   1257  return {metricType: 'ULP', value: toleranceValueDict[expectedDataType]};
   1258 };
   1259 
   1260 const getExpectedDataTypeOfSingleOutput = (expectedOutput) => {
   1261  const expectedDescriptor =
   1262      expectedOutput[Object.keys(expectedOutput)[0]].descriptor;
   1263  const dataType = expectedDescriptor.castedType ?
   1264      expectedDescriptor.castedType :
   1265      expectedDescriptor.dataType;
   1266  return dataType;
   1267 };
   1268 
   1269 const getReductionOperatorsPrecisionTolerance =
   1270    (op, graphResources, intermediateOperands) => {
   1271      let tolerance;
   1272      const operatorName = op.name;
   1273      if (op.name === 'reduceMax' || op.name === 'reduceMin') {
   1274        tolerance = 0;
   1275      } else {
   1276        // other reduction operators
   1277        const args = op.arguments;
   1278        const {inputs} = graphResources;
   1279        let inputShape;
   1280        const inputIndex = args[0][Object.keys(args[0])[0]];
   1281        if (inputs[inputIndex]) {
   1282          inputShape = inputs[inputIndex].descriptor.shape;
   1283        } else {
   1284          inputShape = intermediateOperands[inputIndex].shape;
   1285        }
   1286 
   1287        const rank = inputShape.length;
   1288        const options =
   1289            args.length === 2 ? {...args[1][Object.keys(args[1])[0]]} : {};
   1290        let sizes;
   1291 
   1292        if (options && options.axes) {
   1293          sizes = options.axes.map(
   1294              (axis) => axis < 0 ? inputShape[axis + rank] : inputShape[axis]);
   1295        } else {
   1296          sizes = inputShape;
   1297        }
   1298 
   1299        const elementCount = sizes.reduce(
   1300            (accumulator, currentValue) => accumulator * currentValue, 1);
   1301        tolerance = elementCount;
   1302      }
   1303 
   1304      const toleranceDict = {
   1305        reduceL1: tolerance,
   1306        reduceL2: tolerance * 2 + 2,
   1307        reduceLogSum: tolerance + 18,
   1308        reduceLogSumExp: tolerance * 2 + 18,
   1309        reduceMax: tolerance,
   1310        reduceMean: tolerance + 2,
   1311        reduceMin: tolerance,
   1312        reduceProduct: tolerance,
   1313        reduceSum: tolerance,
   1314        reduceSumSquare: tolerance * 2
   1315      };
   1316      return {metricType: 'ULP', value: toleranceDict[operatorName]};
   1317    };
   1318 
   1319 const getResample2dPrecisionTolerance =
   1320    (op, graphResources, intermediateOperands) => {
   1321      const args = op.arguments;
   1322      const options =
   1323          args.length === 2 ? {...args[1][Object.keys(args[1])[0]]} : {};
   1324      const expectedOutputs = graphResources.expectedOutputs;
   1325      const dataType =
   1326          expectedOutputs[Object.keys(expectedOutputs)[0]].descriptor.dataType;
   1327      let tolerance;
   1328 
   1329      if (options.mode && options.mode === 'linear') {
   1330        // interpolation mode is linear
   1331        if (dataType === 'float32') {
   1332          tolerance = 84;
   1333        } else if (dataType === 'float16') {
   1334          tolerance = 10;
   1335        } else {
   1336          tolerance = 1;
   1337        }
   1338      } else {
   1339        // interpolation mode is nearest-neighbor
   1340        tolerance = 0;
   1341      }
   1342 
   1343      return {metricType: 'ULP', value: tolerance};
   1344    };
   1345 
   1346 let minimumDataTypeSet;
   1347 
   1348 function checkMinimum(descriptor, operandMinimumLimits) {
   1349  const targetRank = descriptor.shape.length;
   1350  const targetDataType = descriptor.dataType;
   1351  let isMinimum = operandMinimumLimits.dataTypes.includes(targetDataType);
   1352 
   1353  if (isMinimum) {
   1354    isMinimum = operandMinimumLimits.rankRange.min <= targetRank &&
   1355        targetRank <= operandMinimumLimits.rankRange.max;
   1356  }
   1357 
   1358  return isMinimum;
   1359 }
   1360 
   1361 function getOutputMinimumLimits(operatorsResources, outputOperandName) {
   1362  let operatorName;
   1363  let outputName;
   1364  for (let operator of operatorsResources) {
   1365    if (typeof operator.outputs === 'string' &&
   1366        operator.outputs === outputOperandName) {
   1367      operatorName = operator.name;
   1368      outputName = 'output';
   1369      break;
   1370    } else if (
   1371        Array.isArray(operator.outputs) &&
   1372        operator.outputs.includes(outputOperandName)) {
   1373      // Current gru, lstm, lstmCell and split operators have multiple outputs
   1374      operatorName = operator.name;
   1375      if (minimumDataTypeSet[operatorName].hasOwnProperty('outputs')) {
   1376        // for split operator
   1377        outputName = 'outputs';
   1378      } else {
   1379        // for gru, lstm, lstmCell operators
   1380        outputName = `output${operator.outputs.indexOf(outputOperandName)}`;
   1381      }
   1382      break;
   1383    }
   1384  }
   1385 
   1386  return minimumDataTypeSet[operatorName][outputName];
   1387 }
   1388 
   1389 async function getMinimumDataTypeSetJson() {
   1390  try {
   1391    const response = await fetch('/webnn/resources/minimum_datatype_set.json');
   1392 
   1393    if (!response.ok) {
   1394      throw new Error(`HTTP error! Status: ${response.status}`);
   1395    }
   1396 
   1397    const text = await response.text();
   1398    const jsonText =
   1399        text.replace(/\/\/.*|\/\*[\s\S]*?\*\//g, '');  // Remove comments
   1400    minimumDataTypeSet = JSON.parse(jsonText);
   1401  } catch (error) {
   1402    throw new Error(`Error fetching and parsing JSON: ${error.message}`);
   1403  }
   1404  return minimumDataTypeSet;
   1405 }
   1406 
   1407 function isMinimumTest(test) {
   1408  let isMinimum = false;
   1409  const graphResources = test.graph;
   1410  const inputsResources = graphResources.inputs;
   1411 
   1412  // check inputs
   1413  for (let operator of graphResources.operators) {
   1414    const minimumLimits = minimumDataTypeSet[operator.name];
   1415    for (let argument of operator.arguments) {
   1416      for (let [operandName, value] of Object.entries(argument)) {
   1417        if (operandName !== 'options') {
   1418          if (typeof value === 'string' &&
   1419              inputsResources.hasOwnProperty(value)) {
   1420            isMinimum = checkMinimum(
   1421                inputsResources[value].descriptor, minimumLimits[operandName]);
   1422            if (!isMinimum) {
   1423              return isMinimum;
   1424            }
   1425          } else if (Array.isArray(value)) {
   1426            for (let subValue of value) {
   1427              if (typeof subValue === 'string' &&
   1428                  inputsResources.hasOwnProperty(subValue)) {
   1429                isMinimum = checkMinimum(
   1430                    inputsResources[subValue].descriptor,
   1431                    minimumLimits[operandName]);
   1432                if (!isMinimum) {
   1433                  return isMinimum;
   1434                }
   1435              }
   1436            }
   1437          }
   1438        } else {
   1439          for (let [optionOperandName, optionValue] of Object.entries(
   1440                   argument['options'])) {
   1441            if (typeof value === 'string' &&
   1442                inputsResources.hasOwnProperty(optionValue)) {
   1443              isMinimum = checkMinimum(
   1444                  inputsResources[optionValue].descriptor,
   1445                  minimumLimits[optionOperandName]);
   1446              if (!isMinimum) {
   1447                return isMinimum;
   1448              }
   1449            }
   1450          }
   1451        }
   1452      }
   1453    }
   1454  }
   1455 
   1456  // check outputs
   1457  const outputsResources = graphResources.expectedOutputs;
   1458  for (let [outputOperandName, value] of Object.entries(outputsResources)) {
   1459    const outputMinimumLimits =
   1460        getOutputMinimumLimits(graphResources.operators, outputOperandName)
   1461    isMinimum = checkMinimum(value.descriptor, outputMinimumLimits);
   1462    if (!isMinimum) {
   1463      return isMinimum;
   1464    }
   1465  }
   1466 
   1467  return isMinimum;
   1468 }
   1469 
   1470 // This array is to save skipped tests which are optional tests unsupported by
   1471 // the context. It's helpful to debug to get detail skipped tests in browser
   1472 // console by typing testsToSkip after running tests.
   1473 const testsToSkip = [];
   1474 
   1475 async function webnn_conformance_test(
   1476    tests, buildAndExecuteGraphFunc, toleranceFunc) {
   1477  if (navigator.ml === undefined) {
   1478    test(() => assert_implements(navigator.ml, 'missing navigator.ml'));
   1479  } else {
   1480    const testsToRun = [];
   1481    promise_setup(async () => {
   1482      // Create a context for checking whether tests are supported.
   1483      const context = await getContext();
   1484      minimumDataTypeSet = await getMinimumDataTypeSetJson();
   1485      tests.filter(isTargetTest).forEach((test) => {
   1486        if (validateContextSupportsGraph(context, test.graph) ||
   1487            isMinimumTest(test)) {
   1488          testsToRun.push(test);
   1489        } else {
   1490          // This test is optional so it can be skipped.
   1491          testsToSkip.push(test);
   1492        }
   1493      });
   1494    });
   1495 
   1496    promise_test(async () => {
   1497      testsToRun.map((test) => {
   1498        promise_test(async () => {
   1499          // Create a context for each test.
   1500          const context = await getContext();
   1501          const builder = new MLGraphBuilder(context);
   1502          const {result, intermediateOperands} =
   1503              await buildAndExecuteGraphFunc(context, builder, test.graph);
   1504          assertResultsEquals(
   1505              toleranceFunc, result, test.graph, intermediateOperands);
   1506        }, `${isMinimumTest(test) ? '[required]' : '[optional]'} ${test.name}`);
   1507      });
   1508    });
   1509  }
   1510 }