tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

prelu.https.any.js (3158B)


      1 // META: title=validation tests for WebNN API prelu operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 validateTwoInputsFromMultipleBuilders('prelu');
     11 
     12 const tests = [
     13  {
     14    name:
     15        '[prelu] Test slope\'s shape = [3, 2, 5] which is the same as input\'s shape.',
     16    input: {dataType: 'float32', shape: [3, 2, 5]},
     17    slope: {dataType: 'float32', shape: [3, 2, 5]},
     18    output: {dataType: 'float32', shape: [3, 2, 5]},
     19  },
     20  {
     21    name:
     22        '[prelu] Test slope\'s shape = [5] which is unidirectionally broadcastable to input\'s shape.',
     23    input: {dataType: 'float32', shape: [3, 2, 5]},
     24    slope: {dataType: 'float32', shape: [5]},
     25    output: {dataType: 'float32', shape: [3, 2, 5]},
     26  },
     27  {
     28    name:
     29        '[prelu] Test slope\'s shape = [] which is unidirectionally broadcastable to input\'s shape.',
     30    input: {dataType: 'float32', shape: [3, 2, 5]},
     31    slope: {dataType: 'float32', shape: []},
     32    output: {dataType: 'float32', shape: [3, 2, 5]},
     33  },
     34  {
     35    name:
     36        '[prelu] Test slope\'s shape = [2, 5] which is unidirectionally broadcastable to input\'s shape.',
     37    input: {dataType: 'float32', shape: [3, 2, 5]},
     38    slope: {dataType: 'float32', shape: [2, 5]},
     39    output: {dataType: 'float32', shape: [3, 2, 5]},
     40  },
     41  {
     42    name:
     43        '[prelu] Throw if the shape of slope is not broadcastable to the shape of input.',
     44    input: {dataType: 'float32', shape: [3, 2, 5]},
     45    slope: {dataType: 'float32', shape: [2]},
     46  },
     47  {
     48    name:
     49        '[prelu] Throw if the data type of slope does not match the data type of input.',
     50    input: {dataType: 'float32', shape: [3, 2, 5]},
     51    slope: {dataType: 'int32', shape: [3, 2, 5]},
     52  },
     53 ];
     54 
     55 tests.forEach(
     56    test => promise_test(async t => {
     57      const builder = new MLGraphBuilder(context);
     58      const input = builder.input('input', test.input);
     59      const slope = builder.input('input', test.slope);
     60      if (test.output) {
     61        const output = builder.prelu(input, slope);
     62        assert_equals(output.dataType, test.output.dataType);
     63        assert_array_equals(output.shape, test.output.shape);
     64      } else {
     65        const label = 'prelu_123';
     66        const options = {label};
     67        const regrexp = new RegExp('\\[' + label + '\\]');
     68        assert_throws_with_label(
     69            () => builder.prelu(input, slope, options), regrexp);
     70      }
     71    }, test.name));
     72 
     73 promise_test(async t => {
     74  for (let dataType of allWebNNOperandDataTypes) {
     75    if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
     76      continue;
     77    }
     78    const builder = new MLGraphBuilder(context);
     79    const shape = [1];
     80    const input = builder.input(`input`, {dataType, shape});
     81    if (context.opSupportLimits().prelu.input.dataTypes.includes(dataType)) {
     82      const output = builder.prelu(input, input);
     83      assert_equals(output.dataType, dataType);
     84      assert_array_equals(output.shape, shape);
     85    } else {
     86      assert_throws_js(TypeError, () => builder.prelu(input, input));
     87    }
     88  }
     89 }, `[prelu] Test prelu with all of the data types.`);