gatherND.https.any.js (3407B)
1 // META: title=validation tests for WebNN API gatherND operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 const label = 'gatherND_'; 11 const regexp = new RegExp('\\[' + label + '\\]'); 12 const tests = [ 13 { 14 name: '[gatherND] Test gatherND with 5D input 3D indices', 15 input: {dataType: 'float32', shape: [2, 2, 3, 3, 4]}, 16 indices: {dataType: 'int32', shape: [5, 4, 3]}, 17 output: {dataType: 'float32', shape: [5, 4, 3, 4]} 18 }, 19 { 20 name: '[gatherND] Throw if input is a scalar', 21 input: {dataType: 'float32', shape: []}, 22 indices: {dataType: 'int32', shape: [1, 1, 1]} 23 }, 24 { 25 name: '[gatherND] Throw if indices is a scalar', 26 input: {dataType: 'float32', shape: [1, 1, 1]}, 27 indices: {dataType: 'int32', shape: []} 28 }, 29 { 30 name: '[gatherND] Throw if indices data type is float32', 31 input: {dataType: 'float32', shape: [1, 2, 3]}, 32 indices: {dataType: 'float32', shape: [1, 1, 1]}, 33 }, 34 { 35 name: 36 '[gatherND] Throw if indices.shape[-1] is greater than the input rank', 37 input: {dataType: 'float32', shape: [1, 2, 3]}, 38 indices: {dataType: 'int32', shape: [1, 1, 4]} 39 } 40 ]; 41 42 tests.forEach(test => promise_test(async t => { 43 const builder = new MLGraphBuilder(context); 44 const input = builder.input('input', test.input); 45 const indices = builder.input('indices', test.indices); 46 47 if (test.output && 48 context.opSupportLimits().gatherND.input.dataTypes.includes( 49 test.input.dataType)) { 50 const output = builder.gatherND(input, indices); 51 assert_equals(output.dataType, test.output.dataType); 52 assert_array_equals(output.shape, test.output.shape); 53 } else { 54 const options = {label: label}; 55 assert_throws_with_label( 56 () => builder.gatherND(input, indices, options), regexp); 57 } 58 }, test.name)); 59 60 multi_builder_test(async (t, builder, otherBuilder) => { 61 const inputFromOtherBuilder = 62 otherBuilder.input('input', {dataType: 'float32', shape: [2, 2]}); 63 64 const indices = builder.input('indices', {dataType: 'int32', shape: [2, 1]}); 65 assert_throws_js( 66 TypeError, () => builder.gatherND(inputFromOtherBuilder, indices)); 67 }, '[gatherND] Throw if input is from another builder'); 68 69 multi_builder_test(async (t, builder, otherBuilder) => { 70 const indicesFromOtherBuilder = 71 otherBuilder.input('indices', {dataType: 'int32', shape: [2, 2]}); 72 73 const input = builder.input('input', {dataType: 'float32', shape: [2, 1]}); 74 assert_throws_js( 75 TypeError, () => builder.gatherND(input, indicesFromOtherBuilder)); 76 }, '[gatherND] Throw if indices is from another builder'); 77 78 promise_test(async t => { 79 const builder = new MLGraphBuilder(context); 80 81 const input = builder.input('input', { 82 dataType: 'float32', shape: [2, 2, 3, 3, 4]}); 83 const indices = builder.input('indices', { 84 dataType: 'int32', 85 shape: [context.opSupportLimits().maxTensorByteLength / 4, 1, 1]}); 86 87 const options = {label}; 88 assert_throws_with_label( 89 () => builder.gatherND(input, indices, options), regexp); 90 }, '[gatherND] throw if the output tensor byte length exceeds limit');