argMinMax.https.any.js (3287B)
1 // META: title=validation tests for WebNN API argMin/Max operations 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const kArgMinMaxOperators = [ 11 'argMin', 12 'argMax', 13 ]; 14 15 const label = 'arg_min_max_1_!'; 16 17 const tests = [ 18 { 19 name: '[argMin/Max] Test with default options.', 20 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 21 axis: 0, 22 output: {shape: [2, 3, 4]} 23 }, 24 { 25 name: '[argMin/Max] Test with axes=1.', 26 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 27 axis: 1, 28 output: {shape: [1, 3, 4]} 29 }, 30 { 31 name: '[argMin/Max] Test with outputDataType=int32', 32 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 33 axis: 1, 34 options: { 35 outputDataType: 'int32', 36 }, 37 output: {shape: [1, 3, 4]} 38 }, 39 { 40 name: '[argMin/Max] Test with outputDataType=int64', 41 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 42 axis: 1, 43 options: { 44 outputDataType: 'int64', 45 }, 46 output: {shape: [1, 3, 4]} 47 }, 48 { 49 name: 50 '[argMin/Max] Throw if the value in axis is greater than or equal to input rank.', 51 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 52 axis: 4, 53 options: { 54 label: label, 55 }, 56 }, 57 { 58 name: '[argMin/Max] Throw if input is a scalar and axis=0.', 59 input: {dataType: 'float32', shape: []}, 60 axis: 0, 61 options: { 62 label: label, 63 }, 64 }, 65 { 66 name: '[argMin/Max] Throw if outputDataType=float32', 67 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 68 axis: 1, 69 options: {outputDataType: 'float32', label: label} 70 } 71 ]; 72 73 function runTests(operatorName, tests) { 74 tests.forEach(test => { 75 promise_test(async t => { 76 const builder = new MLGraphBuilder(context); 77 const input = builder.input('input', test.input); 78 const axis = test.axis; 79 if (!context.opSupportLimits()[operatorName].input.dataTypes.includes(test.input.dataType)){ 80 assert_throws_js( 81 TypeError, () => builder[operatorName](input, axis, test.options)); 82 return; 83 } 84 if (test.options && test.options.outputDataType !== undefined) { 85 if (context.opSupportLimits()[operatorName].output.dataTypes.includes( 86 test.options.outputDataType)) { 87 const output = builder[operatorName](input, axis, test.options); 88 assert_equals(output.dataType, test.options.outputDataType); 89 assert_array_equals(output.shape, test.output.shape); 90 } else { 91 assert_throws_js( 92 TypeError, () => builder[operatorName](input, axis, test.options)); 93 } 94 return; 95 } 96 if (test.output) { 97 const output = builder[operatorName](input, axis, test.options); 98 assert_equals(output.dataType, 'int32'); 99 assert_array_equals(output.shape, test.output.shape); 100 } else { 101 const regrexp = /\[arg_min_max_1_\!\]/; 102 assert_throws_with_label( 103 () => builder[operatorName](input, axis, test.options), regrexp); 104 } 105 }, test.name.replace('[argMin/Max]', `[${operatorName}]`)); 106 }); 107 } 108 109 kArgMinMaxOperators.forEach((operatorName) => { 110 validateInputFromAnotherBuilder(operatorName); 111 runTests(operatorName, tests); 112 });