tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

elementwise-binary.https.any.js (3475B)


      1 // META: title=validation tests for WebNN API element-wise binary operations
      2 // META: global=window
      3 // META: variant=?op=add&device=cpu
      4 // META: variant=?op=add&device=gpu
      5 // META: variant=?op=add&device=npu
      6 // META: variant=?op=sub&device=cpu
      7 // META: variant=?op=sub&device=gpu
      8 // META: variant=?op=sub&device=npu
      9 // META: variant=?op=mul&device=cpu
     10 // META: variant=?op=mul&device=gpu
     11 // META: variant=?op=mul&device=npu
     12 // META: variant=?op=div&device=cpu
     13 // META: variant=?op=div&device=gpu
     14 // META: variant=?op=div&device=npu
     15 // META: variant=?op=max&device=cpu
     16 // META: variant=?op=max&device=gpu
     17 // META: variant=?op=max&device=npu
     18 // META: variant=?op=min&device=cpu
     19 // META: variant=?op=min&device=gpu
     20 // META: variant=?op=min&device=npu
     21 // META: variant=?op=pow&device=cpu
     22 // META: variant=?op=pow&device=gpu
     23 // META: variant=?op=pow&device=npu
     24 // META: script=../resources/utils_validation.js
     25 
     26 'use strict';
     27 
     28 const queryParams = new URLSearchParams(window.location.search);
     29 const operatorName = queryParams.get('op');
     30 
     31 const label = 'elementwise_binary_op';
     32 const regrexp = new RegExp('\\[' + label + '\\]');
     33 const tests = [
     34  {
     35    name: '[binary] Test bidirectionally broadcastable dimensions.',
     36    //  Both inputs have axes of length one which are expanded
     37    //  during broadcasting.
     38    a: {dataType: 'float32', shape: [8, 1, 6, 1]},
     39    b: {dataType: 'float32', shape: [7, 1, 5]},
     40    output: {dataType: 'float32', shape: [8, 7, 6, 5]}
     41  },
     42  {
     43    name: '[binary] Test unidirectionally broadcastable dimensions.',
     44    // Input a has a single axis of length one which is
     45    // expanded during broadcasting.
     46    a: {dataType: 'float32', shape: [4, 2, 1]},
     47    b: {dataType: 'float32', shape: [4]},
     48    output: {dataType: 'float32', shape: [4, 2, 4]}
     49  },
     50  {
     51    name: '[binary] Test scalar broadcasting.',
     52    a: {dataType: 'float32', shape: [4, 2, 4]},
     53    b: {dataType: 'float32', shape: []},
     54    output: {dataType: 'float32', shape: [4, 2, 4]}
     55  },
     56  {
     57    name: '[binary] Throw if the input shapes are not broadcastable.',
     58    a: {dataType: 'float32', shape: [4, 2]},
     59    b: {dataType: 'float32', shape: [4]},
     60  },
     61  {
     62    name: '[binary] Throw if the input types don\'t match.',
     63    a: {dataType: 'float32', shape: [4, 2]},
     64    b: {dataType: 'int32', shape: [1]},
     65  },
     66 ];
     67 
     68 tests.forEach(test => {
     69  promise_test(async t => {
     70    const builder = new MLGraphBuilder(context);
     71    if (!context.opSupportLimits().input.dataTypes.includes(
     72            test.a.dataType)) {
     73      assert_throws_js(TypeError, () => builder.input('a', test.a));
     74      return;
     75    }
     76    if (!context.opSupportLimits().input.dataTypes.includes(
     77            test.b.dataType)) {
     78      assert_throws_js(TypeError, () => builder.input('b', test.b));
     79      return;
     80    }
     81    const a = builder.input('a', test.a);
     82    const b = builder.input('b', test.b);
     83 
     84    if (test.output) {
     85      const output = builder[operatorName](a, b);
     86      assert_equals(output.dataType, test.output.dataType);
     87      assert_array_equals(output.shape, test.output.shape);
     88    } else {
     89      const options = {label};
     90      assert_throws_with_label(
     91          () => builder[operatorName](a, b, options), regrexp);
     92    }
     93  }, test.name.replace('[binary]', `[${operatorName}]`));
     94 });
     95 
     96 validateTwoInputsOfSameDataType(operatorName, label);
     97 validateTwoInputsBroadcastable(operatorName, label);
     98 validateTwoInputsFromMultipleBuilders(operatorName);
     99 validateTwoBroadcastableInputsTensorLimit(operatorName, label);