tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

lstm.https.any.js (15605B)


      1 // META: title=validation tests for WebNN API lstm operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const steps = 10, batchSize = 5, inputSize = 3, hiddenSize = 8,
     11      numDirections = 1;
     12 
     13 // Dimensions required of required inputs.
     14 const kValidInputShape = [steps, batchSize, inputSize];
     15 const kValidWeightShape = [numDirections, 4 * hiddenSize, inputSize];
     16 const kValidRecurrentWeightShape = [numDirections, 4 * hiddenSize, hiddenSize];
     17 // Dimensions required of optional inputs.
     18 const kValidBiasShape = [numDirections, 4 * hiddenSize];
     19 const kValidPeepholeWeightShape = [numDirections, 3 * hiddenSize];
     20 const kValidInitialHiddenStateShape = [numDirections, batchSize, hiddenSize];
     21 
     22 // Example descriptors which are valid according to the above dimensions.
     23 const kExampleInputDescriptor = {
     24  dataType: 'float32',
     25  shape: kValidInputShape
     26 };
     27 const kExampleWeightDescriptor = {
     28  dataType: 'float32',
     29  shape: kValidWeightShape
     30 };
     31 const kExampleRecurrentWeightDescriptor = {
     32  dataType: 'float32',
     33  shape: kValidRecurrentWeightShape
     34 };
     35 const kExampleBiasDescriptor = {
     36  dataType: 'float32',
     37  shape: kValidBiasShape
     38 };
     39 const kExamplePeepholeWeightDescriptor = {
     40  dataType: 'float32',
     41  shape: kValidPeepholeWeightShape
     42 };
     43 const kExampleInitialHiddenStateDescriptor = {
     44  dataType: 'float32',
     45  shape: kValidInitialHiddenStateShape
     46 };
     47 
     48 const tests = [
     49  {
     50    name: '[lstm] Test with default options',
     51    input: {dataType: 'float16', shape: kValidInputShape},
     52    weight: {dataType: 'float16', shape: kValidWeightShape},
     53    recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape},
     54    steps: steps,
     55    hiddenSize: hiddenSize,
     56    outputs: [
     57      {dataType: 'float16', shape: [numDirections, batchSize, hiddenSize]},
     58      {dataType: 'float16', shape: [numDirections, batchSize, hiddenSize]}
     59    ]
     60  },
     61  {
     62    name: '[lstm] Test with given options',
     63    input: kExampleInputDescriptor,
     64    weight: {
     65      dataType: 'float32',
     66      shape: [/*numDirections=*/ 2, 4 * hiddenSize, inputSize]
     67    },
     68    recurrentWeight: {
     69      dataType: 'float32',
     70      shape: [/*numDirections=*/ 2, 4 * hiddenSize, hiddenSize]
     71    },
     72    steps: steps,
     73    hiddenSize: hiddenSize,
     74    options: {
     75      bias:
     76          {dataType: 'float32', shape: [/*numDirections=*/ 2, 4 * hiddenSize]},
     77      recurrentBias:
     78          {dataType: 'float32', shape: [/*numDirections=*/ 2, 4 * hiddenSize]},
     79      peepholeWeight:
     80          {dataType: 'float32', shape: [/*numDirections=*/ 2, 3 * hiddenSize]},
     81      initialHiddenState: {
     82        dataType: 'float32',
     83        shape: [/*numDirections=*/ 2, batchSize, hiddenSize]
     84      },
     85      initialCellState: {
     86        dataType: 'float32',
     87        shape: [/*numDirections=*/ 2, batchSize, hiddenSize]
     88      },
     89      returnSequence: true,
     90      direction: 'both',
     91      layout: 'ifgo',
     92      activations: ['sigmoid', 'relu', 'tanh']
     93    },
     94    outputs: [
     95      {
     96        dataType: 'float32',
     97        shape: [/*numDirections=*/ 2, batchSize, hiddenSize]
     98      },
     99      {
    100        dataType: 'float32',
    101        shape: [/*numDirections=*/ 2, batchSize, hiddenSize]
    102      },
    103      {
    104        dataType: 'float32',
    105        shape: [steps, /*numDirections=*/ 2, batchSize, hiddenSize]
    106      }
    107    ]
    108  },
    109  {
    110    name: '[lstm] TypeError is expected if hiddenSize equals to zero',
    111    input: kExampleInputDescriptor,
    112    weight: kExampleWeightDescriptor,
    113    recurrentWeight: kExampleRecurrentWeightDescriptor,
    114    steps: steps,
    115    hiddenSize: 0
    116  },
    117  {
    118    name: '[lstm] TypeError is expected if hiddenSize is too large',
    119    input: kExampleInputDescriptor,
    120    weight: kExampleWeightDescriptor,
    121    recurrentWeight: kExampleRecurrentWeightDescriptor,
    122    steps: steps,
    123    hiddenSize: 4294967295,
    124  },
    125  {
    126    name: '[lstm] TypeError is expected if steps equals to zero',
    127    input: kExampleInputDescriptor,
    128    weight: kExampleWeightDescriptor,
    129    recurrentWeight: kExampleRecurrentWeightDescriptor,
    130    steps: 0,
    131    hiddenSize: hiddenSize,
    132  },
    133  {
    134    name:
    135        '[lstm] TypeError is expected if the data type is not one of the floating point types',
    136    input: {dataType: 'uint32', shape: kValidInputShape},
    137    weight: {dataType: 'uint32', shape: kValidWeightShape},
    138    recurrentWeight: {dataType: 'uint32', shape: kValidRecurrentWeightShape},
    139    steps: steps,
    140    hiddenSize: hiddenSize
    141  },
    142  {
    143    name: '[lstm] TypeError is expected if the rank of input is not 3',
    144    input: {dataType: 'float32', shape: [steps, batchSize]},
    145    weight: kExampleWeightDescriptor,
    146    recurrentWeight: kExampleRecurrentWeightDescriptor,
    147    steps: steps,
    148    hiddenSize: hiddenSize
    149  },
    150  {
    151    name:
    152        '[lstm] TypeError is expected if input.shape[0] is not equal to steps',
    153    input: {dataType: 'float32', shape: [1000, batchSize, inputSize]},
    154    weight: kExampleWeightDescriptor,
    155    recurrentWeight: kExampleRecurrentWeightDescriptor,
    156    steps: steps,
    157    hiddenSize: hiddenSize
    158  },
    159  {
    160    name: '[lstm] TypeError is expected if the shape of weight is incorrect',
    161    input: kExampleInputDescriptor,
    162    weight: {dataType: 'float32', shape: [numDirections, 4 * hiddenSize, 1000]},
    163    recurrentWeight: kExampleRecurrentWeightDescriptor,
    164    steps: steps,
    165    hiddenSize: hiddenSize
    166  },
    167  {
    168    name:
    169        '[lstm] TypeError is expected if the rank of recurrentWeight is not 3',
    170    input: kExampleInputDescriptor,
    171    weight: kExampleWeightDescriptor,
    172    recurrentWeight:
    173        {dataType: 'float32', shape: [numDirections, 4 * hiddenSize]},
    174    steps: steps,
    175    hiddenSize: hiddenSize
    176  },
    177  {
    178    name:
    179        '[lstm] TypeError is expected if the size of options.activations is not 3',
    180    input: kExampleInputDescriptor,
    181    weight: kExampleWeightDescriptor,
    182    recurrentWeight: kExampleRecurrentWeightDescriptor,
    183    steps: steps,
    184    hiddenSize: hiddenSize,
    185    options: {activations: ['sigmoid', 'tanh']}
    186  },
    187  {
    188    name: '[lstm] TypeError is expected if the rank of options.bias is not 2',
    189    input: {dataType: 'float16', shape: kValidInputShape},
    190    weight: {dataType: 'float16', shape: kValidWeightShape},
    191    recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape},
    192    steps: steps,
    193    hiddenSize: hiddenSize,
    194    options: {bias: {dataType: 'float16', shape: [numDirections]}}
    195  },
    196  {
    197    name:
    198        '[lstm] TypeError is expected if the shape of options.recurrentBias.shape is incorrect',
    199    input: {dataType: 'float16', shape: kValidInputShape},
    200    weight: {dataType: 'float16', shape: kValidWeightShape},
    201    recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape},
    202    steps: steps,
    203    hiddenSize: hiddenSize,
    204    options:
    205        {recurrentBias: {dataType: 'float16', shape: [numDirections, 1000]}}
    206  },
    207  {
    208    name:
    209        '[lstm] TypeError is expected if the dataType of options.peepholeWeight is incorrect',
    210    input: {dataType: 'float16', shape: kValidInputShape},
    211    weight: {dataType: 'float16', shape: kValidWeightShape},
    212    recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape},
    213    steps: steps,
    214    hiddenSize: hiddenSize,
    215    options: {
    216      peepholeWeight:
    217          {dataType: 'float32', shape: [numDirections, 3 * hiddenSize]}
    218    }
    219  },
    220  {
    221    name:
    222        '[lstm] TypeError is expected if the dataType of options.initialHiddenState is incorrect',
    223    input: {dataType: 'float16', shape: kValidInputShape},
    224    weight: {dataType: 'float16', shape: kValidWeightShape},
    225    recurrentWeight: {dataType: 'float16', shape: kValidRecurrentWeightShape},
    226    steps: steps,
    227    hiddenSize: hiddenSize,
    228    options: {
    229      initialHiddenState:
    230          {dataType: 'uint64', shape: [numDirections, batchSize, hiddenSize]}
    231    }
    232  },
    233  {
    234    name:
    235        '[lstm] TypeError is expected if the shape of options.initialCellState is incorrect',
    236    input: kExampleInputDescriptor,
    237    weight: kExampleWeightDescriptor,
    238    recurrentWeight: kExampleRecurrentWeightDescriptor,
    239    steps: steps,
    240    hiddenSize: hiddenSize,
    241    options: {
    242      initialCellState:
    243          {dataType: 'float32', shape: [numDirections, batchSize, 1000]}
    244    }
    245  }
    246 ];
    247 
    248 tests.forEach(
    249    test => promise_test(async t => {
    250      const builder = new MLGraphBuilder(context);
    251      const input = builder.input('input', test.input);
    252      const weight = builder.input('weight', test.weight);
    253      const recurrentWeight =
    254          builder.input('recurrentWeight', test.recurrentWeight);
    255 
    256      const options = {};
    257      if (test.options) {
    258        if (test.options.bias) {
    259          options.bias = builder.input('bias', test.options.bias);
    260        }
    261        if (test.options.recurrentBias) {
    262          options.recurrentBias =
    263              builder.input('recurrentBias', test.options.recurrentBias);
    264        }
    265        if (test.options.peepholeWeight) {
    266          options.peepholeWeight =
    267              builder.input('peepholeWeight', test.options.peepholeWeight);
    268        }
    269        if (test.options.initialHiddenState) {
    270          options.initialHiddenState = builder.input(
    271              'initialHiddenState', test.options.initialHiddenState);
    272        }
    273        if (test.options.initialCellState) {
    274          options.initialCellState =
    275              builder.input('initialCellState', test.options.initialCellState);
    276        }
    277        if (test.options.returnSequence) {
    278          options.returnSequence = test.options.returnSequence;
    279        }
    280        if (test.options.direction) {
    281          options.direction = test.options.direction;
    282        }
    283        if (test.options.layout) {
    284          options.layout = test.options.layout;
    285        }
    286        if (test.options.activations) {
    287          options.activations = test.options.activations;
    288        }
    289      }
    290 
    291      if (test.outputs &&
    292          context.opSupportLimits().lstm.input.dataTypes.includes(
    293              test.input.dataType)) {
    294        const outputs = builder.lstm(
    295            input, weight, recurrentWeight, test.steps, test.hiddenSize,
    296            options);
    297        assert_equals(outputs.length, test.outputs.length);
    298        for (let i = 0; i < outputs.length; ++i) {
    299          assert_equals(outputs[i].dataType, test.outputs[i].dataType);
    300          assert_array_equals(outputs[i].shape, test.outputs[i].shape);
    301        }
    302      } else {
    303        const label = 'lstm_xxx';
    304        options.label = label;
    305        const regrexp = new RegExp('\\[' + label + '\\]');
    306        assert_throws_with_label(
    307            () => builder.lstm(
    308                input, weight, recurrentWeight, test.steps, test.hiddenSize,
    309                options),
    310            regrexp);
    311      }
    312    }, test.name));
    313 
    314 multi_builder_test(async (t, builder, otherBuilder) => {
    315  const inputFromOtherBuilder =
    316      otherBuilder.input('input', kExampleInputDescriptor);
    317  const weight = builder.input('weight', kExampleWeightDescriptor);
    318  const recurrentWeight =
    319      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    320 
    321  assert_throws_js(
    322      TypeError,
    323      () => builder.lstm(
    324          inputFromOtherBuilder, weight, recurrentWeight, steps, hiddenSize));
    325 }, '[lstm] throw if input is from another builder');
    326 
    327 multi_builder_test(async (t, builder, otherBuilder) => {
    328  const input = builder.input('input', kExampleInputDescriptor);
    329  const weightFromOtherBuilder =
    330      otherBuilder.input('weight', kExampleWeightDescriptor);
    331  const recurrentWeight =
    332      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    333 
    334  assert_throws_js(
    335      TypeError,
    336      () => builder.lstm(
    337          input, weightFromOtherBuilder, recurrentWeight, steps, hiddenSize));
    338 }, '[lstm] throw if weight is from another builder');
    339 
    340 
    341 multi_builder_test(async (t, builder, otherBuilder) => {
    342  const input = builder.input('input', kExampleInputDescriptor);
    343  const weight = builder.input('weight', kExampleWeightDescriptor);
    344  const recurrentWeightFromOtherBuilder =
    345      otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    346 
    347  assert_throws_js(
    348      TypeError,
    349      () => builder.lstm(
    350          input, weight, recurrentWeightFromOtherBuilder, steps, hiddenSize));
    351 }, '[lstm] throw if recurrentWeight is from another builder');
    352 
    353 multi_builder_test(async (t, builder, otherBuilder) => {
    354  const biasFromOtherBuilder =
    355      otherBuilder.input('bias', kExampleBiasDescriptor);
    356  const options = {bias: biasFromOtherBuilder};
    357 
    358  const input = builder.input('input', kExampleInputDescriptor);
    359  const weight = builder.input('weight', kExampleWeightDescriptor);
    360  const recurrentWeight =
    361      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    362  assert_throws_js(
    363      TypeError,
    364      () => builder.lstm(
    365          input, weight, recurrentWeight, steps, hiddenSize, options));
    366 }, '[lstm] throw if bias option is from another builder');
    367 
    368 multi_builder_test(async (t, builder, otherBuilder) => {
    369  const recurrentBiasFromOtherBuilder =
    370      otherBuilder.input('bias', kExampleBiasDescriptor);
    371  const options = {recurrentBias: recurrentBiasFromOtherBuilder};
    372 
    373  const input = builder.input('input', kExampleInputDescriptor);
    374  const weight = builder.input('weight', kExampleWeightDescriptor);
    375  const recurrentWeight =
    376      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    377  assert_throws_js(
    378      TypeError,
    379      () => builder.lstm(
    380          input, weight, recurrentWeight, steps, hiddenSize, options));
    381 }, '[lstm] throw if recurrentBias option is from another builder');
    382 
    383 multi_builder_test(async (t, builder, otherBuilder) => {
    384  const peepholeWeightFromOtherBuilder =
    385      otherBuilder.input('peepholeWeight', kExamplePeepholeWeightDescriptor);
    386  const options = {peepholeWeight: peepholeWeightFromOtherBuilder};
    387 
    388  const input = builder.input('input', kExampleInputDescriptor);
    389  const weight = builder.input('weight', kExampleWeightDescriptor);
    390  const recurrentWeight =
    391      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    392  assert_throws_js(
    393      TypeError,
    394      () => builder.lstm(
    395          input, weight, recurrentWeight, steps, hiddenSize, options));
    396 }, '[lstm] throw if peepholeWeight option is from another builder');
    397 
    398 multi_builder_test(async (t, builder, otherBuilder) => {
    399  const initialHiddenStateFromOtherBuilder = otherBuilder.input(
    400      'initialHiddenState', kExampleInitialHiddenStateDescriptor);
    401  const options = {initialHiddenState: initialHiddenStateFromOtherBuilder};
    402 
    403  const input = builder.input('input', kExampleInputDescriptor);
    404  const weight = builder.input('weight', kExampleWeightDescriptor);
    405  const recurrentWeight =
    406      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    407  assert_throws_js(
    408      TypeError,
    409      () => builder.lstm(
    410          input, weight, recurrentWeight, steps, hiddenSize, options));
    411 }, '[lstm] throw if initialHiddenState option is from another builder');
    412 
    413 multi_builder_test(async (t, builder, otherBuilder) => {
    414  const initialCellStateFromOtherBuilder = otherBuilder.input(
    415      'initialCellState', kExampleInitialHiddenStateDescriptor);
    416  const options = {initialCellState: initialCellStateFromOtherBuilder};
    417 
    418  const input = builder.input('input', kExampleInputDescriptor);
    419  const weight = builder.input('weight', kExampleWeightDescriptor);
    420  const recurrentWeight =
    421      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    422  assert_throws_js(
    423      TypeError,
    424      () => builder.lstm(
    425          input, weight, recurrentWeight, steps, hiddenSize, options));
    426 }, '[lstm] throw if initialCellState option is from another builder');