tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

instanceNormalization.https.any.js (6768B)


      1 // META: title=validation tests for WebNN API instanceNormalization operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const kExampleInputDescriptor = {
     11  dataType: 'float32',
     12  shape: [2, 2, 2, 2]
     13 };
     14 // 1D tensor descriptor which may be used for `scale`, or `bias` inputs.
     15 const kExample1DTensorDescriptor = {
     16  dataType: 'float32',
     17  shape: [2]
     18 };
     19 
     20 multi_builder_test(async (t, builder, otherBuilder) => {
     21  const inputFromOtherBuilder =
     22      otherBuilder.input('input', kExampleInputDescriptor);
     23 
     24  assert_throws_js(
     25      TypeError, () => builder.instanceNormalization(inputFromOtherBuilder));
     26 }, '[instanceNormalization] throw if input is from another builder');
     27 
     28 multi_builder_test(async (t, builder, otherBuilder) => {
     29  const scaleFromOtherBuilder =
     30      otherBuilder.input('scale', kExample1DTensorDescriptor);
     31  const options = {scale: scaleFromOtherBuilder};
     32 
     33  const input = builder.input('input', kExampleInputDescriptor);
     34  assert_throws_js(
     35      TypeError, () => builder.instanceNormalization(input, options));
     36 }, '[instanceNormalization] throw if scale option is from another builder');
     37 
     38 multi_builder_test(async (t, builder, otherBuilder) => {
     39  const biasFromOtherBuilder =
     40      otherBuilder.input('bias', kExample1DTensorDescriptor);
     41  const options = {bias: biasFromOtherBuilder};
     42 
     43  const input = builder.input('input', kExampleInputDescriptor);
     44  assert_throws_js(
     45      TypeError, () => builder.instanceNormalization(input, options));
     46 }, '[instanceNormalization] throw if bias option is from another builder');
     47 
     48 const label = 'instance_normalization';
     49 const tests = [
     50  {
     51    name: '[instanceNormalization] Test with default options for 4-D input.',
     52    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     53    output: {dataType: 'float32', shape: [1, 2, 3, 4]}
     54  },
     55  {
     56    name:
     57        '[instanceNormalization] Test with scale, bias and default epsilon value.',
     58    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     59    options: {
     60      scale: {dataType: 'float32', shape: [2]},
     61      bias: {dataType: 'float32', shape: [2]},
     62      epsilon: 1e-5,
     63    },
     64    output: {dataType: 'float32', shape: [1, 2, 3, 4]}
     65  },
     66  {
     67    name: '[instanceNormalization] Test with a non-default epsilon value.',
     68    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     69    options: {
     70      epsilon: 1e-4,
     71    },
     72    output: {dataType: 'float32', shape: [1, 2, 3, 4]}
     73  },
     74  {
     75    name: '[instanceNormalization] Test with layout=nhwc.',
     76    input: {dataType: 'float32', shape: [1, 2, 3, 4]},
     77    options: {
     78      layout: 'nhwc',
     79      scale: {dataType: 'float32', shape: [4]},
     80      bias: {dataType: 'float32', shape: [4]},
     81    },
     82    output: {dataType: 'float32', shape: [1, 2, 3, 4]}
     83  },
     84  {
     85    name: '[instanceNormalization] Test when the input data type is float16.',
     86    input: {dataType: 'float16', shape: [1, 2, 3, 4]},
     87    output: {dataType: 'float16', shape: [1, 2, 3, 4]},
     88    options: {label}
     89  },
     90  {
     91    name: '[instanceNormalization] Throw if the input is not a 4-D tensor.',
     92    input: {dataType: 'float32', shape: [1, 2, 5, 5, 2]},
     93    options: {label}
     94  },
     95  {
     96    name:
     97        '[instanceNormalization] Throw if the input data type is not one of floating point types.',
     98    input: {dataType: 'int32', shape: [1, 2, 5, 5]},
     99    options: {label}
    100  },
    101  {
    102    name:
    103        '[instanceNormalization] Throw if the scale data type is not the same as the input data type.',
    104    input: {dataType: 'float16', shape: [1, 2, 5, 5]},
    105    options: {
    106      scale: {dataType: 'float32', shape: [2]},
    107      label: label,
    108    },
    109  },
    110  {
    111    name:
    112        '[instanceNormalization] Throw if the scale operand is not a 1-D tensor.',
    113    input: {dataType: 'float32', shape: [1, 2, 5, 5]},
    114    options: {
    115      scale: {dataType: 'float32', shape: [2, 1]},
    116      label: label,
    117    },
    118  },
    119  {
    120    name:
    121        '[instanceNormalization] Throw if the size of scale operand is not equal to the size of the feature dimension of the input with layout=nhwc.',
    122    input: {dataType: 'float32', shape: [1, 2, 5, 5]},
    123    options: {
    124      layout: 'nhwc',
    125      scale: {dataType: 'float32', shape: [2]},
    126      label: label,
    127    },
    128  },
    129  {
    130    name:
    131        '[instanceNormalization] Throw if the size of scale operand is not equal to the size of the feature dimension of the input with layout=nchw.',
    132    input: {dataType: 'float32', shape: [1, 5, 5, 2]},
    133    options: {
    134      layout: 'nchw',
    135      scale: {dataType: 'float32', shape: [2]},
    136      label: label,
    137    },
    138  },
    139  {
    140    name:
    141        '[instanceNormalization] Throw if the bias data type is not the same as the input data type.',
    142    input: {dataType: 'float16', shape: [1, 2, 5, 5]},
    143    options: {
    144      bias: {dataType: 'float32', shape: [2]},
    145      label: label,
    146    },
    147  },
    148  {
    149    name:
    150        '[instanceNormalization] Throw if the bias operand is not a 1-D tensor.',
    151    input: {dataType: 'float32', shape: [1, 2, 5, 5]},
    152    options: {
    153      scale: {dataType: 'float32', shape: [2, 1]},
    154      label: label,
    155    },
    156  },
    157  {
    158    name:
    159        '[instanceNormalization] Throw if the size of bias operand is not equal to the size of the feature dimension of the input with layout=nhwc.',
    160    input: {dataType: 'float32', shape: [1, 2, 5, 5]},
    161    options: {
    162      layout: 'nhwc',
    163      bias: {dataType: 'float32', shape: [2]},
    164      label: label,
    165    },
    166  },
    167  {
    168    name:
    169        '[instanceNormalization] Throw if the size of bias operand is not equal to the size of the feature dimension of the input with layout=nchw.',
    170    input: {dataType: 'float32', shape: [1, 5, 5, 2]},
    171    options: {
    172      layout: 'nchw',
    173      bias: {dataType: 'float32', shape: [2]},
    174      label: label,
    175    },
    176  },
    177 ];
    178 
    179 tests.forEach(
    180    test => promise_test(async t => {
    181      const builder = new MLGraphBuilder(context);
    182      const input = builder.input('input', test.input);
    183 
    184      if (test.options && test.options.bias) {
    185        test.options.bias = builder.input('bias', test.options.bias);
    186      }
    187      if (test.options && test.options.scale) {
    188        test.options.scale = builder.input('scale', test.options.scale);
    189      }
    190 
    191      if (test.output &&
    192          context.opSupportLimits()
    193              .instanceNormalization.input.dataTypes.includes(
    194                  test.input.dataType)) {
    195        const output = builder.instanceNormalization(input, test.options);
    196        assert_equals(output.dataType, test.output.dataType);
    197        assert_array_equals(output.shape, test.output.shape);
    198      } else {
    199        const regrexp = new RegExp('\\[' + label + '\\]');
    200        assert_throws_with_label(
    201            () => builder.instanceNormalization(input, test.options), regrexp);
    202      }
    203    }, test.name));