split.https.any.js (3348B)
1 // META: title=validation tests for WebNN API split operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 multi_builder_test(async (t, builder, otherBuilder) => { 11 const inputFromOtherBuilder = 12 otherBuilder.input('input', {dataType: 'float32', shape: [4, 4]}); 13 14 const splits = 2; 15 assert_throws_js( 16 TypeError, () => builder.split(inputFromOtherBuilder, splits)); 17 }, '[split] throw if input is from another builder'); 18 19 const label = 'xxx-split'; 20 const tests = [ 21 { 22 name: '[split] Test with default options.', 23 input: {dataType: 'float32', shape: [2, 6]}, 24 splits: [2], 25 outputs: [ 26 {dataType: 'float32', shape: [2, 6]}, 27 ] 28 }, 29 { 30 name: 31 '[split] Test with a sequence of unsigned long splits and with options.axis = 1.', 32 input: {dataType: 'float32', shape: [2, 6]}, 33 splits: [1, 2, 3], 34 options: {axis: 1}, 35 outputs: [ 36 {dataType: 'float32', shape: [2, 1]}, 37 {dataType: 'float32', shape: [2, 2]}, 38 {dataType: 'float32', shape: [2, 3]}, 39 ] 40 }, 41 { 42 name: '[split] Throw if splitting a scalar.', 43 input: {dataType: 'float32', shape: []}, 44 splits: [1], 45 options: {label} 46 }, 47 { 48 name: '[split] Throw if axis is larger than input rank.', 49 input: {dataType: 'float32', shape: [2, 6]}, 50 splits: [2], 51 options: { 52 axis: 2, 53 label: label, 54 } 55 }, 56 { 57 name: '[split] Throw if splits is equal to 0.', 58 input: {dataType: 'float32', shape: [2, 6]}, 59 splits: [0], 60 options: { 61 axis: 0, 62 label: label, 63 } 64 }, 65 { 66 name: '[split] Throw if splits (scalar) is equal to 0.', 67 input: {dataType: 'float32', shape: [2, 6]}, 68 splits: 0, 69 options: { 70 axis: 0, 71 label: label, 72 }, 73 }, 74 { 75 name: 76 '[split] Throw if the splits can not evenly divide the dimension size of input along options.axis.', 77 input: {dataType: 'float32', shape: [2, 5]}, 78 splits: [2], 79 options: { 80 axis: 1, 81 label: label, 82 } 83 }, 84 { 85 name: 86 '[split] Throw if splits (scalar) can not evenly divide the dimension size of input along options.axis.', 87 input: {dataType: 'float32', shape: [2, 5]}, 88 splits: 2, 89 options: { 90 axis: 1, 91 label: label, 92 }, 93 }, 94 { 95 name: 96 '[split] Throw if the sum of splits sizes not equal to the dimension size of input along options.axis.', 97 input: {dataType: 'float32', shape: [2, 6]}, 98 splits: [2, 2, 3], 99 options: { 100 axis: 1, 101 label: label, 102 } 103 }, 104 ]; 105 106 tests.forEach( 107 test => promise_test(async t => { 108 const builder = new MLGraphBuilder(context); 109 const input = builder.input('input', test.input); 110 if (test.outputs) { 111 const outputs = builder.split(input, test.splits, test.options); 112 assert_equals(outputs.length, test.outputs.length); 113 for (let i = 0; i < outputs.length; ++i) { 114 assert_equals(outputs[i].dataType, test.outputs[i].dataType); 115 assert_array_equals(outputs[i].shape, test.outputs[i].shape); 116 } 117 } else { 118 const regrexp = new RegExp('\\[' + label + '\\]'); 119 assert_throws_with_label( 120 () => builder.split(input, test.splits, test.options), regrexp); 121 } 122 }, test.name));