scalars.https.any.js (6045B)
1 // META: title=test that scalar values work as expected 2 // META: global=window,worker 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils.js 7 // META: timeout=long 8 9 'use strict'; 10 11 let mlContext; 12 13 // Skip tests if WebNN is unimplemented. 14 promise_setup(async () => { 15 assert_implements(navigator.ml, 'missing navigator.ml'); 16 mlContext = await navigator.ml.createContext(contextOptions); 17 }); 18 19 promise_test(async () => { 20 const builder = new MLGraphBuilder(mlContext); 21 const inputOperand = builder.input('input', {dataType: 'int32', shape: []}); 22 const constantOperand = builder.constant( 23 {dataType: 'int32', shape: [4]}, Int32Array.from([3, 2, 1, 7])); 24 const addOperand = builder.add(inputOperand, constantOperand); 25 26 const [inputTensor, outputTensor, mlGraph] = await Promise.all([ 27 mlContext.createTensor({dataType: 'int32', shape: [], writable: true}), 28 mlContext.createTensor({dataType: 'int32', shape: [4], readable: true}), 29 builder.build({'output': addOperand}) 30 ]); 31 32 mlContext.writeTensor(inputTensor, Int32Array.from([4])); 33 mlContext.dispatch(mlGraph, {'input': inputTensor}, {'output': outputTensor}); 34 assert_array_equals( 35 new Int32Array(await mlContext.readTensor(outputTensor)), 36 Int32Array.from([7, 6, 5, 11])); 37 }, 'scalar input'); 38 39 promise_test(async () => { 40 const builder = new MLGraphBuilder(mlContext); 41 const inputOperand = builder.input('input', {dataType: 'float32', shape: []}); 42 const constantOperand = builder.constant( 43 {dataType: 'float32', shape: []}, Float32Array.from([3])); 44 const addOperand = builder.add(inputOperand, constantOperand); 45 46 const [inputTensor, outputTensor, mlGraph] = await Promise.all([ 47 mlContext.createTensor({dataType: 'float32', shape: [], writable: true}), 48 mlContext.createTensor({dataType: 'float32', shape: [], readable: true}), 49 builder.build({'output': addOperand}) 50 ]); 51 52 mlContext.writeTensor(inputTensor, Float32Array.from([4])); 53 54 mlContext.dispatch(mlGraph, {'input': inputTensor}, {'output': outputTensor}); 55 56 assert_array_equals( 57 new Float32Array(await mlContext.readTensor(outputTensor)), 58 Float32Array.from([7])); 59 }, 'float32 scalar input, constant, and output'); 60 61 promise_test(async () => { 62 const builder = new MLGraphBuilder(mlContext); 63 const inputOperand = builder.input('input', {dataType: 'int32', shape: []}); 64 const constantOperand = 65 builder.constant({dataType: 'int32', shape: []}, Int32Array.from([3])); 66 const addOperand = builder.add(inputOperand, constantOperand); 67 68 const [inputTensor, outputTensor, mlGraph] = await Promise.all([ 69 mlContext.createTensor({dataType: 'int32', shape: [], writable: true}), 70 mlContext.createTensor({dataType: 'int32', shape: [], readable: true}), 71 builder.build({'output': addOperand}) 72 ]); 73 74 mlContext.writeTensor(inputTensor, Int32Array.from([4])); 75 mlContext.dispatch(mlGraph, {'input': inputTensor}, {'output': outputTensor}); 76 assert_array_equals( 77 new Int32Array(await mlContext.readTensor(outputTensor)), 78 Int32Array.from([7])); 79 }, 'int32 scalar input, constant, and output'); 80 81 // Tests for constant(type, value) 82 promise_test(async () => { 83 const builder = new MLGraphBuilder(mlContext); 84 const inputOperand = builder.input('input', {dataType: 'float32', shape: []}); 85 const constantOperand = builder.constant('float32', 3.0); 86 const addOperand = builder.add(inputOperand, constantOperand); 87 88 const [inputTensor, outputTensor, mlGraph] = await Promise.all([ 89 mlContext.createTensor({dataType: 'float32', shape: [], writable: true}), 90 mlContext.createTensor({dataType: 'float32', shape: [], readable: true}), 91 builder.build({'output': addOperand}) 92 ]); 93 94 mlContext.writeTensor(inputTensor, Float32Array.from([2.0])); 95 mlContext.dispatch(mlGraph, {'input': inputTensor}, {'output': outputTensor}); 96 97 const result = new Float32Array(await mlContext.readTensor(outputTensor)); 98 assert_array_equals(result, Float32Array.from([5.0])); 99 }, 'scalar constant created with constant(type, value) - float32'); 100 101 promise_test(async () => { 102 const builder = new MLGraphBuilder(mlContext); 103 const inputOperand = builder.input('input', {dataType: 'int32', shape: []}); 104 const constantOperand = builder.constant('int32', 42); 105 const mulOperand = builder.mul(inputOperand, constantOperand); 106 107 const [inputTensor, outputTensor, mlGraph] = await Promise.all([ 108 mlContext.createTensor({dataType: 'int32', shape: [], writable: true}), 109 mlContext.createTensor({dataType: 'int32', shape: [], readable: true}), 110 builder.build({'output': mulOperand}) 111 ]); 112 113 mlContext.writeTensor(inputTensor, Int32Array.from([3])); 114 mlContext.dispatch(mlGraph, {'input': inputTensor}, {'output': outputTensor}); 115 assert_array_equals( 116 new Int32Array(await mlContext.readTensor(outputTensor)), 117 Int32Array.from([126])); 118 }, 'scalar constant created with constant(type, value) - int32'); 119 120 promise_test(async () => { 121 const builder = new MLGraphBuilder(mlContext); 122 const inputOperand = builder.input('input', {dataType: 'float16', shape: [3]}); 123 const zeroConstant = builder.constant('float16', 2.0); 124 const negativeConstant = builder.constant('float16', -1.0); 125 126 // Test complex expression: input * 2 + (-1.0) 127 const mulResult = builder.mul(inputOperand, zeroConstant); 128 const addResult = builder.add(mulResult, negativeConstant); 129 130 const [inputTensor, outputTensor, mlGraph] = await Promise.all([ 131 mlContext.createTensor({dataType: 'float16', shape: [3], writable: true}), 132 mlContext.createTensor({dataType: 'float16', shape: [3], readable: true}), 133 builder.build({'output': addResult}) 134 ]); 135 136 mlContext.writeTensor(inputTensor, Float16Array.from([1.0, 2.0, 3.0])); 137 mlContext.dispatch(mlGraph, {'input': inputTensor}, {'output': outputTensor}); 138 139 const result = new Float16Array(await mlContext.readTensor(outputTensor)); 140 assert_array_equals(result, Float16Array.from([1.0, 3.0, 5.0])); 141 }, 'multiple scalar constants in expression - float16');