tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gatherND.https.any.js (3407B)


      1 // META: title=validation tests for WebNN API gatherND operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const label = 'gatherND_';
     11 const regexp = new RegExp('\\[' + label + '\\]');
     12 const tests = [
     13  {
     14    name: '[gatherND] Test gatherND with 5D input 3D indices',
     15    input: {dataType: 'float32', shape: [2, 2, 3, 3, 4]},
     16    indices: {dataType: 'int32', shape: [5, 4, 3]},
     17    output: {dataType: 'float32', shape: [5, 4, 3, 4]}
     18  },
     19  {
     20    name: '[gatherND] Throw if input is a scalar',
     21    input: {dataType: 'float32', shape: []},
     22    indices: {dataType: 'int32', shape: [1, 1, 1]}
     23  },
     24  {
     25    name: '[gatherND] Throw if indices is a scalar',
     26    input: {dataType: 'float32', shape: [1, 1, 1]},
     27    indices: {dataType: 'int32', shape: []}
     28  },
     29  {
     30    name: '[gatherND] Throw if indices data type is float32',
     31    input: {dataType: 'float32', shape: [1, 2, 3]},
     32    indices: {dataType: 'float32', shape: [1, 1, 1]},
     33  },
     34  {
     35    name:
     36        '[gatherND] Throw if indices.shape[-1] is greater than the input rank',
     37    input: {dataType: 'float32', shape: [1, 2, 3]},
     38    indices: {dataType: 'int32', shape: [1, 1, 4]}
     39  }
     40 ];
     41 
     42 tests.forEach(test => promise_test(async t => {
     43                const builder = new MLGraphBuilder(context);
     44                const input = builder.input('input', test.input);
     45                const indices = builder.input('indices', test.indices);
     46 
     47                if (test.output &&
     48                    context.opSupportLimits().gatherND.input.dataTypes.includes(
     49                        test.input.dataType)) {
     50                  const output = builder.gatherND(input, indices);
     51                  assert_equals(output.dataType, test.output.dataType);
     52                  assert_array_equals(output.shape, test.output.shape);
     53                } else {
     54                  const options = {label: label};
     55                  assert_throws_with_label(
     56                      () => builder.gatherND(input, indices, options), regexp);
     57                }
     58              }, test.name));
     59 
     60 multi_builder_test(async (t, builder, otherBuilder) => {
     61  const inputFromOtherBuilder =
     62      otherBuilder.input('input', {dataType: 'float32', shape: [2, 2]});
     63 
     64  const indices = builder.input('indices', {dataType: 'int32', shape: [2, 1]});
     65  assert_throws_js(
     66      TypeError, () => builder.gatherND(inputFromOtherBuilder, indices));
     67 }, '[gatherND] Throw if input is from another builder');
     68 
     69 multi_builder_test(async (t, builder, otherBuilder) => {
     70  const indicesFromOtherBuilder =
     71      otherBuilder.input('indices', {dataType: 'int32', shape: [2, 2]});
     72 
     73  const input = builder.input('input', {dataType: 'float32', shape: [2, 1]});
     74  assert_throws_js(
     75      TypeError, () => builder.gatherND(input, indicesFromOtherBuilder));
     76 }, '[gatherND] Throw if indices is from another builder');
     77 
     78 promise_test(async t => {
     79  const builder = new MLGraphBuilder(context);
     80 
     81  const input = builder.input('input', {
     82      dataType: 'float32', shape: [2, 2, 3, 3, 4]});
     83  const indices = builder.input('indices', {
     84    dataType: 'int32',
     85    shape: [context.opSupportLimits().maxTensorByteLength / 4, 1, 1]});
     86 
     87  const options = {label};
     88  assert_throws_with_label(
     89      () => builder.gatherND(input, indices, options), regexp);
     90 }, '[gatherND] throw if the output tensor byte length exceeds limit');