tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gather.https.any.js (4116B)


      1 // META: title=validation tests for WebNN API gather operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const label = 'gather_'
     11 const regrexp = new RegExp('\\[' + label + '\\]');
     12 const tests = [
     13  {
     14    name: '[gather] Test gather with default options and 0-D indices',
     15    input: {dataType: 'int32', shape: [3]},
     16    indices: {dataType: 'int64', shape: []},
     17    output: {dataType: 'int32', shape: []}
     18  },
     19  {
     20    name: '[gather] Test gather with axis = 2',
     21    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     22    indices: {dataType: 'int64', shape: [5, 6]},
     23    axis: 2,
     24    output: {dataType: 'float32', shape: [1, 2, 5, 6, 4]}
     25  },
     26  {
     27    name: '[gather] Test gather with indices\'s dataType = uint32',
     28    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     29    indices: {dataType: 'uint32', shape: [5, 6]},
     30    axis: 2,
     31    output: {dataType: 'float32', shape: [1, 2, 5, 6, 4]}
     32  },
     33  {
     34    name: '[gather] Test gather with indices\'s dataType = int32',
     35    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     36    indices: {dataType: 'int32', shape: [5, 6]},
     37    axis: 2,
     38    output: {dataType: 'float32', shape: [1, 2, 5, 6, 4]}
     39  },
     40  {
     41    name: '[gather] TypeError is expected if the input is a scalar',
     42    input: {dataType: 'float16', shape: []},
     43    indices: {dataType: 'int64', shape: [1]},
     44  },
     45  {
     46    name:
     47        '[gather] TypeError is expected if the axis is greater than the rank of input',
     48    input: {dataType: 'float16', shape: [1, 2, 3]},
     49    indices: {dataType: 'int32', shape: [5, 6]},
     50    axis: 4,
     51  },
     52  {
     53    name:
     54        '[gather] TypeError is expected if the data type of indices is float32 which is invalid',
     55    input: {dataType: 'float16', shape: [1, 2, 3, 4]},
     56    indices: {dataType: 'float32', shape: [5, 6]},
     57  },
     58  {
     59    name:
     60        '[gather] TypeError is expected if the data type of indices is uint64 which is invalid',
     61    input: {dataType: 'float16', shape: [1, 2, 3, 4]},
     62    indices: {dataType: 'uint64', shape: [5, 6]},
     63  },
     64 ];
     65 
     66 tests.forEach(
     67    test => promise_test(async t => {
     68      const builder = new MLGraphBuilder(context);
     69      const input = builder.input('input', test.input);
     70      const indices = builder.input('indices', test.indices);
     71 
     72      const options = {};
     73      if (test.axis) {
     74        options.axis = test.axis;
     75      }
     76 
     77      if (test.output) {
     78        const output = builder.gather(input, indices, options);
     79        assert_equals(output.dataType, test.output.dataType);
     80        assert_array_equals(output.shape, test.output.shape);
     81      } else {
     82        options.label = label;
     83        assert_throws_with_label(
     84            () => builder.gather(input, indices, options), regrexp);
     85      }
     86    }, test.name));
     87 
     88 multi_builder_test(async (t, builder, otherBuilder) => {
     89  const inputFromOtherBuilder =
     90      otherBuilder.input('input', {dataType: 'float32', shape: [2, 2]});
     91 
     92  const indices = builder.input('indices', {dataType: 'int64', shape: [2, 2]});
     93  assert_throws_js(
     94      TypeError, () => builder.gather(inputFromOtherBuilder, indices));
     95 }, '[gather] throw if input is from another builder');
     96 
     97 multi_builder_test(async (t, builder, otherBuilder) => {
     98  const indicesFromOtherBuilder =
     99      otherBuilder.input('indices', {dataType: 'int64', shape: [2, 2]});
    100 
    101  const input = builder.input('input', {dataType: 'float32', shape: [2, 2]});
    102  assert_throws_js(
    103      TypeError, () => builder.gather(input, indicesFromOtherBuilder));
    104 }, '[gather] throw if indices is from another builder');
    105 
    106 promise_test(async t => {
    107  const builder = new MLGraphBuilder(context);
    108 
    109  const input = builder.input('input', {
    110      dataType: 'float32', shape: [1, 3, 3, 4]});
    111  const indices = builder.input('indices', {
    112      dataType: 'int32',
    113      shape: [context.opSupportLimits().maxTensorByteLength / 4] });
    114 
    115  const options = {};
    116  options.label = label;
    117  options.axis = 2;
    118  assert_throws_with_label(
    119      () => builder.gather(input, indices, options), regrexp);
    120 }, '[gather] throw if the output tensor byte length exceeds limit');