reshape.https.any.js (3337B)
1 // META: title=validation tests for WebNN API reshape operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 multi_builder_test(async (t, builder, otherBuilder) => { 11 const inputFromOtherBuilder = 12 otherBuilder.input('input', {dataType: 'float32', shape: [1, 2, 3]}); 13 14 const newShape = [3, 2, 1]; 15 assert_throws_js( 16 TypeError, () => builder.reshape(inputFromOtherBuilder, newShape)); 17 }, '[reshape] throw if input is from another builder'); 18 19 const tests = [ 20 { 21 name: '[reshape] Test with new shape=[3, 8].', 22 input: {dataType: 'float32', shape: [2, 3, 4]}, 23 newShape: [3, 8], 24 output: {dataType: 'float32', shape: [3, 8]} 25 }, 26 { 27 name: '[reshape] Test with new shape=[24], src shape=[2, 3, 4].', 28 input: {dataType: 'float32', shape: [2, 3, 4]}, 29 newShape: [24], 30 output: {dataType: 'float32', shape: [24]} 31 }, 32 { 33 name: '[reshape] Test with new shape=[1], src shape=[1].', 34 input: {dataType: 'float32', shape: [1]}, 35 newShape: [1], 36 output: {dataType: 'float32', shape: [1]} 37 }, 38 { 39 name: '[reshape] Test reshaping a 1-D 1-element tensor to scalar.', 40 input: {dataType: 'float32', shape: [1]}, 41 newShape: [], 42 output: {dataType: 'float32', shape: []} 43 }, 44 { 45 name: '[reshape] Test reshaping a scalar to 1-D 1-element tensor.', 46 input: {dataType: 'float32', shape: []}, 47 newShape: [1], 48 output: {dataType: 'float32', shape: [1]} 49 }, 50 { 51 name: '[reshape] Throw if one value of new shape is 0.', 52 input: {dataType: 'float32', shape: [2, 4]}, 53 newShape: [2, 4, 0], 54 }, 55 { 56 name: 57 '[reshape] Throw if the number of elements implied by new shape is not equal to the number of elements in the input tensor when new shape=[].', 58 input: {dataType: 'float32', shape: [2, 3, 4]}, 59 newShape: [], 60 }, 61 { 62 name: 63 '[reshape] Throw if the number of elements implied by new shape is not equal to the number of elements in the input tensor.', 64 input: {dataType: 'float32', shape: [2, 3, 4]}, 65 newShape: [3, 9], 66 }, 67 ]; 68 69 tests.forEach( 70 test => promise_test(async t => { 71 const builder = new MLGraphBuilder(context); 72 const input = builder.input('input', test.input); 73 if (test.output) { 74 const output = builder.reshape(input, test.newShape); 75 assert_equals(output.dataType, test.output.dataType); 76 assert_array_equals(output.shape, test.output.shape); 77 } else { 78 const label = 'reshape_xxx'; 79 const options = {label}; 80 const regrexp = new RegExp('\\[' + label + '\\]'); 81 assert_throws_with_label( 82 () => builder.reshape(input, test.newShape, options), regrexp); 83 } 84 }, test.name)); 85 86 promise_test(async t => { 87 const builder = new MLGraphBuilder(context); 88 89 const input = builder.input('input', {dataType: 'float32', shape: [2]}); 90 const newShape = 91 new Array(context.opSupportLimits().expand.output.rankRange.max + 1) 92 .fill(1); 93 newShape[0] = 2; 94 95 const label = 'reshape_xxx'; 96 const options = {label}; 97 const regrexp = new RegExp('\\[' + label + '\\]'); 98 assert_throws_with_label( 99 () => builder.reshape(input, newShape, options), regrexp); 100 }, '[expand] throw if new shape rank exceeds limit');