tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

matmul.https.any.js (3992B)


      1 // META: title=validation tests for WebNN API matmul operation
      2 // META: global=window
      3 // META: variant=?cpu
      4 // META: variant=?gpu
      5 // META: variant=?npu
      6 // META: script=../resources/utils_validation.js
      7 
      8 'use strict';
      9 
     10 const label = 'matmul_123';
     11 validateTwoInputsFromMultipleBuilders('matmul');
     12 validateTwoBroadcastableInputsTensorLimit('matmul', label);
     13 
     14 const tests = [
     15  {
     16    name: '[matmul] Throw if first input\'s rank is less than 2',
     17    inputs: {
     18      a: {dataType: 'float32', shape: [2]},
     19      b: {dataType: 'float32', shape: [2, 2]}
     20    }
     21  },
     22  {
     23    name: '[matmul] Throw if second input\'s rank is less than 2',
     24    inputs: {
     25      a: {dataType: 'float32', shape: [2, 2]},
     26      b: {dataType: 'float32', shape: [2]}
     27    }
     28  },
     29  {
     30    name: '[matmul] Test with 2-D input and 4-D input',
     31    inputs: {
     32      a: {dataType: 'float32', shape: [1, 4]},
     33      b: {dataType: 'float32', shape: [2, 2, 4, 2]}
     34    },
     35    output: {dataType: 'float32', shape: [2, 2, 1, 2]}
     36  },
     37  {
     38    name: '[matmul] Test with 2-D input and 2-D input',
     39    inputs: {
     40      a: {dataType: 'float32', shape: [4, 2]},
     41      b: {dataType: 'float32', shape: [2, 3]}
     42    },
     43    output: {dataType: 'float32', shape: [4, 3]}
     44  },
     45  {
     46    // batchShape is a clone of inputShape with the spatial dimensions
     47    // (last 2 items) removed.
     48    name:
     49        '[matmul] Test with 3-D input and 3-D input of broadcastable batchShape',
     50    inputs: {
     51      a: {dataType: 'float32', shape: [2, 3, 4]},
     52      b: {dataType: 'float32', shape: [1, 4, 1]}
     53    },
     54    output: {dataType: 'float32', shape: [2, 3, 1]}
     55  },
     56  {
     57    // batchShape is a clone of inputShape with the spatial dimensions
     58    // (last 2 items) removed.
     59    name:
     60        '[matmul] Test with 4-D input and 3-D input of broadcastable batchShape',
     61    inputs: {
     62      a: {dataType: 'float32', shape: [2, 2, 3, 4]},
     63      b: {dataType: 'float32', shape: [1, 4, 5]}
     64    },
     65    output: {dataType: 'float32', shape: [2, 2, 3, 5]}
     66  },
     67  {
     68    name: '[matmul] Test with 3-D input and 3-D input',
     69    inputs: {
     70      a: {dataType: 'float32', shape: [2, 3, 4]},
     71      b: {dataType: 'float32', shape: [2, 4, 5]}
     72    },
     73    output: {dataType: 'float32', shape: [2, 3, 5]}
     74  },
     75  {
     76    name: '[matmul] Throw if the input data type is not floating point',
     77    inputs: {
     78      a: {dataType: 'uint32', shape: [2, 3, 4]},
     79      b: {dataType: 'uint32', shape: [2, 4, 5]}
     80    }
     81  },
     82  {
     83    name: '[matmul] Throw if data type of two inputs don\'t match',
     84    inputs: {
     85      a: {dataType: 'float32', shape: [2, 3, 4]},
     86      b: {dataType: 'float16', shape: [2, 4, 5]}
     87    }
     88  },
     89  {
     90    name:
     91        '[matmul] Throw if columns of first input\'s shape doesn\'t match the rows of second input\'s shape',
     92    inputs: {
     93      a: {dataType: 'float32', shape: /* [rows, columns] */[2, 3]},
     94      b: {dataType: 'float32', shape: /* [rows, columns] */[2, 4]}
     95    },
     96  },
     97  {
     98    // batchShape is a clone of inputShape with the spatial dimensions
     99    // (last 2 items) removed.
    100    name: '[matmul] Throw if batchShapes aren\'t bidirectionally broadcastable',
    101    inputs: {
    102      a: {dataType: 'float32', shape: [3, 3, 4]},
    103      b: {dataType: 'float32', shape: [2, 4, 1]}
    104    },
    105  },
    106 ];
    107 
    108 tests.forEach(test => promise_test(async t => {
    109                const builder = new MLGraphBuilder(context);
    110                const inputA = builder.input('a', test.inputs.a);
    111                const inputB = builder.input('b', test.inputs.b);
    112                if (test.output) {
    113                  const output = builder.matmul(inputA, inputB);
    114                  assert_equals(output.dataType, test.output.dataType);
    115                  assert_array_equals(output.shape, test.output.shape);
    116                } else {
    117                  const options = {label};
    118                  const regrexp = new RegExp('\\[' + label + '\\]');
    119                  assert_throws_with_label(
    120                      () => builder.matmul(inputA, inputB, options), regrexp);
    121                }
    122              }, test.name));