tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

softmax.https.any.js (2064B)


      1 // META: title=validation tests for WebNN API softmax operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const tests = [
     11  {
     12    name: '[softmax] Test building Softmax with float32 input.',
     13    input: {dataType: 'float32', shape: [4, 4, 3]},
     14    axis: 1,
     15    output: {dataType: 'float32', shape: [4, 4, 3]}
     16  },
     17  {
     18    name: '[softmax] Test building Softmax with float16 input.',
     19    input: {dataType: 'float16', shape: [3, 1, 5, 2]},
     20    axis: 2,
     21    output: {dataType: 'float16', shape: [3, 1, 5, 2]}
     22  },
     23  {
     24    name: '[softmax] Throw if the input is not a non-floating-point data.',
     25    input: {dataType: 'int32', shape: [3, 1, 5, 2]},
     26    axis: 3
     27  },
     28  {
     29    name: '[softmax] Throw if the axis is greater than input rank - 1.',
     30    input: {dataType: 'float16', shape: [3, 1, 5, 2]},
     31    axis: 4
     32  },
     33  {
     34    name: '[softmax] Throw if the input is a scalar.',
     35    input: {dataType: 'float32', shape: []},
     36    axis: 0
     37  }
     38 ];
     39 
     40 tests.forEach(
     41    test => promise_test(async t => {
     42      const builder = new MLGraphBuilder(context);
     43      let input = builder.input(`input`, test.input);
     44      if (test.output) {
     45        const output = builder.softmax(input, test.axis);
     46        assert_equals(output.dataType, test.output.dataType);
     47        assert_array_equals(output.shape, test.output.shape);
     48      } else {
     49        const label = 'softmax_xxx';
     50        const options = {label};
     51        const regrexp = new RegExp('\\[' + label + '\\]');
     52        assert_throws_with_label(
     53            () => builder.softmax(input, test.axis, options), regrexp);
     54      }
     55    }, test.name));
     56 
     57 multi_builder_test(async (t, builder, otherBuilder) => {
     58  const operandDescriptor = {dataType: 'float32', shape: [1, 2, 3]};
     59  const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
     60  const axis = 1;
     61 
     62  assert_throws_js(
     63      TypeError, () => builder.softmax(inputFromOtherBuilder, axis));
     64 }, '[softmax] throw if any input is from another builder');