gemm.https.any.js (5279B)
1 // META: title=validation tests for WebNN API gemm operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const label = 'gemm_xxx'; 11 const kExampleInputDescriptor = { 12 dataType: 'float32', 13 shape: [2, 2] 14 }; 15 16 validateTwoInputsFromMultipleBuilders('gemm'); 17 validateTwoBroadcastableInputsTensorLimit('gemm', label); 18 19 multi_builder_test(async (t, builder, otherBuilder) => { 20 const cFromOtherBuilder = otherBuilder.input('c', kExampleInputDescriptor); 21 const options = {c: cFromOtherBuilder}; 22 23 const a = builder.input('a', kExampleInputDescriptor); 24 const b = builder.input('b', kExampleInputDescriptor); 25 assert_throws_js(TypeError, () => builder.gemm(a, b, options)); 26 }, '[gemm] throw if c option is from another builder'); 27 28 const tests = [ 29 { 30 name: '[gemm] Test building gemm with default option.', 31 a: {dataType: 'float32', shape: [2, 3]}, 32 b: {dataType: 'float32', shape: [3, 4]}, 33 output: {dataType: 'float32', shape: [2, 4]} 34 }, 35 { 36 name: 37 '[gemm] Throw if inputShapeA[1] is not equal to inputShapeB[0] default options.', 38 a: {dataType: 'float32', shape: [2, 3]}, 39 b: {dataType: 'float32', shape: [2, 4]}, 40 options: {label} 41 }, 42 { 43 name: '[gemm] Test building gemm with aTranspose=true.', 44 a: {dataType: 'float32', shape: [2, 3]}, 45 b: {dataType: 'float32', shape: [2, 4]}, 46 options: { 47 aTranspose: true, 48 }, 49 output: {dataType: 'float32', shape: [3, 4]} 50 }, 51 { 52 name: 53 '[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with aTranspose=true.', 54 a: {dataType: 'float32', shape: [2, 3]}, 55 b: {dataType: 'float32', shape: [3, 4]}, 56 options: { 57 aTranspose: true, 58 label: label, 59 }, 60 }, 61 { 62 name: '[gemm] Test building gemm with bTranspose=true.', 63 a: {dataType: 'float32', shape: [2, 3]}, 64 b: {dataType: 'float32', shape: [4, 3]}, 65 options: { 66 bTranspose: true, 67 }, 68 output: {dataType: 'float32', shape: [2, 4]} 69 }, 70 { 71 name: 72 '[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with bTranspose=true.', 73 a: {dataType: 'float32', shape: [2, 3]}, 74 b: {dataType: 'float32', shape: [3, 4]}, 75 options: { 76 bTranspose: true, 77 label: label, 78 }, 79 }, 80 { 81 name: '[gemm] Throw if the rank of inputA is not 2.', 82 a: {dataType: 'float32', shape: [2, 3, 1]}, 83 b: {dataType: 'float32', shape: [2, 4]}, 84 options: {label} 85 }, 86 { 87 name: '[gemm] Throw if the rank of inputB is not 2.', 88 a: {dataType: 'float32', shape: [2, 4]}, 89 b: {dataType: 'float32', shape: [2, 3, 1]}, 90 options: {label} 91 }, 92 { 93 name: '[gemm] Throw if data types of two inputs do not match.', 94 a: {dataType: 'float32', shape: [2, 3]}, 95 b: {dataType: 'float16', shape: [3, 4]}, 96 options: {label} 97 }, 98 { 99 name: '[gemm] Test building gemm with inputC.', 100 a: {dataType: 'float32', shape: [2, 3]}, 101 b: {dataType: 'float32', shape: [3, 4]}, 102 options: { 103 c: {dataType: 'float32', shape: [4]}, 104 }, 105 output: {dataType: 'float32', shape: [2, 4]} 106 }, 107 { 108 name: '[gemm] Test building gemm with scalar inputC.', 109 a: {dataType: 'float32', shape: [2, 3]}, 110 b: {dataType: 'float32', shape: [3, 4]}, 111 options: { 112 c: {dataType: 'float32', shape: []}, 113 }, 114 output: {dataType: 'float32', shape: [2, 4]} 115 }, 116 { 117 name: 118 '[gemm] Throw if inputShapeC is not unidirectionally broadcastable to the output shape [inputShapeA[0], inputShapeB[1]].', 119 a: {dataType: 'float32', shape: [2, 3]}, 120 b: {dataType: 'float32', shape: [3, 4]}, 121 options: { 122 c: {dataType: 'float32', shape: [2, 3]}, 123 label: label, 124 }, 125 }, 126 { 127 name: '[gemm] Throw if the input data type is not floating point.', 128 a: {dataType: 'int32', shape: [2, 3]}, 129 b: {dataType: 'int32', shape: [3, 4]}, 130 options: {label} 131 }, 132 { 133 name: 134 '[gemm] Throw if data type of inputC does not match ones of inputA and inputB.', 135 a: {dataType: 'float32', shape: [3, 2]}, 136 b: {dataType: 'float32', shape: [4, 3]}, 137 options: { 138 c: {dataType: 'float16', shape: [2, 4]}, 139 aTranspose: true, 140 bTranspose: true, 141 label: label, 142 }, 143 }, 144 { 145 name: '[gemm] Throw if the rank of inputC is 3.', 146 a: {dataType: 'float32', shape: [3, 2]}, 147 b: {dataType: 'float32', shape: [4, 3]}, 148 options: { 149 c: {dataType: 'float32', shape: [2, 3, 4]}, 150 aTranspose: true, 151 bTranspose: true, 152 label: label, 153 }, 154 }, 155 ]; 156 157 tests.forEach( 158 test => promise_test(async t => { 159 const builder = new MLGraphBuilder(context); 160 const a = builder.input('a', test.a); 161 const b = builder.input('b', test.b); 162 if (test.options && test.options.c) { 163 test.options.c = builder.input('c', test.options.c); 164 } 165 if (test.output) { 166 const output = builder.gemm(a, b, test.options); 167 assert_equals(output.dataType, test.output.dataType); 168 assert_array_equals(output.shape, test.output.shape); 169 } else { 170 const regrexp = new RegExp('\\[' + label + '\\]'); 171 assert_throws_with_label( 172 () => builder.gemm(a, b, test.options), regrexp); 173 } 174 }, test.name));