transpose.https.any.js (2894B)
1 // META: title=validation tests for WebNN API transpose operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 validateInputFromAnotherBuilder('transpose'); 11 12 const label = 'transpose-2'; 13 const tests = [ 14 { 15 name: '[transpose] Test building transpose with permutation=[0, 2, 3, 1].', 16 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 17 options: {permutation: [0, 2, 3, 1]}, 18 output: {dataType: 'float32', shape: [1, 3, 4, 2]} 19 }, 20 { 21 name: 22 '[transpose] Throw if permutation\'s size is not the same as input\'s rank.', 23 input: {dataType: 'int32', shape: [1, 2, 4]}, 24 options: { 25 permutation: [0, 2, 3, 1], 26 label: label, 27 }, 28 }, 29 { 30 name: '[transpose] Throw if two values in permutation are same.', 31 input: {dataType: 'int32', shape: [1, 2, 3, 4]}, 32 options: { 33 permutation: [0, 2, 3, 2], 34 label: label, 35 }, 36 }, 37 { 38 name: 39 '[transpose] Throw if any value in permutation is not in the range [0,input\'s rank).', 40 input: {dataType: 'int32', shape: [1, 2, 3, 4]}, 41 options: { 42 permutation: [0, 1, 2, 4], 43 label: label, 44 }, 45 }, 46 { 47 name: '[transpose] Throw if any value in permutation is negative.', 48 input: {dataType: 'int32', shape: [1, 2, 3, 4]}, 49 options: { 50 permutation: [0, -1, 2, 3], 51 }, 52 } 53 ]; 54 55 tests.forEach( 56 test => promise_test(async t => { 57 const builder = new MLGraphBuilder(context); 58 const input = builder.input('input', test.input); 59 if (test.output) { 60 const output = builder.transpose(input, test.options); 61 assert_equals(output.dataType, test.output.dataType); 62 assert_array_equals(output.shape, test.output.shape); 63 } else { 64 const options = {...test.options}; 65 if (options.label) { 66 const regrexp = new RegExp('\\[' + label + '\\]'); 67 assert_throws_with_label( 68 () => builder.transpose(input, options), regrexp); 69 } else { 70 assert_throws_js(TypeError, () => builder.transpose(input, options)); 71 } 72 } 73 }, test.name)); 74 75 promise_test(async t => { 76 for (let dataType of allWebNNOperandDataTypes) { 77 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 78 continue; 79 } 80 const builder = new MLGraphBuilder(context); 81 const shape = [1, 2, 3, 4]; 82 const input = builder.input(`input`, {dataType, shape}); 83 if (context.opSupportLimits().transpose.input.dataTypes.includes( 84 dataType)) { 85 const output = builder.transpose(input); 86 assert_equals(output.dataType, dataType); 87 assert_array_equals(output.shape, [4, 3, 2, 1]); 88 } else { 89 assert_throws_js(TypeError, () => builder.transpose(input)); 90 } 91 } 92 }, `[transpose] Test transpose with all of the data types.`);