gather.https.any.js (4116B)
1 // META: title=validation tests for WebNN API gather operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const label = 'gather_' 11 const regrexp = new RegExp('\\[' + label + '\\]'); 12 const tests = [ 13 { 14 name: '[gather] Test gather with default options and 0-D indices', 15 input: {dataType: 'int32', shape: [3]}, 16 indices: {dataType: 'int64', shape: []}, 17 output: {dataType: 'int32', shape: []} 18 }, 19 { 20 name: '[gather] Test gather with axis = 2', 21 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 22 indices: {dataType: 'int64', shape: [5, 6]}, 23 axis: 2, 24 output: {dataType: 'float32', shape: [1, 2, 5, 6, 4]} 25 }, 26 { 27 name: '[gather] Test gather with indices\'s dataType = uint32', 28 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 29 indices: {dataType: 'uint32', shape: [5, 6]}, 30 axis: 2, 31 output: {dataType: 'float32', shape: [1, 2, 5, 6, 4]} 32 }, 33 { 34 name: '[gather] Test gather with indices\'s dataType = int32', 35 input: {dataType: 'float32', shape: [1, 2, 3, 4]}, 36 indices: {dataType: 'int32', shape: [5, 6]}, 37 axis: 2, 38 output: {dataType: 'float32', shape: [1, 2, 5, 6, 4]} 39 }, 40 { 41 name: '[gather] TypeError is expected if the input is a scalar', 42 input: {dataType: 'float16', shape: []}, 43 indices: {dataType: 'int64', shape: [1]}, 44 }, 45 { 46 name: 47 '[gather] TypeError is expected if the axis is greater than the rank of input', 48 input: {dataType: 'float16', shape: [1, 2, 3]}, 49 indices: {dataType: 'int32', shape: [5, 6]}, 50 axis: 4, 51 }, 52 { 53 name: 54 '[gather] TypeError is expected if the data type of indices is float32 which is invalid', 55 input: {dataType: 'float16', shape: [1, 2, 3, 4]}, 56 indices: {dataType: 'float32', shape: [5, 6]}, 57 }, 58 { 59 name: 60 '[gather] TypeError is expected if the data type of indices is uint64 which is invalid', 61 input: {dataType: 'float16', shape: [1, 2, 3, 4]}, 62 indices: {dataType: 'uint64', shape: [5, 6]}, 63 }, 64 ]; 65 66 tests.forEach( 67 test => promise_test(async t => { 68 const builder = new MLGraphBuilder(context); 69 const input = builder.input('input', test.input); 70 const indices = builder.input('indices', test.indices); 71 72 const options = {}; 73 if (test.axis) { 74 options.axis = test.axis; 75 } 76 77 if (test.output) { 78 const output = builder.gather(input, indices, options); 79 assert_equals(output.dataType, test.output.dataType); 80 assert_array_equals(output.shape, test.output.shape); 81 } else { 82 options.label = label; 83 assert_throws_with_label( 84 () => builder.gather(input, indices, options), regrexp); 85 } 86 }, test.name)); 87 88 multi_builder_test(async (t, builder, otherBuilder) => { 89 const inputFromOtherBuilder = 90 otherBuilder.input('input', {dataType: 'float32', shape: [2, 2]}); 91 92 const indices = builder.input('indices', {dataType: 'int64', shape: [2, 2]}); 93 assert_throws_js( 94 TypeError, () => builder.gather(inputFromOtherBuilder, indices)); 95 }, '[gather] throw if input is from another builder'); 96 97 multi_builder_test(async (t, builder, otherBuilder) => { 98 const indicesFromOtherBuilder = 99 otherBuilder.input('indices', {dataType: 'int64', shape: [2, 2]}); 100 101 const input = builder.input('input', {dataType: 'float32', shape: [2, 2]}); 102 assert_throws_js( 103 TypeError, () => builder.gather(input, indicesFromOtherBuilder)); 104 }, '[gather] throw if indices is from another builder'); 105 106 promise_test(async t => { 107 const builder = new MLGraphBuilder(context); 108 109 const input = builder.input('input', { 110 dataType: 'float32', shape: [1, 3, 3, 4]}); 111 const indices = builder.input('indices', { 112 dataType: 'int32', 113 shape: [context.opSupportLimits().maxTensorByteLength / 4] }); 114 115 const options = {}; 116 options.label = label; 117 options.axis = 2; 118 assert_throws_with_label( 119 () => builder.gather(input, indices, options), regrexp); 120 }, '[gather] throw if the output tensor byte length exceeds limit');