scatterND.https.any.js (4045B)
1 // META: title=validation tests for WebNN API scatterND operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const tests = [ 11 { 12 name: '[scatterND] Test scatterND with valid tensors', 13 input: {dataType: 'float32', shape: [4, 4, 4]}, 14 indices: {dataType: 'int32', shape: [2, 1]}, 15 updates: {dataType: 'float32', shape: [2, 4, 4]}, 16 output: {dataType: 'float32', shape: [4, 4, 4]} 17 }, 18 { 19 name: 20 '[scatterND] Throw if updates tensor data type is not the same as input data type', 21 input: {dataType: 'float32', shape: [4, 4, 4]}, 22 indices: {dataType: 'int32', shape: [2, 1]}, 23 updates: {dataType: 'float16', shape: [2, 4, 4]}, 24 }, 25 { 26 name: '[scatterND] Throw if input is a scalar', 27 input: {dataType: 'float32', shape: []}, 28 indices: {dataType: 'int32', shape: [2, 1]}, 29 updates: {dataType: 'float32', shape: [2, 4, 4]}, 30 }, 31 { 32 name: '[scatterND] Throw if indices is a scalar', 33 input: {dataType: 'float32', shape: [4, 4, 4]}, 34 indices: {dataType: 'int32', shape: []}, 35 updates: {dataType: 'float32', shape: [2, 4, 4]}, 36 }, 37 { 38 name: 39 '[scatterND] Throw if the size of last dimension of indices tensor is greater than input rank', 40 input: {dataType: 'float32', shape: [4, 4, 4]}, 41 indices: {dataType: 'int32', shape: [2, 4]}, 42 updates: {dataType: 'float32', shape: [2, 4, 4]}, 43 }, 44 { 45 name: '[scatterND] Throw if updates tensor shape is invalid.', 46 input: {dataType: 'float32', shape: [4, 4, 4]}, 47 indices: {dataType: 'int32', shape: [2, 1]}, 48 // Updates tensor shape should be [2, 4, 4]. 49 updates: {dataType: 'float32', shape: [2, 3, 4]}, 50 } 51 ]; 52 53 tests.forEach(test => promise_test(async t => { 54 const builder = new MLGraphBuilder(context); 55 const input = builder.input('input', test.input); 56 const indices = builder.input('indices', test.indices); 57 const updates = builder.input('updates', test.updates); 58 59 if (test.output) { 60 const output = builder.scatterND(input, indices, updates); 61 assert_equals(output.dataType, test.output.dataType); 62 assert_array_equals(output.shape, test.output.shape); 63 } else { 64 const label = 'a_scatter_nd' 65 const options = {label}; 66 const regexp = new RegExp('\\[' + label + '\\]'); 67 assert_throws_with_label( 68 () => builder.scatterND(input, indices, updates, options), 69 regexp); 70 } 71 }, test.name)); 72 73 multi_builder_test(async (t, builder, otherBuilder) => { 74 const input = otherBuilder.input('input', {dataType: 'float32', shape: [8]}); 75 const indices = builder.input('indices', {dataType: 'int32', shape: [4, 1]}); 76 const updates = builder.input('indices', {dataType: 'int32', shape: [4]}); 77 78 assert_throws_js(TypeError, () => builder.scatterND(input, indices, updates)); 79 }, '[scatterND] Throw if input is from another builder'); 80 81 multi_builder_test(async (t, builder, otherBuilder) => { 82 const input = builder.input('input', {dataType: 'float32', shape: [8]}); 83 const indices = 84 otherBuilder.input('indices', {dataType: 'int32', shape: [4, 1]}); 85 const updates = builder.input('indices', {dataType: 'int32', shape: [4]}); 86 87 assert_throws_js(TypeError, () => builder.scatterND(input, indices, updates)); 88 }, '[scatterND] Throw if indcies is from another builder'); 89 90 multi_builder_test(async (t, builder, otherBuilder) => { 91 const input = builder.input('input', {dataType: 'float32', shape: [8]}); 92 const indices = builder.input('indices', {dataType: 'int32', shape: [4, 1]}); 93 const updates = 94 otherBuilder.input('indices', {dataType: 'int32', shape: [4]}); 95 96 assert_throws_js(TypeError, () => builder.scatterND(input, indices, updates)); 97 }, '[scatterND] Throw if updates is from another builder');