tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

reduction.https.any.js (2976B)


      1 // META: title=validation tests for WebNN API reduction operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const kReductionOperators = [
     11  'reduceL1',
     12  'reduceL2',
     13  'reduceLogSum',
     14  'reduceLogSumExp',
     15  'reduceMax',
     16  'reduceMean',
     17  'reduceMin',
     18  'reduceProduct',
     19  'reduceSum',
     20  'reduceSumSquare',
     21 ];
     22 
     23 const label = 'reduce_op_xxx';
     24 
     25 const allReductionOperatorsTests = [
     26  {
     27    name: '[reduce] Test reduce with keepDimensions=true.',
     28    input: {dataType: 'float32', shape: [1, 3, 4, 4]},
     29    options: {
     30      keepDimensions: true,
     31    },
     32    output: {dataType: 'float32', shape: [1, 1, 1, 1]}
     33  },
     34  {
     35    name: '[reduce] Test reduce with axes=[0, 1] and keep_dimensions=false.',
     36    input: {dataType: 'float32', shape: [1, 3, 5, 5]},
     37    options: {axes: [0, 1]},
     38    output: {dataType: 'float32', shape: [5, 5]}
     39  },
     40  {
     41    name: '[reduce] Throw if a value in axes is out of range of [0, N-1].',
     42    input: {dataType: 'float32', shape: [1, 2, 5, 5]},
     43    options: {
     44      axes: [4],
     45      label: label,
     46    },
     47  },
     48  {
     49    name: '[reduce] Throw if the two values are same in axes sequence.',
     50    input: {dataType: 'float32', shape: [1, 2, 5, 5]},
     51    options: {
     52      axes: [0, 1, 1],
     53      label: label,
     54    },
     55  },
     56 ];
     57 
     58 function runReductionTests(operatorName, tests) {
     59  tests.forEach(test => {
     60    promise_test(async t => {
     61      const builder = new MLGraphBuilder(context);
     62      const input = builder.input('input', test.input);
     63 
     64      if (test.output) {
     65        const output = builder[operatorName](input, test.options);
     66        assert_equals(output.dataType, test.output.dataType);
     67        assert_array_equals(output.shape, test.output.shape);
     68      } else {
     69        const regrexp = new RegExp('\\[' + label + '\\]');
     70        assert_throws_with_label(
     71            () => builder[operatorName](input, test.options), regrexp);
     72      }
     73    }, test.name.replace('[reduce]', `[${operatorName}]`));
     74  });
     75 }
     76 
     77 kReductionOperators.forEach((operatorName) => {
     78  validateInputFromAnotherBuilder(operatorName);
     79  runReductionTests(operatorName, allReductionOperatorsTests);
     80 });
     81 
     82 kReductionOperators.forEach((operatorName) => {
     83  promise_test(async t => {
     84    for (let dataType of allWebNNOperandDataTypes) {
     85      if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
     86        continue;
     87      }
     88      const builder = new MLGraphBuilder(context);
     89      const input = builder.input(`input`, {dataType, shape: shape3D});
     90      if (context.opSupportLimits()[operatorName].input.dataTypes.includes(
     91              dataType)) {
     92        const output = builder[operatorName](input);
     93        assert_equals(output.dataType, dataType);
     94        assert_array_equals(output.shape, []);
     95      } else {
     96        assert_throws_js(TypeError, () => builder[operatorName](input));
     97      }
     98    }
     99  }, `[${operatorName}] Test reduce with all of the data types.`);
    100 });