scatterElements.https.any.js (5486B)
1 // META: title=validation tests for WebNN API scatterElements operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const tests = [ 11 { 12 name: '[scatterElements] Test scatterElements with default options', 13 input: {dataType: 'float32', shape: [3, 3]}, 14 indices: {dataType: 'int32', shape: [2, 3]}, 15 updates: {dataType: 'float32', shape: [2, 3]}, 16 output: {dataType: 'float32', shape: [3, 3]} 17 }, 18 { 19 name: '[scatterElements] Test scatterElements with axis = 0', 20 input: {dataType: 'float32', shape: [3, 3]}, 21 indices: {dataType: 'int32', shape: [2, 3]}, 22 updates: {dataType: 'float32', shape: [2, 3]}, 23 axis: 0, 24 output: {dataType: 'float32', shape: [3, 3]} 25 }, 26 { 27 name: '[scatterElements] Test scatterElements with axis = 1', 28 input: {dataType: 'float32', shape: [3, 3]}, 29 indices: {dataType: 'int32', shape: [3, 2]}, 30 updates: {dataType: 'float32', shape: [3, 2]}, 31 axis: 1, 32 output: {dataType: 'float32', shape: [3, 3]} 33 }, 34 { 35 name: '[scatterElements] Throw if axis is greater than input rank', 36 input: {dataType: 'float32', shape: [3, 3]}, 37 indices: {dataType: 'int32', shape: [2, 3]}, 38 updates: {dataType: 'float32', shape: [2, 3]}, 39 axis: 2 40 }, 41 { 42 name: 43 '[scatterElements] Throw if updates tensor data type is not the same as input data type', 44 input: {dataType: 'float32', shape: [3, 3]}, 45 indices: {dataType: 'int32', shape: [2, 3]}, 46 updates: {dataType: 'float16', shape: [2, 3]}, 47 }, 48 { 49 name: '[scatterElements] Throw if input, indices and updates are scalar', 50 input: {dataType: 'float32', shape: []}, 51 indices: {dataType: 'int32', shape: []}, 52 updates: {dataType: 'float32', shape: []}, 53 }, 54 { 55 name: 56 '[scatterElements] Throw if indices rank is not the same as input rank', 57 input: {dataType: 'float32', shape: [3, 3]}, 58 indices: {dataType: 'int32', shape: [2, 3, 3]}, 59 updates: {dataType: 'float32', shape: [2, 3, 3]}, 60 }, 61 { 62 name: 63 '[scatterElements] Throw if indices size is not the same as input size along axis 1', 64 input: {dataType: 'float32', shape: [3, 3]}, 65 indices: {dataType: 'int32', shape: [2, 4]}, 66 updates: {dataType: 'float32', shape: [2, 4]}, 67 axis: 0 68 }, 69 { 70 name: 71 '[scatterElements] Throw if indices size is not the same as input size along axis 0', 72 input: {dataType: 'float32', shape: [3, 3]}, 73 indices: {dataType: 'int32', shape: [2, 2]}, 74 updates: {dataType: 'float32', shape: [2, 2]}, 75 axis: 1 76 }, 77 { 78 name: 79 '[scatterElements] Throw if indices rank is not the same as updates rank', 80 input: {dataType: 'float32', shape: [3, 3]}, 81 indices: {dataType: 'int32', shape: [2, 3]}, 82 updates: {dataType: 'float32', shape: [2, 3, 3]}, 83 }, 84 { 85 name: 86 '[scatterElements] Throw if indices shape is not the same as updates shape', 87 input: {dataType: 'float32', shape: [3, 3]}, 88 indices: {dataType: 'int32', shape: [2, 3]}, 89 updates: {dataType: 'float32', shape: [2, 4]}, 90 } 91 ]; 92 93 tests.forEach( 94 test => promise_test(async t => { 95 const builder = new MLGraphBuilder(context); 96 const input = builder.input('input', test.input); 97 const indices = builder.input('indices', test.indices); 98 const updates = builder.input('updates', test.updates); 99 100 const options = {}; 101 if (test.axis) { 102 options.axis = test.axis; 103 } 104 105 if (test.output) { 106 const output = 107 builder.scatterElements(input, indices, updates, options); 108 assert_equals(output.dataType, test.output.dataType); 109 assert_array_equals(output.shape, test.output.shape); 110 } else { 111 const label = 'a_scatter_elements' 112 options.label = label; 113 const regexp = new RegExp('\\[' + label + '\\]'); 114 assert_throws_with_label( 115 () => builder.scatterElements(input, indices, updates, options), 116 regexp); 117 } 118 }, test.name)); 119 120 multi_builder_test(async (t, builder, otherBuilder) => { 121 const input = 122 otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); 123 const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); 124 const updates = 125 builder.input('updates', {dataType: 'float32', shape: [2, 3]}); 126 127 assert_throws_js( 128 TypeError, () => builder.scatterElements(input, indices, updates)); 129 }, '[scatterElements] Throw if input is from another builder'); 130 131 multi_builder_test(async (t, builder, otherBuilder) => { 132 const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); 133 const indices = 134 otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]}); 135 const updates = 136 builder.input('updates', {dataType: 'float32', shape: [2, 3]}); 137 138 assert_throws_js( 139 TypeError, () => builder.scatterElements(input, indices, updates)); 140 }, '[scatterElements] Throw if indices is from another builder'); 141 142 multi_builder_test(async (t, builder, otherBuilder) => { 143 const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); 144 const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); 145 const updates = 146 otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]}); 147 148 assert_throws_js( 149 TypeError, () => builder.scatterElements(input, indices, updates)); 150 }, '[scatterElements] Throw if updates is from another builder');