tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gruCell.https.any.js (15302B)


      1 // META: title=validation tests for WebNN API gruCell operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const batchSize = 3, inputSize = 4, hiddenSize = 5;
     11 
     12 // Dimensions required of required inputs.
     13 const kValidInputShape = [batchSize, inputSize];
     14 const kValidWeightShape = [3 * hiddenSize, inputSize];
     15 const kValidRecurrentWeightShape = [3 * hiddenSize, hiddenSize];
     16 const kValidHiddenStateShape = [batchSize, hiddenSize];
     17 // Dimensions required of optional inputs.
     18 const kValidBiasShape = [3 * hiddenSize];
     19 const kValidRecurrentBiasShape = [3 * hiddenSize];
     20 // Dimensions required of required output.
     21 const kValidOutputShape = [batchSize, hiddenSize];
     22 
     23 // Example descriptors which are valid according to the above dimensions.
     24 const kExampleInputDescriptor = {
     25  dataType: 'float32',
     26  shape: kValidInputShape
     27 };
     28 const kExampleWeightDescriptor = {
     29  dataType: 'float32',
     30  shape: kValidWeightShape
     31 };
     32 const kExampleRecurrentWeightDescriptor = {
     33  dataType: 'float32',
     34  shape: kValidRecurrentWeightShape
     35 };
     36 const kExampleHiddenStateDescriptor = {
     37  dataType: 'float32',
     38  shape: kValidHiddenStateShape
     39 };
     40 const kExampleBiasDescriptor = {
     41  dataType: 'float32',
     42  shape: kValidBiasShape
     43 };
     44 const kExampleRecurrentBiasDescriptor = {
     45  dataType: 'float32',
     46  shape: kValidRecurrentBiasShape
     47 };
     48 const kExampleOutputDescriptor = {
     49  dataType: 'float32',
     50  shape: kValidOutputShape
     51 };
     52 
     53 const tests = [
     54  {
     55    name: '[gruCell] Test with default options',
     56    input: kExampleInputDescriptor,
     57    weight: kExampleWeightDescriptor,
     58    recurrentWeight: kExampleRecurrentWeightDescriptor,
     59    hiddenState: kExampleHiddenStateDescriptor,
     60    hiddenSize: hiddenSize,
     61    output: kExampleOutputDescriptor
     62  },
     63  {
     64    name: '[gruCell] Test with given options',
     65    input: kExampleInputDescriptor,
     66    weight: kExampleWeightDescriptor,
     67    recurrentWeight: kExampleRecurrentWeightDescriptor,
     68    hiddenState: kExampleHiddenStateDescriptor,
     69    hiddenSize: hiddenSize,
     70    options: {
     71      bias: kExampleBiasDescriptor,
     72      recurrentBias: kExampleRecurrentBiasDescriptor,
     73      restAfter: true,
     74      layout: 'rzn',
     75      activations: ['sigmoid', 'relu']
     76    },
     77    output: kExampleOutputDescriptor
     78  },
     79  {
     80    name: '[gruCell] Throw if hiddenSize equals to zero',
     81    input: kExampleInputDescriptor,
     82    weight: kExampleWeightDescriptor,
     83    recurrentWeight: kExampleRecurrentWeightDescriptor,
     84    hiddenState: kExampleHiddenStateDescriptor,
     85    hiddenSize: 0
     86  },
     87  {
     88    name: '[gruCell] Throw if hiddenSize is too large',
     89    input: kExampleInputDescriptor,
     90    weight: kExampleWeightDescriptor,
     91    recurrentWeight: kExampleRecurrentWeightDescriptor,
     92    hiddenState: kExampleHiddenStateDescriptor,
     93    hiddenSize: 4294967295,
     94  },
     95  {
     96    name:
     97        '[gruCell] Throw if the data type of the inputs is not one of the floating point types',
     98    input: {dataType: 'uint32', shape: kValidInputShape},
     99    weight: {dataType: 'uint32', shape: kValidWeightShape},
    100    recurrentWeight: {dataType: 'uint32', shape: kValidRecurrentWeightShape},
    101    hiddenState: {dataType: 'uint32', shape: kValidHiddenStateShape},
    102    hiddenSize: hiddenSize
    103  },
    104  {
    105    name: '[gruCell] Throw if the rank of input is not 2',
    106    input: {dataType: 'float32', shape: [batchSize]},
    107    weight: kExampleWeightDescriptor,
    108    recurrentWeight: kExampleRecurrentWeightDescriptor,
    109    hiddenState: kExampleHiddenStateDescriptor,
    110    hiddenSize: hiddenSize
    111  },
    112  {
    113    name: '[gruCell] Throw if the input.shape[1] is incorrect',
    114    input: {dataType: 'float32', shape: [inputSize, inputSize]},
    115    weight: kExampleWeightDescriptor,
    116    recurrentWeight: kExampleRecurrentWeightDescriptor,
    117    hiddenState: kExampleHiddenStateDescriptor,
    118    hiddenSize: hiddenSize
    119  },
    120  {
    121    name:
    122        '[gruCell] Throw if data type of weight is not one of the floating point types',
    123    input: kExampleInputDescriptor,
    124    weight: {dataType: 'int8', shape: [3 * hiddenSize, inputSize]},
    125    recurrentWeight: kExampleRecurrentWeightDescriptor,
    126    hiddenState: kExampleHiddenStateDescriptor,
    127    hiddenSize: hiddenSize
    128  },
    129  {
    130    name: '[gruCell] Throw if rank of weight is not 2',
    131    input: kExampleInputDescriptor,
    132    weight: {dataType: 'float32', shape: [3 * hiddenSize]},
    133    recurrentWeight: kExampleRecurrentWeightDescriptor,
    134    hiddenState: kExampleHiddenStateDescriptor,
    135    hiddenSize: hiddenSize
    136  },
    137  {
    138    name: '[gruCell] Throw if weight.shape[0] is not 3 * hiddenSize',
    139    input: kExampleInputDescriptor,
    140    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    141    recurrentWeight: kExampleRecurrentWeightDescriptor,
    142    hiddenState: kExampleHiddenStateDescriptor,
    143    hiddenSize: hiddenSize
    144  },
    145  {
    146    name:
    147        '[gruCell] Throw if data type of recurrentWeight is not one of the floating point types',
    148    input: kExampleInputDescriptor,
    149    weight: kExampleWeightDescriptor,
    150    recurrentWeight: {dataType: 'int32', shape: [3 * hiddenSize, hiddenSize]},
    151    hiddenState: kExampleHiddenStateDescriptor,
    152    hiddenSize: hiddenSize
    153  },
    154  {
    155    name: '[gruCell] Throw if the rank of recurrentWeight is not 2',
    156    input: kExampleInputDescriptor,
    157    weight: kExampleWeightDescriptor,
    158    recurrentWeight: {dataType: 'float32', shape: [3 * hiddenSize]},
    159    hiddenState: kExampleHiddenStateDescriptor,
    160    hiddenSize: hiddenSize
    161  },
    162  {
    163    name: '[gruCell] Throw if the recurrentWeight.shape is invalid',
    164    input: kExampleInputDescriptor,
    165    weight: kExampleWeightDescriptor,
    166    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    167    hiddenState: kExampleHiddenStateDescriptor,
    168    hiddenSize: hiddenSize
    169  },
    170  {
    171    name:
    172        '[gruCell] Throw if data type of hiddenState is not one of the floating point types',
    173    input: kExampleInputDescriptor,
    174    weight: kExampleWeightDescriptor,
    175    recurrentWeight: kExampleRecurrentWeightDescriptor,
    176    hiddenState: {dataType: 'uint32', shape: [batchSize, hiddenSize]},
    177    hiddenSize: hiddenSize
    178  },
    179  {
    180    name: '[gruCell] Throw if the rank of hiddenState is not 2',
    181    input: kExampleInputDescriptor,
    182    weight: kExampleWeightDescriptor,
    183    recurrentWeight: kExampleRecurrentWeightDescriptor,
    184    hiddenState: {dataType: 'float32', shape: [hiddenSize]},
    185    hiddenSize: hiddenSize
    186  },
    187  {
    188    name: '[gruCell] Throw if the hiddenState.shape is invalid',
    189    input: kExampleInputDescriptor,
    190    weight: kExampleWeightDescriptor,
    191    recurrentWeight: kExampleRecurrentWeightDescriptor,
    192    hiddenState: {dataType: 'float32', shape: [batchSize, 3 * hiddenSize]},
    193    hiddenSize: hiddenSize
    194  },
    195  {
    196    name: '[gruCell] Throw if the size of options.activations is not 2',
    197    input: kExampleInputDescriptor,
    198    weight: kExampleWeightDescriptor,
    199    recurrentWeight: kExampleRecurrentWeightDescriptor,
    200    hiddenState: kExampleHiddenStateDescriptor,
    201    hiddenSize: hiddenSize,
    202    options: {activations: ['sigmoid', 'tanh', 'relu']}
    203  },
    204  {
    205    name:
    206        '[gruCell] Throw if data type of options.bias is not one of the floating point types',
    207    input: kExampleInputDescriptor,
    208    weight: kExampleWeightDescriptor,
    209    recurrentWeight: kExampleRecurrentWeightDescriptor,
    210    hiddenState: kExampleHiddenStateDescriptor,
    211    hiddenSize: hiddenSize,
    212    options: {bias: {dataType: 'uint8', shape: [3 * hiddenSize]}}
    213  },
    214  {
    215    name: '[gruCell] Throw if the rank of options.bias is not 1',
    216    input: kExampleInputDescriptor,
    217    weight: kExampleWeightDescriptor,
    218    recurrentWeight: kExampleRecurrentWeightDescriptor,
    219    hiddenState: kExampleHiddenStateDescriptor,
    220    hiddenSize: hiddenSize,
    221    options: {bias: {dataType: 'float32', shape: [batchSize, 3 * hiddenSize]}}
    222  },
    223  {
    224    name: '[gruCell] Throw if options.bias.shape[0] is not 3 * hiddenSize',
    225    input: kExampleInputDescriptor,
    226    weight: kExampleWeightDescriptor,
    227    recurrentWeight: kExampleRecurrentWeightDescriptor,
    228    hiddenState: kExampleHiddenStateDescriptor,
    229    hiddenSize: hiddenSize,
    230    options: {bias: {dataType: 'float32', shape: [2 * hiddenSize]}}
    231  },
    232  {
    233    name:
    234        '[gruCell] Throw if data type of options.recurrentBias is not one of the floating point types',
    235    input: kExampleInputDescriptor,
    236    weight: kExampleWeightDescriptor,
    237    recurrentWeight: kExampleRecurrentWeightDescriptor,
    238    hiddenState: kExampleHiddenStateDescriptor,
    239    hiddenSize: hiddenSize,
    240    options: {recurrentBias: {dataType: 'int8', shape: [3 * hiddenSize]}}
    241  },
    242  {
    243    name: '[gruCell] Throw if the rank of options.recurrentBias is not 1',
    244    input: kExampleInputDescriptor,
    245    weight: kExampleWeightDescriptor,
    246    recurrentWeight: kExampleRecurrentWeightDescriptor,
    247    hiddenState: kExampleHiddenStateDescriptor,
    248    hiddenSize: hiddenSize,
    249    options: {
    250      recurrentBias: {dataType: 'float32', shape: [batchSize, 3 * hiddenSize]}
    251    }
    252  },
    253  {
    254    name:
    255        '[gruCell] Throw if options.recurrentBias.shape[0] is not 3 * hiddenSize',
    256    input: kExampleInputDescriptor,
    257    weight: kExampleWeightDescriptor,
    258    recurrentWeight: kExampleRecurrentWeightDescriptor,
    259    hiddenState: kExampleHiddenStateDescriptor,
    260    hiddenSize: hiddenSize,
    261    options: {recurrentBias: {dataType: 'float16', shape: [4 * hiddenSize]}}
    262  }
    263 ];
    264 
    265 tests.forEach(
    266    test =>
    267        promise_test(async t => {
    268          const builder = new MLGraphBuilder(context);
    269          const input = builder.input('input', test.input);
    270          const weight = builder.input('weight', test.weight);
    271          const recurrentWeight =
    272              builder.input('recurrentWeight', test.recurrentWeight);
    273          const hiddenState = builder.input('hiddenState', test.hiddenState);
    274 
    275          const options = {};
    276          if (test.options) {
    277            if (test.options.bias) {
    278              options.bias = builder.input('bias', test.options.bias);
    279            }
    280            if (test.options.recurrentBias) {
    281              options.recurrentBias =
    282                  builder.input('recurrentBias', test.options.recurrentBias);
    283            }
    284            if (test.options.resetAfter) {
    285              options.resetAfter = test.options.resetAfter;
    286            }
    287            if (test.options.layout) {
    288              options.layout = test.options.layout;
    289            }
    290            if (test.options.activations) {
    291              options.activations = test.options.activations;
    292            }
    293          }
    294 
    295          if (test.output &&
    296              context.opSupportLimits().gruCell.input.dataTypes.includes(
    297                  test.input.dataType)) {
    298            const output = builder.gruCell(
    299                input, weight, recurrentWeight, hiddenState, test.hiddenSize,
    300                options);
    301            assert_equals(output.dataType, test.output.dataType);
    302            assert_array_equals(output.shape, test.output.shape);
    303          } else {
    304            const label = 'gru_cell_xxx';
    305            options.label = label;
    306            const regrexp = new RegExp('\\[' + label + '\\]');
    307            assert_throws_with_label(
    308                () => builder.gruCell(
    309                    input, weight, recurrentWeight, hiddenState,
    310                    test.hiddenSize, options),
    311                regrexp);
    312          }
    313        }, test.name));
    314 
    315 multi_builder_test(async (t, builder, otherBuilder) => {
    316  const inputFromOtherBuilder =
    317      otherBuilder.input('input', kExampleInputDescriptor);
    318 
    319  const weight = builder.input('weight', kExampleWeightDescriptor);
    320  const recurrentWeight =
    321      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    322  const hiddenState =
    323      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    324  assert_throws_js(
    325      TypeError,
    326      () => builder.gruCell(
    327          inputFromOtherBuilder, weight, recurrentWeight, hiddenState,
    328          hiddenSize));
    329 }, '[gruCell] throw if input is from another builder');
    330 
    331 multi_builder_test(async (t, builder, otherBuilder) => {
    332  const weightFromOtherBuilder =
    333      otherBuilder.input('weight', kExampleWeightDescriptor);
    334 
    335  const input = builder.input('input', kExampleInputDescriptor);
    336  const recurrentWeight =
    337      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    338  const hiddenState =
    339      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    340  assert_throws_js(
    341      TypeError,
    342      () => builder.gruCell(
    343          input, weightFromOtherBuilder, recurrentWeight, hiddenState,
    344          hiddenSize));
    345 }, '[gruCell] throw if weight is from another builder');
    346 
    347 multi_builder_test(async (t, builder, otherBuilder) => {
    348  const recurrentWeightFromOtherBuilder =
    349      otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    350 
    351  const input = builder.input('input', kExampleInputDescriptor);
    352  const weight = builder.input('weight', kExampleWeightDescriptor);
    353  const hiddenState =
    354      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    355  assert_throws_js(
    356      TypeError,
    357      () => builder.gruCell(
    358          input, weight, recurrentWeightFromOtherBuilder, hiddenState,
    359          hiddenSize));
    360 }, '[gruCell] throw if recurrentWeight is from another builder');
    361 
    362 multi_builder_test(async (t, builder, otherBuilder) => {
    363  const hiddenStateFromOtherBuilder =
    364      otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor);
    365 
    366  const input = builder.input('input', kExampleInputDescriptor);
    367  const weight = builder.input('weight', kExampleWeightDescriptor);
    368  const recurrentWeight =
    369      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    370  assert_throws_js(
    371      TypeError,
    372      () => builder.gruCell(
    373          input, weight, recurrentWeight, hiddenStateFromOtherBuilder,
    374          hiddenSize));
    375 }, '[gruCell] throw if hiddenState is from another builder');
    376 
    377 multi_builder_test(async (t, builder, otherBuilder) => {
    378  const biasFromOtherBuilder =
    379      otherBuilder.input('bias', kExampleBiasDescriptor);
    380  const options = {bias: biasFromOtherBuilder};
    381 
    382  const input = builder.input('input', kExampleInputDescriptor);
    383  const weight = builder.input('weight', kExampleWeightDescriptor);
    384  const recurrentWeight =
    385      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    386  const hiddenState =
    387      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    388  assert_throws_js(
    389      TypeError,
    390      () => builder.gruCell(
    391          input, weight, recurrentWeight, hiddenState, hiddenSize, options));
    392 }, '[gruCell] throw if bias option is from another builder');
    393 
    394 multi_builder_test(async (t, builder, otherBuilder) => {
    395  const recurrentBiasFromOtherBuilder =
    396      otherBuilder.input('recurrentBias', kExampleRecurrentBiasDescriptor);
    397  const options = {recurrentBias: recurrentBiasFromOtherBuilder};
    398 
    399  const input = builder.input('input', kExampleInputDescriptor);
    400  const weight = builder.input('weight', kExampleWeightDescriptor);
    401  const recurrentWeight =
    402      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    403  const hiddenState =
    404      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    405  assert_throws_js(
    406      TypeError,
    407      () => builder.gruCell(
    408          input, weight, recurrentWeight, hiddenState, hiddenSize, options));
    409 }, '[gruCell] throw if recurrentBias option is from another builder');