prelu.https.any.js (3158B)
1 // META: title=validation tests for WebNN API prelu operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 validateTwoInputsFromMultipleBuilders('prelu'); 11 12 const tests = [ 13 { 14 name: 15 '[prelu] Test slope\'s shape = [3, 2, 5] which is the same as input\'s shape.', 16 input: {dataType: 'float32', shape: [3, 2, 5]}, 17 slope: {dataType: 'float32', shape: [3, 2, 5]}, 18 output: {dataType: 'float32', shape: [3, 2, 5]}, 19 }, 20 { 21 name: 22 '[prelu] Test slope\'s shape = [5] which is unidirectionally broadcastable to input\'s shape.', 23 input: {dataType: 'float32', shape: [3, 2, 5]}, 24 slope: {dataType: 'float32', shape: [5]}, 25 output: {dataType: 'float32', shape: [3, 2, 5]}, 26 }, 27 { 28 name: 29 '[prelu] Test slope\'s shape = [] which is unidirectionally broadcastable to input\'s shape.', 30 input: {dataType: 'float32', shape: [3, 2, 5]}, 31 slope: {dataType: 'float32', shape: []}, 32 output: {dataType: 'float32', shape: [3, 2, 5]}, 33 }, 34 { 35 name: 36 '[prelu] Test slope\'s shape = [2, 5] which is unidirectionally broadcastable to input\'s shape.', 37 input: {dataType: 'float32', shape: [3, 2, 5]}, 38 slope: {dataType: 'float32', shape: [2, 5]}, 39 output: {dataType: 'float32', shape: [3, 2, 5]}, 40 }, 41 { 42 name: 43 '[prelu] Throw if the shape of slope is not broadcastable to the shape of input.', 44 input: {dataType: 'float32', shape: [3, 2, 5]}, 45 slope: {dataType: 'float32', shape: [2]}, 46 }, 47 { 48 name: 49 '[prelu] Throw if the data type of slope does not match the data type of input.', 50 input: {dataType: 'float32', shape: [3, 2, 5]}, 51 slope: {dataType: 'int32', shape: [3, 2, 5]}, 52 }, 53 ]; 54 55 tests.forEach( 56 test => promise_test(async t => { 57 const builder = new MLGraphBuilder(context); 58 const input = builder.input('input', test.input); 59 const slope = builder.input('input', test.slope); 60 if (test.output) { 61 const output = builder.prelu(input, slope); 62 assert_equals(output.dataType, test.output.dataType); 63 assert_array_equals(output.shape, test.output.shape); 64 } else { 65 const label = 'prelu_123'; 66 const options = {label}; 67 const regrexp = new RegExp('\\[' + label + '\\]'); 68 assert_throws_with_label( 69 () => builder.prelu(input, slope, options), regrexp); 70 } 71 }, test.name)); 72 73 promise_test(async t => { 74 for (let dataType of allWebNNOperandDataTypes) { 75 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 76 continue; 77 } 78 const builder = new MLGraphBuilder(context); 79 const shape = [1]; 80 const input = builder.input(`input`, {dataType, shape}); 81 if (context.opSupportLimits().prelu.input.dataTypes.includes(dataType)) { 82 const output = builder.prelu(input, input); 83 assert_equals(output.dataType, dataType); 84 assert_array_equals(output.shape, shape); 85 } else { 86 assert_throws_js(TypeError, () => builder.prelu(input, input)); 87 } 88 } 89 }, `[prelu] Test prelu with all of the data types.`);