matmul.https.any.js (3992B)
1 // META: title=validation tests for WebNN API matmul operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const label = 'matmul_123'; 11 validateTwoInputsFromMultipleBuilders('matmul'); 12 validateTwoBroadcastableInputsTensorLimit('matmul', label); 13 14 const tests = [ 15 { 16 name: '[matmul] Throw if first input\'s rank is less than 2', 17 inputs: { 18 a: {dataType: 'float32', shape: [2]}, 19 b: {dataType: 'float32', shape: [2, 2]} 20 } 21 }, 22 { 23 name: '[matmul] Throw if second input\'s rank is less than 2', 24 inputs: { 25 a: {dataType: 'float32', shape: [2, 2]}, 26 b: {dataType: 'float32', shape: [2]} 27 } 28 }, 29 { 30 name: '[matmul] Test with 2-D input and 4-D input', 31 inputs: { 32 a: {dataType: 'float32', shape: [1, 4]}, 33 b: {dataType: 'float32', shape: [2, 2, 4, 2]} 34 }, 35 output: {dataType: 'float32', shape: [2, 2, 1, 2]} 36 }, 37 { 38 name: '[matmul] Test with 2-D input and 2-D input', 39 inputs: { 40 a: {dataType: 'float32', shape: [4, 2]}, 41 b: {dataType: 'float32', shape: [2, 3]} 42 }, 43 output: {dataType: 'float32', shape: [4, 3]} 44 }, 45 { 46 // batchShape is a clone of inputShape with the spatial dimensions 47 // (last 2 items) removed. 48 name: 49 '[matmul] Test with 3-D input and 3-D input of broadcastable batchShape', 50 inputs: { 51 a: {dataType: 'float32', shape: [2, 3, 4]}, 52 b: {dataType: 'float32', shape: [1, 4, 1]} 53 }, 54 output: {dataType: 'float32', shape: [2, 3, 1]} 55 }, 56 { 57 // batchShape is a clone of inputShape with the spatial dimensions 58 // (last 2 items) removed. 59 name: 60 '[matmul] Test with 4-D input and 3-D input of broadcastable batchShape', 61 inputs: { 62 a: {dataType: 'float32', shape: [2, 2, 3, 4]}, 63 b: {dataType: 'float32', shape: [1, 4, 5]} 64 }, 65 output: {dataType: 'float32', shape: [2, 2, 3, 5]} 66 }, 67 { 68 name: '[matmul] Test with 3-D input and 3-D input', 69 inputs: { 70 a: {dataType: 'float32', shape: [2, 3, 4]}, 71 b: {dataType: 'float32', shape: [2, 4, 5]} 72 }, 73 output: {dataType: 'float32', shape: [2, 3, 5]} 74 }, 75 { 76 name: '[matmul] Throw if the input data type is not floating point', 77 inputs: { 78 a: {dataType: 'uint32', shape: [2, 3, 4]}, 79 b: {dataType: 'uint32', shape: [2, 4, 5]} 80 } 81 }, 82 { 83 name: '[matmul] Throw if data type of two inputs don\'t match', 84 inputs: { 85 a: {dataType: 'float32', shape: [2, 3, 4]}, 86 b: {dataType: 'float16', shape: [2, 4, 5]} 87 } 88 }, 89 { 90 name: 91 '[matmul] Throw if columns of first input\'s shape doesn\'t match the rows of second input\'s shape', 92 inputs: { 93 a: {dataType: 'float32', shape: /* [rows, columns] */[2, 3]}, 94 b: {dataType: 'float32', shape: /* [rows, columns] */[2, 4]} 95 }, 96 }, 97 { 98 // batchShape is a clone of inputShape with the spatial dimensions 99 // (last 2 items) removed. 100 name: '[matmul] Throw if batchShapes aren\'t bidirectionally broadcastable', 101 inputs: { 102 a: {dataType: 'float32', shape: [3, 3, 4]}, 103 b: {dataType: 'float32', shape: [2, 4, 1]} 104 }, 105 }, 106 ]; 107 108 tests.forEach(test => promise_test(async t => { 109 const builder = new MLGraphBuilder(context); 110 const inputA = builder.input('a', test.inputs.a); 111 const inputB = builder.input('b', test.inputs.b); 112 if (test.output) { 113 const output = builder.matmul(inputA, inputB); 114 assert_equals(output.dataType, test.output.dataType); 115 assert_array_equals(output.shape, test.output.shape); 116 } else { 117 const options = {label}; 118 const regrexp = new RegExp('\\[' + label + '\\]'); 119 assert_throws_with_label( 120 () => builder.matmul(inputA, inputB, options), regrexp); 121 } 122 }, test.name));