tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

scatterElements.https.any.js (5486B)


      1 // META: title=validation tests for WebNN API scatterElements operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const tests = [
     11  {
     12    name: '[scatterElements] Test scatterElements with default options',
     13    input: {dataType: 'float32', shape: [3, 3]},
     14    indices: {dataType: 'int32', shape: [2, 3]},
     15    updates: {dataType: 'float32', shape: [2, 3]},
     16    output: {dataType: 'float32', shape: [3, 3]}
     17  },
     18  {
     19    name: '[scatterElements] Test scatterElements with axis = 0',
     20    input: {dataType: 'float32', shape: [3, 3]},
     21    indices: {dataType: 'int32', shape: [2, 3]},
     22    updates: {dataType: 'float32', shape: [2, 3]},
     23    axis: 0,
     24    output: {dataType: 'float32', shape: [3, 3]}
     25  },
     26  {
     27    name: '[scatterElements] Test scatterElements with axis = 1',
     28    input: {dataType: 'float32', shape: [3, 3]},
     29    indices: {dataType: 'int32', shape: [3, 2]},
     30    updates: {dataType: 'float32', shape: [3, 2]},
     31    axis: 1,
     32    output: {dataType: 'float32', shape: [3, 3]}
     33  },
     34  {
     35    name: '[scatterElements] Throw if axis is greater than input rank',
     36    input: {dataType: 'float32', shape: [3, 3]},
     37    indices: {dataType: 'int32', shape: [2, 3]},
     38    updates: {dataType: 'float32', shape: [2, 3]},
     39    axis: 2
     40  },
     41  {
     42    name:
     43        '[scatterElements] Throw if updates tensor data type is not the same as input data type',
     44    input: {dataType: 'float32', shape: [3, 3]},
     45    indices: {dataType: 'int32', shape: [2, 3]},
     46    updates: {dataType: 'float16', shape: [2, 3]},
     47  },
     48  {
     49    name: '[scatterElements] Throw if input, indices and updates are scalar',
     50    input: {dataType: 'float32', shape: []},
     51    indices: {dataType: 'int32', shape: []},
     52    updates: {dataType: 'float32', shape: []},
     53  },
     54  {
     55    name:
     56        '[scatterElements] Throw if indices rank is not the same as input rank',
     57    input: {dataType: 'float32', shape: [3, 3]},
     58    indices: {dataType: 'int32', shape: [2, 3, 3]},
     59    updates: {dataType: 'float32', shape: [2, 3, 3]},
     60  },
     61  {
     62    name:
     63        '[scatterElements] Throw if indices size is not the same as input size along axis 1',
     64    input: {dataType: 'float32', shape: [3, 3]},
     65    indices: {dataType: 'int32', shape: [2, 4]},
     66    updates: {dataType: 'float32', shape: [2, 4]},
     67    axis: 0
     68  },
     69  {
     70    name:
     71        '[scatterElements] Throw if indices size is not the same as input size along axis 0',
     72    input: {dataType: 'float32', shape: [3, 3]},
     73    indices: {dataType: 'int32', shape: [2, 2]},
     74    updates: {dataType: 'float32', shape: [2, 2]},
     75    axis: 1
     76  },
     77  {
     78    name:
     79        '[scatterElements] Throw if indices rank is not the same as updates rank',
     80    input: {dataType: 'float32', shape: [3, 3]},
     81    indices: {dataType: 'int32', shape: [2, 3]},
     82    updates: {dataType: 'float32', shape: [2, 3, 3]},
     83  },
     84  {
     85    name:
     86        '[scatterElements] Throw if indices shape is not the same as updates shape',
     87    input: {dataType: 'float32', shape: [3, 3]},
     88    indices: {dataType: 'int32', shape: [2, 3]},
     89    updates: {dataType: 'float32', shape: [2, 4]},
     90  }
     91 ];
     92 
     93 tests.forEach(
     94    test => promise_test(async t => {
     95      const builder = new MLGraphBuilder(context);
     96      const input = builder.input('input', test.input);
     97      const indices = builder.input('indices', test.indices);
     98      const updates = builder.input('updates', test.updates);
     99 
    100      const options = {};
    101      if (test.axis) {
    102        options.axis = test.axis;
    103      }
    104 
    105      if (test.output) {
    106        const output =
    107            builder.scatterElements(input, indices, updates, options);
    108        assert_equals(output.dataType, test.output.dataType);
    109        assert_array_equals(output.shape, test.output.shape);
    110      } else {
    111        const label = 'a_scatter_elements'
    112        options.label = label;
    113        const regexp = new RegExp('\\[' + label + '\\]');
    114        assert_throws_with_label(
    115            () => builder.scatterElements(input, indices, updates, options),
    116            regexp);
    117      }
    118    }, test.name));
    119 
    120 multi_builder_test(async (t, builder, otherBuilder) => {
    121  const input =
    122      otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]});
    123  const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]});
    124  const updates =
    125      builder.input('updates', {dataType: 'float32', shape: [2, 3]});
    126 
    127  assert_throws_js(
    128      TypeError, () => builder.scatterElements(input, indices, updates));
    129 }, '[scatterElements] Throw if input is from another builder');
    130 
    131 multi_builder_test(async (t, builder, otherBuilder) => {
    132  const input = builder.input('input', {dataType: 'float32', shape: [3, 3]});
    133  const indices =
    134      otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]});
    135  const updates =
    136      builder.input('updates', {dataType: 'float32', shape: [2, 3]});
    137 
    138  assert_throws_js(
    139      TypeError, () => builder.scatterElements(input, indices, updates));
    140 }, '[scatterElements] Throw if indices is from another builder');
    141 
    142 multi_builder_test(async (t, builder, otherBuilder) => {
    143  const input = builder.input('input', {dataType: 'float32', shape: [3, 3]});
    144  const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]});
    145  const updates =
    146      otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]});
    147 
    148  assert_throws_js(
    149      TypeError, () => builder.scatterElements(input, indices, updates));
    150 }, '[scatterElements] Throw if updates is from another builder');