elementwise-binary.https.any.js (3475B)
1 // META: title=validation tests for WebNN API element-wise binary operations 2 // META: global=window 3 // META: variant=?op=add&device=cpu 4 // META: variant=?op=add&device=gpu 5 // META: variant=?op=add&device=npu 6 // META: variant=?op=sub&device=cpu 7 // META: variant=?op=sub&device=gpu 8 // META: variant=?op=sub&device=npu 9 // META: variant=?op=mul&device=cpu 10 // META: variant=?op=mul&device=gpu 11 // META: variant=?op=mul&device=npu 12 // META: variant=?op=div&device=cpu 13 // META: variant=?op=div&device=gpu 14 // META: variant=?op=div&device=npu 15 // META: variant=?op=max&device=cpu 16 // META: variant=?op=max&device=gpu 17 // META: variant=?op=max&device=npu 18 // META: variant=?op=min&device=cpu 19 // META: variant=?op=min&device=gpu 20 // META: variant=?op=min&device=npu 21 // META: variant=?op=pow&device=cpu 22 // META: variant=?op=pow&device=gpu 23 // META: variant=?op=pow&device=npu 24 // META: script=../resources/utils_validation.js 25 26 'use strict'; 27 28 const queryParams = new URLSearchParams(window.location.search); 29 const operatorName = queryParams.get('op'); 30 31 const label = 'elementwise_binary_op'; 32 const regrexp = new RegExp('\\[' + label + '\\]'); 33 const tests = [ 34 { 35 name: '[binary] Test bidirectionally broadcastable dimensions.', 36 // Both inputs have axes of length one which are expanded 37 // during broadcasting. 38 a: {dataType: 'float32', shape: [8, 1, 6, 1]}, 39 b: {dataType: 'float32', shape: [7, 1, 5]}, 40 output: {dataType: 'float32', shape: [8, 7, 6, 5]} 41 }, 42 { 43 name: '[binary] Test unidirectionally broadcastable dimensions.', 44 // Input a has a single axis of length one which is 45 // expanded during broadcasting. 46 a: {dataType: 'float32', shape: [4, 2, 1]}, 47 b: {dataType: 'float32', shape: [4]}, 48 output: {dataType: 'float32', shape: [4, 2, 4]} 49 }, 50 { 51 name: '[binary] Test scalar broadcasting.', 52 a: {dataType: 'float32', shape: [4, 2, 4]}, 53 b: {dataType: 'float32', shape: []}, 54 output: {dataType: 'float32', shape: [4, 2, 4]} 55 }, 56 { 57 name: '[binary] Throw if the input shapes are not broadcastable.', 58 a: {dataType: 'float32', shape: [4, 2]}, 59 b: {dataType: 'float32', shape: [4]}, 60 }, 61 { 62 name: '[binary] Throw if the input types don\'t match.', 63 a: {dataType: 'float32', shape: [4, 2]}, 64 b: {dataType: 'int32', shape: [1]}, 65 }, 66 ]; 67 68 tests.forEach(test => { 69 promise_test(async t => { 70 const builder = new MLGraphBuilder(context); 71 if (!context.opSupportLimits().input.dataTypes.includes( 72 test.a.dataType)) { 73 assert_throws_js(TypeError, () => builder.input('a', test.a)); 74 return; 75 } 76 if (!context.opSupportLimits().input.dataTypes.includes( 77 test.b.dataType)) { 78 assert_throws_js(TypeError, () => builder.input('b', test.b)); 79 return; 80 } 81 const a = builder.input('a', test.a); 82 const b = builder.input('b', test.b); 83 84 if (test.output) { 85 const output = builder[operatorName](a, b); 86 assert_equals(output.dataType, test.output.dataType); 87 assert_array_equals(output.shape, test.output.shape); 88 } else { 89 const options = {label}; 90 assert_throws_with_label( 91 () => builder[operatorName](a, b, options), regrexp); 92 } 93 }, test.name.replace('[binary]', `[${operatorName}]`)); 94 }); 95 96 validateTwoInputsOfSameDataType(operatorName, label); 97 validateTwoInputsBroadcastable(operatorName, label); 98 validateTwoInputsFromMultipleBuilders(operatorName); 99 validateTwoBroadcastableInputsTensorLimit(operatorName, label);