tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gemm.https.any.js (5279B)


      1 // META: title=validation tests for WebNN API gemm operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const label = 'gemm_xxx';
     11 const kExampleInputDescriptor = {
     12  dataType: 'float32',
     13  shape: [2, 2]
     14 };
     15 
     16 validateTwoInputsFromMultipleBuilders('gemm');
     17 validateTwoBroadcastableInputsTensorLimit('gemm', label);
     18 
     19 multi_builder_test(async (t, builder, otherBuilder) => {
     20  const cFromOtherBuilder = otherBuilder.input('c', kExampleInputDescriptor);
     21  const options = {c: cFromOtherBuilder};
     22 
     23  const a = builder.input('a', kExampleInputDescriptor);
     24  const b = builder.input('b', kExampleInputDescriptor);
     25  assert_throws_js(TypeError, () => builder.gemm(a, b, options));
     26 }, '[gemm] throw if c option is from another builder');
     27 
     28 const tests = [
     29  {
     30    name: '[gemm] Test building gemm with default option.',
     31    a: {dataType: 'float32', shape: [2, 3]},
     32    b: {dataType: 'float32', shape: [3, 4]},
     33    output: {dataType: 'float32', shape: [2, 4]}
     34  },
     35  {
     36    name:
     37        '[gemm] Throw if inputShapeA[1] is not equal to inputShapeB[0] default options.',
     38    a: {dataType: 'float32', shape: [2, 3]},
     39    b: {dataType: 'float32', shape: [2, 4]},
     40    options: {label}
     41  },
     42  {
     43    name: '[gemm] Test building gemm with aTranspose=true.',
     44    a: {dataType: 'float32', shape: [2, 3]},
     45    b: {dataType: 'float32', shape: [2, 4]},
     46    options: {
     47      aTranspose: true,
     48    },
     49    output: {dataType: 'float32', shape: [3, 4]}
     50  },
     51  {
     52    name:
     53        '[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with aTranspose=true.',
     54    a: {dataType: 'float32', shape: [2, 3]},
     55    b: {dataType: 'float32', shape: [3, 4]},
     56    options: {
     57      aTranspose: true,
     58      label: label,
     59    },
     60  },
     61  {
     62    name: '[gemm] Test building gemm with bTranspose=true.',
     63    a: {dataType: 'float32', shape: [2, 3]},
     64    b: {dataType: 'float32', shape: [4, 3]},
     65    options: {
     66      bTranspose: true,
     67    },
     68    output: {dataType: 'float32', shape: [2, 4]}
     69  },
     70  {
     71    name:
     72        '[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with bTranspose=true.',
     73    a: {dataType: 'float32', shape: [2, 3]},
     74    b: {dataType: 'float32', shape: [3, 4]},
     75    options: {
     76      bTranspose: true,
     77      label: label,
     78    },
     79  },
     80  {
     81    name: '[gemm] Throw if the rank of inputA is not 2.',
     82    a: {dataType: 'float32', shape: [2, 3, 1]},
     83    b: {dataType: 'float32', shape: [2, 4]},
     84    options: {label}
     85  },
     86  {
     87    name: '[gemm] Throw if the rank of inputB is not 2.',
     88    a: {dataType: 'float32', shape: [2, 4]},
     89    b: {dataType: 'float32', shape: [2, 3, 1]},
     90    options: {label}
     91  },
     92  {
     93    name: '[gemm] Throw if data types of two inputs do not match.',
     94    a: {dataType: 'float32', shape: [2, 3]},
     95    b: {dataType: 'float16', shape: [3, 4]},
     96    options: {label}
     97  },
     98  {
     99    name: '[gemm] Test building gemm with inputC.',
    100    a: {dataType: 'float32', shape: [2, 3]},
    101    b: {dataType: 'float32', shape: [3, 4]},
    102    options: {
    103      c: {dataType: 'float32', shape: [4]},
    104    },
    105    output: {dataType: 'float32', shape: [2, 4]}
    106  },
    107  {
    108    name: '[gemm] Test building gemm with scalar inputC.',
    109    a: {dataType: 'float32', shape: [2, 3]},
    110    b: {dataType: 'float32', shape: [3, 4]},
    111    options: {
    112      c: {dataType: 'float32', shape: []},
    113    },
    114    output: {dataType: 'float32', shape: [2, 4]}
    115  },
    116  {
    117    name:
    118        '[gemm] Throw if inputShapeC is not unidirectionally broadcastable to the output shape [inputShapeA[0], inputShapeB[1]].',
    119    a: {dataType: 'float32', shape: [2, 3]},
    120    b: {dataType: 'float32', shape: [3, 4]},
    121    options: {
    122      c: {dataType: 'float32', shape: [2, 3]},
    123      label: label,
    124    },
    125  },
    126  {
    127    name: '[gemm] Throw if the input data type is not floating point.',
    128    a: {dataType: 'int32', shape: [2, 3]},
    129    b: {dataType: 'int32', shape: [3, 4]},
    130    options: {label}
    131  },
    132  {
    133    name:
    134        '[gemm] Throw if data type of inputC does not match ones of inputA and inputB.',
    135    a: {dataType: 'float32', shape: [3, 2]},
    136    b: {dataType: 'float32', shape: [4, 3]},
    137    options: {
    138      c: {dataType: 'float16', shape: [2, 4]},
    139      aTranspose: true,
    140      bTranspose: true,
    141      label: label,
    142    },
    143  },
    144  {
    145    name: '[gemm] Throw if the rank of inputC is 3.',
    146    a: {dataType: 'float32', shape: [3, 2]},
    147    b: {dataType: 'float32', shape: [4, 3]},
    148    options: {
    149      c: {dataType: 'float32', shape: [2, 3, 4]},
    150      aTranspose: true,
    151      bTranspose: true,
    152      label: label,
    153    },
    154  },
    155 ];
    156 
    157 tests.forEach(
    158    test => promise_test(async t => {
    159      const builder = new MLGraphBuilder(context);
    160      const a = builder.input('a', test.a);
    161      const b = builder.input('b', test.b);
    162      if (test.options && test.options.c) {
    163        test.options.c = builder.input('c', test.options.c);
    164      }
    165      if (test.output) {
    166        const output = builder.gemm(a, b, test.options);
    167        assert_equals(output.dataType, test.output.dataType);
    168        assert_array_equals(output.shape, test.output.shape);
    169      } else {
    170        const regrexp = new RegExp('\\[' + label + '\\]');
    171        assert_throws_with_label(
    172            () => builder.gemm(a, b, test.options), regrexp);
    173      }
    174    }, test.name));