tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

expand.https.any.js (4440B)


      1 // META: title=validation tests for WebNN API expand operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 multi_builder_test(async (t, builder, otherBuilder) => {
     11  const inputFromOtherBuilder =
     12      otherBuilder.input('input', {dataType: 'float32', shape: [2, 1, 2]});
     13 
     14  const newShape = [2, 2, 2];
     15  assert_throws_js(
     16      TypeError, () => builder.expand(inputFromOtherBuilder, newShape));
     17 }, '[expand] throw if input is from another builder');
     18 
     19 const label = 'xxx_expand';
     20 const regexp = new RegExp('\\[' + label + '\\]');
     21 const tests = [
     22  {
     23    name: '[expand] Test with 0-D scalar to 3-D tensor.',
     24    input: {dataType: 'float32', shape: []},
     25    newShape: [3, 4, 5],
     26    output: {dataType: 'float32', shape: [3, 4, 5]}
     27  },
     28  {
     29    name: '[expand] Test with the new shapes that are the same as input.',
     30    input: {dataType: 'float32', shape: [4]},
     31    newShape: [4],
     32    output: {dataType: 'float32', shape: [4]}
     33  },
     34  {
     35    name: '[expand] Test with the new shapes that are broadcastable.',
     36    input: {dataType: 'float32', shape: [3, 1, 5]},
     37    newShape: [3, 4, 5],
     38    output: {dataType: 'float32', shape: [3, 4, 5]}
     39  },
     40  {
     41    name:
     42        '[expand] Test with the new shapes that are broadcastable and the rank of new shapes is larger than input.',
     43    input: {dataType: 'float32', shape: [2, 5]},
     44    newShape: [3, 2, 5],
     45    output: {dataType: 'float32', shape: [3, 2, 5]}
     46  },
     47  {
     48    name:
     49        '[expand] Throw if the input shapes are the same rank but not broadcastable.',
     50    input: {dataType: 'float32', shape: [3, 6, 2]},
     51    newShape: [4, 3, 5],
     52    options: {label}
     53  },
     54  {
     55    name: '[expand] Throw if the input shapes are not broadcastable.',
     56    input: {dataType: 'float32', shape: [5, 4]},
     57    newShape: [5],
     58    options: {label}
     59  },
     60  {
     61    name: '[expand] Throw if the number of new shapes is too large.',
     62    input: {dataType: 'float32', shape: [1, 2, 1, 1]},
     63    newShape: [1, 2, kMaxUnsignedLong, kMaxUnsignedLong],
     64  },
     65 ];
     66 
     67 tests.forEach(
     68    test => promise_test(async t => {
     69      const builder = new MLGraphBuilder(context);
     70      const input = builder.input('input', test.input);
     71 
     72      if (test.output) {
     73        const output = builder.expand(input, test.newShape);
     74        assert_equals(output.dataType, test.output.dataType);
     75        assert_array_equals(output.shape, test.output.shape);
     76      } else {
     77        const options = {...test.options};
     78        if (options.label) {
     79          assert_throws_with_label(
     80              () => builder.expand(input, test.newShape, options), regexp);
     81        } else {
     82          assert_throws_js(
     83              TypeError, () => builder.expand(input, test.newShape, options));
     84        }
     85      }
     86    }, test.name));
     87 
     88 promise_test(async t => {
     89  for (let dataType of allWebNNOperandDataTypes) {
     90    if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
     91      continue;
     92    }
     93    const builder = new MLGraphBuilder(context);
     94    const shape = [1];
     95    const newShape = [1, 2, 3];
     96    const input = builder.input(`input`, {dataType, shape});
     97    if (context.opSupportLimits().expand.input.dataTypes.includes(dataType)) {
     98      const output = builder.expand(input, newShape);
     99      assert_equals(output.dataType, dataType);
    100      assert_array_equals(output.shape, newShape);
    101    } else {
    102      assert_throws_js(TypeError, () => builder.expand(input, newShape));
    103    }
    104  }
    105 }, `[expand] Test expand with all of the data types.`);
    106 
    107 promise_test(async t => {
    108  const builder = new MLGraphBuilder(context);
    109 
    110  const input = builder.input('input', {
    111      dataType: 'float32', shape: [1, 2, 1, 1]});
    112  const newShape = [1, 2, context.opSupportLimits().maxTensorByteLength, 1];
    113 
    114  const options = {label};
    115  assert_throws_with_label(
    116      () => builder.expand(input, newShape, options), regexp);
    117 }, '[expand] throw if the output tensor byte length exceeds limit');
    118 
    119 promise_test(async t => {
    120  const builder = new MLGraphBuilder(context);
    121 
    122  const input = builder.input('input', {dataType: 'float32', shape: [2]});
    123  const newShape =
    124      new Array(context.opSupportLimits().expand.output.rankRange.max + 1)
    125          .fill(1);
    126  newShape[newShape.length - 1] = 2;
    127 
    128  const options = {label};
    129  assert_throws_with_label(
    130      () => builder.expand(input, newShape, options), regexp);
    131 }, '[expand] throw if new shape rank exceeds limit');