tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

lstmCell.https.any.js (24492B)


      1 // META: title=validation tests for WebNN API lstmCell operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const batchSize = 3, inputSize = 4, hiddenSize = 5;
     11 
     12 // Dimensions required of required inputs.
     13 const kValidInputShape = [batchSize, inputSize];
     14 const kValidWeightShape = [4 * hiddenSize, inputSize];
     15 const kValidRecurrentWeightShape = [4 * hiddenSize, hiddenSize];
     16 const kValidHiddenStateShape = [batchSize, hiddenSize];
     17 const kValidCellStateShape = [batchSize, hiddenSize];
     18 // Dimensions required of optional inputs.
     19 const kValidBiasShape = [4 * hiddenSize];
     20 const kValidPeepholeWeightShape = [3 * hiddenSize];
     21 
     22 // Example descriptors which are valid according to the above dimensions.
     23 const kExampleInputDescriptor = {
     24  dataType: 'float32',
     25  shape: kValidInputShape
     26 };
     27 const kExampleWeightDescriptor = {
     28  dataType: 'float32',
     29  shape: kValidWeightShape
     30 };
     31 const kExampleRecurrentWeightDescriptor = {
     32  dataType: 'float32',
     33  shape: kValidRecurrentWeightShape
     34 };
     35 const kExampleHiddenStateDescriptor = {
     36  dataType: 'float32',
     37  shape: kValidHiddenStateShape
     38 };
     39 const kExampleCellStateDescriptor = {
     40  dataType: 'float32',
     41  shape: kValidCellStateShape
     42 };
     43 const kExampleBiasDescriptor = {
     44  dataType: 'float32',
     45  shape: kValidBiasShape
     46 };
     47 const kExamplePeepholeWeightDescriptor = {
     48  dataType: 'float32',
     49  shape: kValidPeepholeWeightShape
     50 };
     51 
     52 multi_builder_test(async (t, builder, otherBuilder) => {
     53  const inputFromOtherBuilder =
     54      otherBuilder.input('input', kExampleInputDescriptor);
     55 
     56  const weight = builder.input('weight', kExampleWeightDescriptor);
     57  const recurrentWeight =
     58      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
     59  const hiddenState =
     60      builder.input('hiddenState', kExampleHiddenStateDescriptor);
     61  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
     62  assert_throws_js(
     63      TypeError,
     64      () => builder.lstmCell(
     65          inputFromOtherBuilder, weight, recurrentWeight, hiddenState,
     66          cellState, hiddenSize));
     67 }, '[lstmCell] throw if input is from another builder');
     68 
     69 multi_builder_test(async (t, builder, otherBuilder) => {
     70  const weightFromOtherBuilder =
     71      otherBuilder.input('weight', kExampleWeightDescriptor);
     72 
     73  const input = builder.input('input', kExampleInputDescriptor);
     74  const recurrentWeight =
     75      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
     76  const hiddenState =
     77      builder.input('hiddenState', kExampleHiddenStateDescriptor);
     78  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
     79  assert_throws_js(
     80      TypeError,
     81      () => builder.lstmCell(
     82          input, weightFromOtherBuilder, recurrentWeight, hiddenState,
     83          cellState, hiddenSize));
     84 }, '[lstmCell] throw if weight is from another builder');
     85 
     86 multi_builder_test(async (t, builder, otherBuilder) => {
     87  const recurrentWeightFromOtherBuilder =
     88      otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
     89 
     90  const input = builder.input('input', kExampleInputDescriptor);
     91  const weight = builder.input('weight', kExampleWeightDescriptor);
     92  const hiddenState =
     93      builder.input('hiddenState', kExampleHiddenStateDescriptor);
     94  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
     95  assert_throws_js(
     96      TypeError,
     97      () => builder.lstmCell(
     98          input, weight, recurrentWeightFromOtherBuilder, hiddenState,
     99          cellState, hiddenSize));
    100 }, '[lstmCell] throw if recurrentWeight is from another builder');
    101 
    102 multi_builder_test(async (t, builder, otherBuilder) => {
    103  const hiddenStateFromOtherBuilder =
    104      otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor);
    105 
    106  const input = builder.input('input', kExampleInputDescriptor);
    107  const weight = builder.input('weight', kExampleWeightDescriptor);
    108  const recurrentWeight =
    109      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    110  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
    111  assert_throws_js(
    112      TypeError,
    113      () => builder.lstmCell(
    114          input, weight, recurrentWeight, hiddenStateFromOtherBuilder,
    115          cellState, hiddenSize));
    116 }, '[lstmCell] throw if hiddenState is from another builder');
    117 
    118 multi_builder_test(async (t, builder, otherBuilder) => {
    119  const cellStateFromOtherBuilder =
    120      otherBuilder.input('cellState', kExampleCellStateDescriptor);
    121 
    122  const input = builder.input('input', kExampleInputDescriptor);
    123  const weight = builder.input('weight', kExampleWeightDescriptor);
    124  const recurrentWeight =
    125      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    126  const hiddenState =
    127      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    128  assert_throws_js(
    129      TypeError,
    130      () => builder.lstmCell(
    131          input, weight, recurrentWeight, hiddenState,
    132          cellStateFromOtherBuilder, hiddenSize));
    133 }, '[lstmCell] throw if cellState is from another builder');
    134 
    135 multi_builder_test(async (t, builder, otherBuilder) => {
    136  const biasFromOtherBuilder =
    137      otherBuilder.input('bias', kExampleBiasDescriptor);
    138  const options = {bias: biasFromOtherBuilder};
    139 
    140  const input = builder.input('input', kExampleInputDescriptor);
    141  const weight = builder.input('weight', kExampleWeightDescriptor);
    142  const recurrentWeight =
    143      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    144  const hiddenState =
    145      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    146  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
    147  assert_throws_js(
    148      TypeError,
    149      () => builder.lstmCell(
    150          input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
    151          options));
    152 }, '[lstmCell] throw if bias option is from another builder');
    153 
    154 multi_builder_test(async (t, builder, otherBuilder) => {
    155  const recurrentBiasFromOtherBuilder =
    156      otherBuilder.input('bias', kExampleBiasDescriptor);
    157  const options = {recurrentBias: recurrentBiasFromOtherBuilder};
    158 
    159  const input = builder.input('input', kExampleInputDescriptor);
    160  const weight = builder.input('weight', kExampleWeightDescriptor);
    161  const recurrentWeight =
    162      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    163  const hiddenState =
    164      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    165  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
    166  assert_throws_js(
    167      TypeError,
    168      () => builder.lstmCell(
    169          input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
    170          options));
    171 }, '[lstmCell] throw if recurrentBias option is from another builder');
    172 
    173 multi_builder_test(async (t, builder, otherBuilder) => {
    174  const peepholeWeightFromOtherBuilder =
    175      otherBuilder.input('peepholeWeight', kExamplePeepholeWeightDescriptor);
    176  const options = {peepholeWeight: peepholeWeightFromOtherBuilder};
    177 
    178  const input = builder.input('input', kExampleInputDescriptor);
    179  const weight = builder.input('weight', kExampleWeightDescriptor);
    180  const recurrentWeight =
    181      builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
    182  const hiddenState =
    183      builder.input('hiddenState', kExampleHiddenStateDescriptor);
    184  const cellState = builder.input('cellState', kExampleCellStateDescriptor);
    185  assert_throws_js(
    186      TypeError,
    187      () => builder.lstmCell(
    188          input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
    189          options));
    190 }, '[lstmCell] throw if peepholeWeight option is from another builder');
    191 
    192 const tests = [
    193  {
    194    name: '[lstmCell] Test with default options',
    195    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    196    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    197    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    198    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    199    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    200    hiddenSize: hiddenSize,
    201    outputs: [
    202      {dataType: 'float16', shape: [batchSize, hiddenSize]},
    203      {dataType: 'float16', shape: [batchSize, hiddenSize]}
    204    ]
    205  },
    206  {
    207    name: '[lstmCell] Test with given options',
    208    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    209    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    210    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    211    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    212    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    213    hiddenSize: hiddenSize,
    214    options: {
    215      bias: {dataType: 'float32', shape: [4 * hiddenSize]},
    216      recurrentBias: {dataType: 'float32', shape: [4 * hiddenSize]},
    217      peepholeWeight: {dataType: 'float32', shape: [3 * hiddenSize]},
    218      layout: 'ifgo',
    219      activations: ['sigmoid', 'relu', 'tanh']
    220    },
    221    outputs: [
    222      {dataType: 'float32', shape: [batchSize, hiddenSize]},
    223      {dataType: 'float32', shape: [batchSize, hiddenSize]}
    224    ]
    225  },
    226  {
    227    name: '[lstmCell] Throw if hiddenSize is equal to zero',
    228    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    229    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    230    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    231    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    232    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    233    hiddenSize: 0
    234  },
    235  {
    236    name: '[lstmCell] Throw if hiddenSize is too large',
    237    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    238    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    239    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    240    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    241    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    242    hiddenSize: 4294967295
    243  },
    244  {
    245    name:
    246        '[lstmCell] Throw if the input data type is not one of the floating point types',
    247    input: {dataType: 'uint32', shape: [batchSize, inputSize]},
    248    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    249    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    250    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    251    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    252    hiddenSize: hiddenSize
    253  },
    254  {
    255    name: '[lstmCell] Throw if the rank of input is not 2',
    256    input: {dataType: 'float32', shape: [batchSize]},
    257    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    258    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    259    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    260    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    261    hiddenSize: hiddenSize
    262  },
    263  {
    264    name: '[lstmCell] Throw if the shape of input is incorrect',
    265    input: {dataType: 'float32', shape: [batchSize, 1000]},
    266    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    267    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    268    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    269    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    270    hiddenSize: hiddenSize
    271  },
    272  {
    273    name: '[lstmCell] Throw if the data type of weight is incorrect',
    274    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    275    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    276    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    277    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    278    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    279    hiddenSize: hiddenSize
    280  },
    281  {
    282    name: '[lstmCell] Throw if the rank of weight is not 2',
    283    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    284    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize, 1000]},
    285    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    286    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    287    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    288    hiddenSize: hiddenSize
    289  },
    290  {
    291    name: '[lstmCell] Throw if the shape of weight is incorrect',
    292    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    293    weight: {dataType: 'float32', shape: [1000, inputSize]},
    294    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    295    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    296    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    297    hiddenSize: hiddenSize
    298  },
    299  {
    300    name: '[lstmCell] Throw if the data type of recurrentWeight is incorrect',
    301    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    302    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    303    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    304    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    305    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    306    hiddenSize: hiddenSize
    307  },
    308  {
    309    name: '[lstmCell] Throw if the rank of recurrentWeight is not 2',
    310    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    311    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    312    recurrentWeight:
    313        {dataType: 'float32', shape: [1000, 4 * hiddenSize, hiddenSize]},
    314    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    315    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    316    hiddenSize: hiddenSize
    317  },
    318  {
    319    name: '[lstmCell] Throw if the shape of recurrentWeight is incorrect',
    320    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    321    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    322    recurrentWeight: {dataType: 'float32', shape: [1000, hiddenSize]},
    323    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    324    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    325    hiddenSize: hiddenSize
    326  },
    327  {
    328    name: '[lstmCell] Throw if the data type of hiddenState is incorrect',
    329    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    330    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    331    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    332    hiddenState: {dataType: 'int64', shape: [batchSize, hiddenSize]},
    333    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    334    hiddenSize: hiddenSize
    335  },
    336  {
    337    name: '[lstmCell] Throw if the rank of hiddenState is not 2',
    338    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    339    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    340    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    341    hiddenState: {dataType: 'float32', shape: [batchSize]},
    342    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    343    hiddenSize: hiddenSize
    344  },
    345  {
    346    name: '[lstmCell] Throw if the shape of hiddenState is incorrect',
    347    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    348    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    349    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    350    hiddenState: {dataType: 'float32', shape: [batchSize, 1000]},
    351    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    352    hiddenSize: hiddenSize
    353  },
    354  {
    355    name: '[lstmCell] Throw if the data type of cellState is incorrect',
    356    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    357    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    358    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    359    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    360    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    361    hiddenSize: hiddenSize
    362  },
    363  {
    364    name: '[lstmCell] Throw if the rank of cellState is not 2',
    365    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    366    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    367    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    368    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    369    cellState: {dataType: 'float32', shape: [batchSize]},
    370    hiddenSize: hiddenSize
    371  },
    372  {
    373    name: '[lstmCell] Throw if the shape of cellState is incorrect',
    374    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    375    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    376    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    377    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    378    cellState: {dataType: 'float16', shape: [batchSize, 1000]},
    379    hiddenSize: hiddenSize
    380  },
    381  {
    382    name: '[lstmCell] Throw if the data type of options.bias is incorrect',
    383    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    384    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    385    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    386    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    387    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    388    hiddenSize: hiddenSize,
    389    options: {bias: {dataType: 'int8', shape: [4 * hiddenSize]}}
    390  },
    391  {
    392    name: '[lstmCell] Throw if the rank of options.bias is not 1',
    393    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    394    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    395    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    396    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    397    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    398    hiddenSize: hiddenSize,
    399    options: {bias: {dataType: 'float16', shape: [4 * hiddenSize, 1000]}}
    400  },
    401  {
    402    name: '[lstmCell] Throw if the shape of options.bias is incorrect',
    403    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    404    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    405    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    406    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    407    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    408    hiddenSize: hiddenSize,
    409    options: {bias: {dataType: 'float16', shape: [1000]}}
    410  },
    411  {
    412    name:
    413        '[lstmCell] Throw if the data type of options.recurrentBias is incorrect',
    414    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    415    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    416    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    417    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    418    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    419    hiddenSize: hiddenSize,
    420    options: {recurrentBias: {dataType: 'uint8', shape: [4 * hiddenSize]}}
    421  },
    422  {
    423    name: '[lstmCell] Throw if the rank of options.recurrentBias is not 1',
    424    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    425    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    426    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    427    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    428    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    429    hiddenSize: hiddenSize,
    430    options:
    431        {recurrentBias: {dataType: 'float16', shape: [4 * hiddenSize, 1000]}}
    432  },
    433  {
    434    name: '[lstmCell] Throw if the shape of options.recurrentBias is incorrect',
    435    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    436    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    437    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    438    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    439    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    440    hiddenSize: hiddenSize,
    441    options: {recurrentBias: {dataType: 'float16', shape: [1000]}}
    442  },
    443  {
    444    name:
    445        '[lstmCell] Throw if the data type of options.peepholeWeight is incorrect',
    446    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    447    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    448    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    449    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    450    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    451    hiddenSize: hiddenSize,
    452    options: {peepholeWeight: {dataType: 'float32', shape: [3 * hiddenSize]}}
    453  },
    454  {
    455    name: '[lstmCell] Throw if the rank of options.peepholeWeight is not 1',
    456    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    457    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    458    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    459    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    460    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    461    hiddenSize: hiddenSize,
    462    options: {peepholeWeight: {dataType: 'float16', shape: []}}
    463  },
    464  {
    465    name:
    466        '[lstmCell] Throw if the shape of options.peepholeWeight is incorrect',
    467    input: {dataType: 'float16', shape: [batchSize, inputSize]},
    468    weight: {dataType: 'float16', shape: [4 * hiddenSize, inputSize]},
    469    recurrentWeight: {dataType: 'float16', shape: [4 * hiddenSize, hiddenSize]},
    470    hiddenState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    471    cellState: {dataType: 'float16', shape: [batchSize, hiddenSize]},
    472    hiddenSize: hiddenSize,
    473    options: {peepholeWeight: {dataType: 'float16', shape: [1000]}}
    474  },
    475  {
    476    name: '[lstmCell] Throw if the size of options.activations is not 3',
    477    input: {dataType: 'float32', shape: [batchSize, inputSize]},
    478    weight: {dataType: 'float32', shape: [4 * hiddenSize, inputSize]},
    479    recurrentWeight: {dataType: 'float32', shape: [4 * hiddenSize, hiddenSize]},
    480    hiddenState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    481    cellState: {dataType: 'float32', shape: [batchSize, hiddenSize]},
    482    hiddenSize: hiddenSize,
    483    options: {activations: ['sigmoid', 'tanh', 'sigmoid', 'tanh']}
    484  }
    485 ];
    486 
    487 tests.forEach(
    488    test => promise_test(async t => {
    489      const builder = new MLGraphBuilder(context);
    490      const input = builder.input('input', test.input);
    491      const weight = builder.input('weight', test.weight);
    492      const recurrentWeight =
    493          builder.input('recurrentWeight', test.recurrentWeight);
    494      const hiddenState = builder.input('hiddenState', test.hiddenState);
    495      const cellState = builder.input('cellState', test.cellState);
    496 
    497      const options = {};
    498      if (test.options) {
    499        if (test.options.bias) {
    500          options.bias = builder.input('bias', test.options.bias);
    501        }
    502        if (test.options.recurrentBias) {
    503          options.recurrentBias =
    504              builder.input('recurrentBias', test.options.recurrentBias);
    505        }
    506        if (test.options.peepholeWeight) {
    507          options.peepholeWeight =
    508              builder.input('peepholeWeight', test.options.peepholeWeight);
    509        }
    510        if (test.options.layout) {
    511          options.layout = test.options.layout;
    512        }
    513        if (test.options.activations) {
    514          options.activations = test.options.activations;
    515        }
    516      }
    517 
    518      if (test.outputs &&
    519          context.opSupportLimits().lstmCell.input.dataTypes.includes(
    520              test.input.dataType)) {
    521        const outputs = builder.lstmCell(
    522            input, weight, recurrentWeight, hiddenState, cellState,
    523            test.hiddenSize, options);
    524        assert_equals(outputs.length, test.outputs.length);
    525        for (let i = 0; i < outputs.length; ++i) {
    526          assert_equals(outputs[i].dataType, test.outputs[i].dataType);
    527          assert_array_equals(outputs[i].shape, test.outputs[i].shape);
    528        }
    529      } else {
    530        const label = 'lstm_cell_xxx';
    531        options.label = label;
    532        const regrexp = new RegExp('\\[' + label + '\\]');
    533        assert_throws_with_label(
    534            () => builder.lstmCell(
    535                input, weight, recurrentWeight, hiddenState, cellState,
    536                test.hiddenSize, options),
    537            regrexp);
    538      }
    539    }, test.name));