expand.https.any.js (4440B)
1 // META: title=validation tests for WebNN API expand operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 multi_builder_test(async (t, builder, otherBuilder) => { 11 const inputFromOtherBuilder = 12 otherBuilder.input('input', {dataType: 'float32', shape: [2, 1, 2]}); 13 14 const newShape = [2, 2, 2]; 15 assert_throws_js( 16 TypeError, () => builder.expand(inputFromOtherBuilder, newShape)); 17 }, '[expand] throw if input is from another builder'); 18 19 const label = 'xxx_expand'; 20 const regexp = new RegExp('\\[' + label + '\\]'); 21 const tests = [ 22 { 23 name: '[expand] Test with 0-D scalar to 3-D tensor.', 24 input: {dataType: 'float32', shape: []}, 25 newShape: [3, 4, 5], 26 output: {dataType: 'float32', shape: [3, 4, 5]} 27 }, 28 { 29 name: '[expand] Test with the new shapes that are the same as input.', 30 input: {dataType: 'float32', shape: [4]}, 31 newShape: [4], 32 output: {dataType: 'float32', shape: [4]} 33 }, 34 { 35 name: '[expand] Test with the new shapes that are broadcastable.', 36 input: {dataType: 'float32', shape: [3, 1, 5]}, 37 newShape: [3, 4, 5], 38 output: {dataType: 'float32', shape: [3, 4, 5]} 39 }, 40 { 41 name: 42 '[expand] Test with the new shapes that are broadcastable and the rank of new shapes is larger than input.', 43 input: {dataType: 'float32', shape: [2, 5]}, 44 newShape: [3, 2, 5], 45 output: {dataType: 'float32', shape: [3, 2, 5]} 46 }, 47 { 48 name: 49 '[expand] Throw if the input shapes are the same rank but not broadcastable.', 50 input: {dataType: 'float32', shape: [3, 6, 2]}, 51 newShape: [4, 3, 5], 52 options: {label} 53 }, 54 { 55 name: '[expand] Throw if the input shapes are not broadcastable.', 56 input: {dataType: 'float32', shape: [5, 4]}, 57 newShape: [5], 58 options: {label} 59 }, 60 { 61 name: '[expand] Throw if the number of new shapes is too large.', 62 input: {dataType: 'float32', shape: [1, 2, 1, 1]}, 63 newShape: [1, 2, kMaxUnsignedLong, kMaxUnsignedLong], 64 }, 65 ]; 66 67 tests.forEach( 68 test => promise_test(async t => { 69 const builder = new MLGraphBuilder(context); 70 const input = builder.input('input', test.input); 71 72 if (test.output) { 73 const output = builder.expand(input, test.newShape); 74 assert_equals(output.dataType, test.output.dataType); 75 assert_array_equals(output.shape, test.output.shape); 76 } else { 77 const options = {...test.options}; 78 if (options.label) { 79 assert_throws_with_label( 80 () => builder.expand(input, test.newShape, options), regexp); 81 } else { 82 assert_throws_js( 83 TypeError, () => builder.expand(input, test.newShape, options)); 84 } 85 } 86 }, test.name)); 87 88 promise_test(async t => { 89 for (let dataType of allWebNNOperandDataTypes) { 90 if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { 91 continue; 92 } 93 const builder = new MLGraphBuilder(context); 94 const shape = [1]; 95 const newShape = [1, 2, 3]; 96 const input = builder.input(`input`, {dataType, shape}); 97 if (context.opSupportLimits().expand.input.dataTypes.includes(dataType)) { 98 const output = builder.expand(input, newShape); 99 assert_equals(output.dataType, dataType); 100 assert_array_equals(output.shape, newShape); 101 } else { 102 assert_throws_js(TypeError, () => builder.expand(input, newShape)); 103 } 104 } 105 }, `[expand] Test expand with all of the data types.`); 106 107 promise_test(async t => { 108 const builder = new MLGraphBuilder(context); 109 110 const input = builder.input('input', { 111 dataType: 'float32', shape: [1, 2, 1, 1]}); 112 const newShape = [1, 2, context.opSupportLimits().maxTensorByteLength, 1]; 113 114 const options = {label}; 115 assert_throws_with_label( 116 () => builder.expand(input, newShape, options), regexp); 117 }, '[expand] throw if the output tensor byte length exceeds limit'); 118 119 promise_test(async t => { 120 const builder = new MLGraphBuilder(context); 121 122 const input = builder.input('input', {dataType: 'float32', shape: [2]}); 123 const newShape = 124 new Array(context.opSupportLimits().expand.output.rankRange.max + 1) 125 .fill(1); 126 newShape[newShape.length - 1] = 2; 127 128 const options = {label}; 129 assert_throws_with_label( 130 () => builder.expand(input, newShape, options), regexp); 131 }, '[expand] throw if new shape rank exceeds limit');