gatherElements.https.any.js (3304B)
1 // META: title=validation tests for WebNN API gatherElements operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const tests = [ 11 { 12 name: '[gatherElements] Test gatherElements with default options', 13 input: {dataType: 'float32', shape: [1, 2, 3]}, 14 indices: {dataType: 'int32', shape: [2, 2, 3]}, 15 output: {dataType: 'float32', shape: [2, 2, 3]} 16 }, 17 { 18 name: '[gatherElements] Test gatherElements with axis = 2', 19 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 20 indices: {dataType: 'int32', shape: [1, 2, 1, 4]}, 21 axis: 2, 22 output: {dataType: 'float32', shape: [1, 2, 1, 4]} 23 }, 24 { 25 name: '[gatherElements] Throw if input is a scalar', 26 input: {dataType: 'float32', shape: []}, 27 indices: {dataType: 'int32', shape: []} 28 }, 29 { 30 name: 31 '[gatherElements] Throw if the axis is greater than the rank of input', 32 input: {dataType: 'float32', shape: [1, 2, 3]}, 33 indices: {dataType: 'int32', shape: [1, 2, 3]}, 34 axis: 4 35 }, 36 { 37 name: '[gatherElements] Throw if indices data type is float32', 38 input: {dataType: 'float32', shape: [1, 2, 3]}, 39 indices: {dataType: 'float32', shape: [1, 2, 3]} 40 }, 41 { 42 name: '[gatherElements] Throw if input rank is not equal to indices rank', 43 input: {dataType: 'float32', shape: [1, 2, 3]}, 44 indices: {dataType: 'int32', shape: [1, 2]} 45 }, 46 { 47 name: '[gatherElements] Throw if indices shape is incorrect', 48 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 49 indices: {dataType: 'int32', shape: [3, 2, 3, 4]}, 50 axis: 3 51 } 52 ]; 53 54 tests.forEach( 55 test => promise_test(async t => { 56 const builder = new MLGraphBuilder(context); 57 const input = builder.input('input', test.input); 58 const indices = builder.input('indices', test.indices); 59 60 const options = {}; 61 if (test.axis) { 62 options.axis = test.axis; 63 } 64 65 if (test.output) { 66 const output = builder.gatherElements(input, indices, options); 67 assert_equals(output.dataType, test.output.dataType); 68 assert_array_equals(output.shape, test.output.shape); 69 } else { 70 const label = 'gatherElements_' 71 options.label = label; 72 const regrexp = new RegExp('\\[' + label + '\\]'); 73 assert_throws_with_label( 74 () => builder.gatherElements(input, indices, options), regrexp); 75 } 76 }, test.name)); 77 78 multi_builder_test(async (t, builder, otherBuilder) => { 79 const inputFromOtherBuilder = 80 otherBuilder.input('input', {dataType: 'float32', shape: [2, 2]}); 81 82 const indices = builder.input('indices', {dataType: 'int32', shape: [2, 2]}); 83 assert_throws_js( 84 TypeError, () => builder.gatherElements(inputFromOtherBuilder, indices)); 85 }, '[gatherElements] Throw if input is from another builder'); 86 87 multi_builder_test(async (t, builder, otherBuilder) => { 88 const indicesFromOtherBuilder = 89 otherBuilder.input('indices', {dataType: 'int32', shape: [2, 2]}); 90 91 const input = builder.input('input', {dataType: 'float32', shape: [2, 2]}); 92 assert_throws_js( 93 TypeError, () => builder.gatherElements(input, indicesFromOtherBuilder)); 94 }, '[gatherElements] Throw if indices is from another builder');