tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

argMinMax.https.any.js (3287B)


      1 // META: title=validation tests for WebNN API argMin/Max operations
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const kArgMinMaxOperators = [
     11  'argMin',
     12  'argMax',
     13 ];
     14 
     15 const label = 'arg_min_max_1_!';
     16 
     17 const tests = [
     18  {
     19    name: '[argMin/Max] Test with default options.',
     20    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     21    axis: 0,
     22    output: {shape: [2, 3, 4]}
     23  },
     24  {
     25    name: '[argMin/Max] Test with axes=1.',
     26    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     27    axis: 1,
     28    output: {shape: [1, 3, 4]}
     29  },
     30  {
     31    name: '[argMin/Max] Test with outputDataType=int32',
     32    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     33    axis: 1,
     34    options: {
     35      outputDataType: 'int32',
     36    },
     37    output: {shape: [1, 3, 4]}
     38  },
     39  {
     40    name: '[argMin/Max] Test with outputDataType=int64',
     41    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     42    axis: 1,
     43    options: {
     44      outputDataType: 'int64',
     45    },
     46    output: {shape: [1, 3, 4]}
     47  },
     48  {
     49    name:
     50        '[argMin/Max] Throw if the value in axis is greater than or equal to input rank.',
     51    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     52    axis: 4,
     53    options: {
     54      label: label,
     55    },
     56  },
     57  {
     58    name: '[argMin/Max] Throw if input is a scalar and axis=0.',
     59    input: {dataType: 'float32', shape: []},
     60    axis: 0,
     61    options: {
     62      label: label,
     63    },
     64  },
     65  {
     66    name: '[argMin/Max] Throw if outputDataType=float32',
     67    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     68    axis: 1,
     69    options: {outputDataType: 'float32', label: label}
     70  }
     71 ];
     72 
     73 function runTests(operatorName, tests) {
     74  tests.forEach(test => {
     75    promise_test(async t => {
     76      const builder = new MLGraphBuilder(context);
     77      const input = builder.input('input', test.input);
     78      const axis = test.axis;
     79      if (!context.opSupportLimits()[operatorName].input.dataTypes.includes(test.input.dataType)){
     80        assert_throws_js(
     81          TypeError, () => builder[operatorName](input, axis, test.options));
     82        return;
     83      }
     84      if (test.options && test.options.outputDataType !== undefined) {
     85        if (context.opSupportLimits()[operatorName].output.dataTypes.includes(
     86          test.options.outputDataType)) {
     87          const output = builder[operatorName](input, axis, test.options);
     88          assert_equals(output.dataType, test.options.outputDataType);
     89          assert_array_equals(output.shape, test.output.shape);
     90        } else {
     91          assert_throws_js(
     92            TypeError, () => builder[operatorName](input, axis, test.options));
     93        }
     94        return;
     95      }
     96      if (test.output) {
     97        const output = builder[operatorName](input, axis, test.options);
     98        assert_equals(output.dataType, 'int32');
     99        assert_array_equals(output.shape, test.output.shape);
    100      } else {
    101        const regrexp = /\[arg_min_max_1_\!\]/;
    102        assert_throws_with_label(
    103            () => builder[operatorName](input, axis, test.options), regrexp);
    104      }
    105    }, test.name.replace('[argMin/Max]', `[${operatorName}]`));
    106  });
    107 }
    108 
    109 kArgMinMaxOperators.forEach((operatorName) => {
    110  validateInputFromAnotherBuilder(operatorName);
    111  runTests(operatorName, tests);
    112 });