concat.https.any.js (4153B)
1 // META: title=validation tests for WebNN API concat operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const label = `concate_123`; 11 const tests = [ 12 { 13 name: '[concat] Test building Concat with one input.', 14 inputs: [{dataType: 'float32', shape: [4, 4, 3]}], 15 axis: 2, 16 output: {dataType: 'float32', shape: [4, 4, 3]} 17 }, 18 { 19 name: '[concat] Test building Concat with two inputs', 20 inputs: [ 21 {dataType: 'float32', shape: [3, 1, 5]}, 22 {dataType: 'float32', shape: [3, 2, 5]} 23 ], 24 axis: 1, 25 output: {dataType: 'float32', shape: [3, 3, 5]} 26 }, 27 { 28 name: '[concat] Test building Concat with three inputs', 29 inputs: [ 30 {dataType: 'float32', shape: [3, 5, 1]}, 31 {dataType: 'float32', shape: [3, 5, 2]}, 32 {dataType: 'float32', shape: [3, 5, 3]} 33 ], 34 axis: 2, 35 output: {dataType: 'float32', shape: [3, 5, 6]} 36 }, 37 { 38 name: '[concat] Test building Concat with two 1D inputs.', 39 inputs: 40 [{dataType: 'float32', shape: [1]}, {dataType: 'float32', shape: [1]}], 41 axis: 0, 42 output: {dataType: 'float32', shape: [2]} 43 }, 44 { 45 name: '[concat] Throw if the inputs are empty.', 46 axis: 0, 47 }, 48 { 49 name: '[concat] Throw if the argument types are inconsistent.', 50 inputs: [ 51 {dataType: 'float32', shape: [1, 1]}, {dataType: 'int32', shape: [1, 1]} 52 ], 53 axis: 0, 54 }, 55 { 56 name: '[concat] Throw if the inputs have different ranks.', 57 inputs: [ 58 {dataType: 'float32', shape: [1, 1]}, 59 {dataType: 'float32', shape: [1, 1, 1]} 60 ], 61 axis: 0, 62 }, 63 { 64 name: 65 '[concat] Throw if the axis is equal to or greater than the size of ranks', 66 inputs: [ 67 {dataType: 'float32', shape: [1, 1]}, {dataType: 'float32', shape: [1, 1]} 68 ], 69 axis: 2, 70 }, 71 { 72 name: '[concat] Throw if concat with two 0-D scalars.', 73 inputs: 74 [{dataType: 'float32', shape: []}, {dataType: 'float32', shape: []}], 75 axis: 0, 76 }, 77 { 78 name: 79 '[concat] Throw if the inputs have other axes with different sizes except on the axis.', 80 inputs: [ 81 {dataType: 'float32', shape: [1, 1, 1]}, 82 {dataType: 'float32', shape: [1, 2, 3]} 83 ], 84 axis: 1, 85 }, 86 ]; 87 88 tests.forEach( 89 test => promise_test(async t => { 90 const builder = new MLGraphBuilder(context); 91 let inputs = []; 92 if (test.inputs) { 93 for (let i = 0; i < test.inputs.length; ++i) { 94 inputs[i] = builder.input(`inputs[${i}]`, test.inputs[i]); 95 } 96 } 97 if (test.output) { 98 const output = builder.concat(inputs, test.axis); 99 assert_equals(output.dataType, test.output.dataType); 100 assert_array_equals(output.shape, test.output.shape); 101 } else { 102 const options = {label}; 103 const regrexp = new RegExp('\\[' + label + '\\]'); 104 assert_throws_with_label( 105 () => builder.concat(inputs, test.axis, options), regrexp); 106 } 107 }, test.name)); 108 109 multi_builder_test(async (t, builder, otherBuilder) => { 110 const operandDescriptor = {dataType: 'float32', shape: [2, 2]}; 111 112 const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor); 113 114 const input1 = builder.input('input', operandDescriptor); 115 const input2 = builder.input('input', operandDescriptor); 116 const input3 = builder.input('input', operandDescriptor); 117 118 assert_throws_js( 119 TypeError, 120 () => builder.concat([input1, input2, inputFromOtherBuilder, input3])); 121 }, '[concat] throw if any input is from another builder'); 122 123 promise_test(async t => { 124 const builder = new MLGraphBuilder(context); 125 126 const operandDescriptor = { 127 dataType: 'float32', 128 shape: [context.opSupportLimits().maxTensorByteLength / 4] 129 }; 130 const input1 = builder.input('input1', operandDescriptor); 131 const input2 = builder.input('input2', operandDescriptor); 132 const input3 = builder.input('input3', operandDescriptor); 133 134 assert_throws_js( 135 TypeError, () => builder.concat(input1, input2, input3)); 136 }, '[concat] throw if the output tensor byte length exceeds limit');