conv2d.https.any.js (17520B)
1 // META: title=validation tests for WebNN API conv2d operation 2 // META: global=window 3 // META: variant=?cpu 4 // META: variant=?gpu 5 // META: variant=?npu 6 // META: script=../resources/utils_validation.js 7 8 'use strict'; 9 10 // Example input in NCHW layout. 11 const kExampleInputDescriptor = { 12 dataType: 'float32', 13 shape: [1, 1, 5, 5] 14 }; 15 // Example filter in OIHW layout. 16 const kExampleFilterDescriptor = { 17 dataType: 'float32', 18 shape: [1, 1, 3, 3] 19 }; 20 const kExampleBiasDescriptor = { 21 dataType: 'float32', 22 shape: [/* output channels */ 1] 23 }; 24 const label = `conv_2d_*`; 25 26 multi_builder_test(async (t, builder, otherBuilder) => { 27 const inputFromOtherBuilder = 28 otherBuilder.input('input', kExampleInputDescriptor); 29 30 const filter = builder.input('filter', kExampleFilterDescriptor); 31 assert_throws_js( 32 TypeError, () => builder.conv2d(inputFromOtherBuilder, filter)); 33 }, '[conv2d] throw if input is from another builder'); 34 35 multi_builder_test(async (t, builder, otherBuilder) => { 36 const filterFromOtherBuilder = 37 otherBuilder.input('filter', kExampleFilterDescriptor); 38 39 const input = builder.input('input', kExampleInputDescriptor); 40 assert_throws_js( 41 TypeError, () => builder.conv2d(input, filterFromOtherBuilder)); 42 }, '[conv2d] throw if filter is from another builder'); 43 44 multi_builder_test(async (t, builder, otherBuilder) => { 45 const biasFromOtherBuilder = 46 otherBuilder.input('bias', kExampleBiasDescriptor); 47 const options = {inputLayout: 'nchw', bias: biasFromOtherBuilder}; 48 49 const input = builder.input('input', kExampleInputDescriptor); 50 const filter = builder.input('filter', kExampleFilterDescriptor); 51 assert_throws_js(TypeError, () => builder.conv2d(input, filter, options)); 52 }, '[conv2d] throw if bias option is from another builder'); 53 54 const tests = [ 55 { 56 name: '[conv2d] Test with default options.', 57 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 58 filter: {dataType: 'float32', shape: [1, 1, 3, 3]}, 59 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 60 }, 61 { 62 name: '[conv2d] Test with padding.', 63 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 64 filter: {dataType: 'float32', shape: [1, 1, 3, 3]}, 65 options: { 66 padding: [1, 1, 1, 1], 67 }, 68 output: {dataType: 'float32', shape: [1, 1, 5, 5]} 69 }, 70 { 71 name: '[conv2d] Test with strides and padding.', 72 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 73 filter: {dataType: 'float32', shape: [1, 1, 3, 3]}, 74 options: { 75 padding: [1, 1, 1, 1], 76 strides: [2, 2], 77 }, 78 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 79 }, 80 { 81 name: '[conv2d] Test with strides and asymmetric padding.', 82 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 83 filter: {dataType: 'float32', shape: [1, 1, 4, 2]}, 84 options: { 85 padding: [1, 2, 0, 1], 86 strides: [2, 2], 87 }, 88 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 89 }, 90 { 91 name: '[conv2d] Test depthwise conv2d by setting groups to input channels.', 92 input: {dataType: 'float32', shape: [1, 4, 2, 2]}, 93 filter: {dataType: 'float32', shape: [4, 1, 2, 2]}, 94 options: { 95 groups: 4, 96 }, 97 output: {dataType: 'float32', shape: [1, 4, 1, 1]} 98 }, 99 { 100 name: 101 '[conv2d] Test depthwise conv2d with groups, inputLayout="nhwc" and filterLayout="ihwo".', 102 input: {dataType: 'float32', shape: [1, 2, 2, 4]}, 103 filter: {dataType: 'float32', shape: [1, 2, 2, 4]}, 104 options: { 105 groups: 4, 106 inputLayout: 'nhwc', 107 filterLayout: 'ihwo', 108 }, 109 output: {dataType: 'float32', shape: [1, 1, 1, 4]} 110 }, 111 { 112 name: 113 '[conv2d] Test with dilations, inputLayout="nhwc" and filterLayout="ihwo".', 114 input: {dataType: 'float32', shape: [1, 65, 65, 1]}, 115 filter: {dataType: 'float32', shape: [1, 3, 3, 1]}, 116 options: { 117 inputLayout: 'nhwc', 118 filterLayout: 'ihwo', 119 dilations: [4, 4], 120 }, 121 output: {dataType: 'float32', shape: [1, 57, 57, 1]} 122 }, 123 { 124 name: '[conv2d] Test with inputLayout="nchw" and filterLayout="oihw".', 125 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 126 filter: {dataType: 'float32', shape: [1, 2, 3, 3]}, 127 options: { 128 inputLayout: 'nchw', 129 filterLayout: 'oihw', 130 }, 131 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 132 }, 133 { 134 name: '[conv2d] Test with inputLayout="nchw" and filterLayout="hwio".', 135 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 136 filter: {dataType: 'float32', shape: [3, 3, 2, 1]}, 137 options: { 138 inputLayout: 'nchw', 139 filterLayout: 'hwio', 140 }, 141 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 142 }, 143 { 144 name: '[conv2d] Test with inputLayout="nchw" and filterLayout="ohwi".', 145 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 146 filter: {dataType: 'float32', shape: [1, 3, 3, 2]}, 147 options: { 148 inputLayout: 'nchw', 149 filterLayout: 'ohwi', 150 }, 151 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 152 }, 153 { 154 name: '[conv2d] Test with inputLayout="nchw" and filterLayout="ihwo".', 155 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 156 filter: {dataType: 'float32', shape: [2, 3, 3, 1]}, 157 options: { 158 inputLayout: 'nchw', 159 filterLayout: 'ihwo', 160 }, 161 output: {dataType: 'float32', shape: [1, 1, 3, 3]} 162 }, 163 { 164 name: '[conv2d] Test with inputLayout="nhwc" and filterLayout="oihw".', 165 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 166 filter: {dataType: 'float32', shape: [1, 2, 3, 3]}, 167 options: { 168 inputLayout: 'nhwc', 169 filterLayout: 'oihw', 170 }, 171 output: {dataType: 'float32', shape: [1, 3, 3, 1]} 172 }, 173 { 174 name: '[conv2d] Test with inputLayout="nhwc" and filterLayout="hwio".', 175 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 176 filter: {dataType: 'float32', shape: [3, 3, 2, 1]}, 177 options: { 178 inputLayout: 'nhwc', 179 filterLayout: 'hwio', 180 }, 181 output: {dataType: 'float32', shape: [1, 3, 3, 1]} 182 }, 183 { 184 name: '[conv2d] Test with inputLayout="nhwc" and filterLayout="ohwi".', 185 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 186 filter: {dataType: 'float32', shape: [1, 3, 3, 2]}, 187 options: { 188 inputLayout: 'nhwc', 189 filterLayout: 'ohwi', 190 }, 191 output: {dataType: 'float32', shape: [1, 3, 3, 1]} 192 }, 193 { 194 name: '[conv2d] Test with inputLayout="nhwc" and filterLayout="ihwo".', 195 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 196 filter: {dataType: 'float32', shape: [2, 3, 3, 1]}, 197 options: { 198 inputLayout: 'nhwc', 199 filterLayout: 'ihwo', 200 }, 201 output: {dataType: 'float32', shape: [1, 3, 3, 1]} 202 }, 203 { 204 name: '[conv2d] Throw if the input is not a 4-D tensor.', 205 input: {dataType: 'float32', shape: [1, 5, 5]}, 206 filter: {dataType: 'float32', shape: [1, 2, 2, 1]}, 207 options: {label}, 208 }, 209 { 210 name: '[conv2d] Throw if the input data type is not floating point.', 211 input: {dataType: 'int32', shape: [1, 1, 5, 5]}, 212 filter: {dataType: 'int32', shape: [1, 1, 2, 2]}, 213 options: {label}, 214 }, 215 { 216 name: '[conv2d] Throw if the filter is not a 4-D tensor.', 217 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 218 filter: {dataType: 'float32', shape: [2, 2]}, 219 options: {label}, 220 }, 221 { 222 name: 223 '[conv2d] Throw if the filter data type doesn\'t match the input data type.', 224 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 225 filter: {dataType: 'int32', shape: [1, 1, 2, 2]}, 226 options: { 227 label: label, 228 }, 229 }, 230 { 231 name: '[conv2d] Throw if the length of padding is not 4.', 232 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 233 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 234 options: { 235 padding: [2, 2], 236 label: label, 237 }, 238 }, 239 { 240 name: '[conv2d] Throw if the length of strides is not 2.', 241 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 242 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 243 options: { 244 strides: [2], 245 label: label, 246 }, 247 }, 248 { 249 name: '[conv2d] Throw if strideHeight is smaller than 1.', 250 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 251 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 252 options: { 253 strides: [0, 1], 254 label: label, 255 }, 256 }, 257 { 258 name: '[conv2d] Throw if strideWidth is smaller than 1.', 259 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 260 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 261 options: { 262 strides: [1, 0], 263 label: label, 264 }, 265 }, 266 { 267 name: '[conv2d] Throw if the length of dilations is not 2.', 268 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 269 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 270 options: { 271 dilations: [1], 272 label: label, 273 }, 274 }, 275 { 276 name: '[conv2d] Throw if dilationHeight is smaller than 1.', 277 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 278 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 279 options: { 280 dilations: [0, 1], 281 label: label, 282 }, 283 }, 284 { 285 name: '[conv2d] Throw if dilationWidth is smaller than 1.', 286 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 287 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 288 options: { 289 dilations: [1, 0], 290 label: label, 291 }, 292 }, 293 { 294 name: '[conv2d] Throw if inputChannels % groups is not 0.', 295 input: {dataType: 'float32', shape: [1, 4, 5, 5]}, 296 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 297 options: { 298 groups: 3, 299 label: label, 300 }, 301 }, 302 { 303 name: 304 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels.', 305 input: {dataType: 'float32', shape: [1, 4, 5, 5]}, 306 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 307 options: { 308 groups: 2, 309 label: label, 310 }, 311 }, 312 { 313 name: '[conv2d] Throw if the groups is smaller than 1.', 314 input: {dataType: 'float32', shape: [1, 4, 5, 5]}, 315 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 316 options: { 317 groups: 0, 318 label: label, 319 }, 320 }, 321 { 322 name: 323 '[conv2d] Throw due to overflow when calculating the effective filter height.', 324 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 325 filter: {dataType: 'float32', shape: [1, 1, 434983, 2]}, 326 options: { 327 dilations: [328442, 1], 328 label: label, 329 }, 330 }, 331 { 332 name: 333 '[conv2d] Throw due to overflow when calculating the effective filter width.', 334 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 335 filter: {dataType: 'float32', shape: [1, 1, 2, 234545]}, 336 options: { 337 dilations: [2, 843452], 338 label: label, 339 }, 340 }, 341 { 342 name: '[conv2d] Throw due to overflow when dilation height is too large.', 343 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 344 filter: {dataType: 'float32', shape: [1, 1, 3, 3]}, 345 options: { 346 dilations: [kMaxUnsignedLong, 1], 347 label: label, 348 }, 349 }, 350 { 351 name: '[conv2d] Throw due to overflow when dilation width is too large.', 352 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 353 filter: {dataType: 'float32', shape: [1, 1, 3, 3]}, 354 options: { 355 dilations: [1, kMaxUnsignedLong], 356 label: label, 357 }, 358 }, 359 { 360 name: '[conv2d] Throw due to underflow when calculating the output height.', 361 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 362 filter: {dataType: 'float32', shape: [1, 1, 4, 2]}, 363 options: { 364 dilations: [4, 1], 365 padding: [1, 1, 1, 1], 366 strides: [2, 2], 367 label: label, 368 }, 369 }, 370 { 371 name: '[conv2d] Throw due to underflow when calculating the output width.', 372 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 373 filter: {dataType: 'float32', shape: [1, 1, 2, 8]}, 374 options: { 375 dilations: [1, 4], 376 padding: [1, 1, 1, 1], 377 strides: [2, 2], 378 label: label, 379 }, 380 }, 381 { 382 name: '[conv2d] Throw if the bias is not a 1-D tensor.', 383 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 384 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 385 options: { 386 bias: {dataType: 'float32', shape: [1, 2]}, 387 label: label, 388 }, 389 }, 390 { 391 name: 392 '[conv2d] Throw if the bias shape is not equal to [output_channels] with filterLayout="oihw".', 393 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 394 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 395 options: { 396 bias: {dataType: 'float32', shape: [2]}, 397 label: label, 398 }, 399 }, 400 { 401 name: 402 '[conv2d] Throw if the bias shape is not equal to [output_channels] with filterLayout="hwio".', 403 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 404 filter: {dataType: 'float32', shape: [2, 2, 1, 1]}, 405 options: { 406 bias: {dataType: 'float32', shape: [2]}, 407 label: label, 408 }, 409 }, 410 { 411 name: 412 '[conv2d] Throw if the bias shape is not equal to [output_channels] with filterLayout="ohwi".', 413 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 414 filter: {dataType: 'float32', shape: [1, 2, 2, 1]}, 415 options: { 416 bias: {dataType: 'float32', shape: [2]}, 417 label: label, 418 }, 419 }, 420 { 421 name: 422 '[conv2d] Throw if the bias shape is not equal to [output_channels] with filterLayout="ihwo".', 423 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 424 filter: {dataType: 'float32', shape: [1, 2, 2, 1]}, 425 options: { 426 bias: {dataType: 'float32', shape: [2]}, 427 label: label, 428 }, 429 }, 430 { 431 name: 432 '[conv2d] Throw if the bias data type doesn\'t match input data type.', 433 input: {dataType: 'float32', shape: [1, 1, 5, 5]}, 434 filter: {dataType: 'float32', shape: [1, 1, 2, 2]}, 435 options: { 436 bias: {dataType: 'int32', shape: [1]}, 437 label: label, 438 }, 439 }, 440 { 441 name: 442 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nchw" and filterLayout="oihw".', 443 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 444 filter: {dataType: 'float32', shape: [1, 2, 3, 3]}, 445 options: { 446 inputLayout: 'nchw', 447 filterLayout: 'oihw', 448 groups: 2, 449 label: label, 450 }, 451 }, 452 { 453 name: 454 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nchw" and filterLayout="hwio".', 455 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 456 filter: {dataType: 'float32', shape: [3, 3, 2, 1]}, 457 options: { 458 inputLayout: 'nchw', 459 filterLayout: 'hwio', 460 groups: 2, 461 label: label, 462 }, 463 }, 464 { 465 name: 466 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nchw" and filterLayout="ohwi".', 467 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 468 filter: {dataType: 'float32', shape: [1, 3, 3, 2]}, 469 options: { 470 inputLayout: 'nchw', 471 filterLayout: 'ohwi', 472 groups: 2, 473 label: label, 474 }, 475 }, 476 { 477 name: 478 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nchw" and filterLayout="ihwo".', 479 input: {dataType: 'float32', shape: [1, 2, 5, 5]}, 480 filter: {dataType: 'float32', shape: [2, 3, 3, 1]}, 481 options: { 482 inputLayout: 'nchw', 483 filterLayout: 'ihwo', 484 groups: 2, 485 label: label, 486 }, 487 488 }, 489 { 490 name: 491 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nhwc" and filterLayout="oihw".', 492 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 493 filter: {dataType: 'float32', shape: [1, 2, 3, 3]}, 494 options: { 495 inputLayout: 'nhwc', 496 filterLayout: 'oihw', 497 groups: 2, 498 label: label, 499 }, 500 }, 501 { 502 name: 503 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nhwc" and filterLayout="hwio".', 504 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 505 filter: {dataType: 'float32', shape: [3, 3, 2, 1]}, 506 options: { 507 inputLayout: 'nhwc', 508 filterLayout: 'hwio', 509 groups: 2, 510 label: label, 511 }, 512 }, 513 { 514 name: 515 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nhwc" and filterLayout="ohwi".', 516 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 517 filter: {dataType: 'float32', shape: [1, 3, 3, 2]}, 518 options: { 519 inputLayout: 'nhwc', 520 filterLayout: 'ohwi', 521 groups: 2, 522 label: label, 523 }, 524 }, 525 { 526 name: 527 '[conv2d] Throw if inputChannels / groups is not equal to filterInputChannels with inputLayout="nhwc" and filterLayout="ihwo".', 528 input: {dataType: 'float32', shape: [1, 5, 5, 2]}, 529 filter: {dataType: 'float32', shape: [2, 3, 3, 1]}, 530 options: { 531 inputLayout: 'nhwc', 532 filterLayout: 'ihwo', 533 groups: 2, 534 label: label, 535 }, 536 }, 537 ]; 538 539 tests.forEach( 540 test => promise_test(async t => { 541 const builder = new MLGraphBuilder(context); 542 const input = builder.input('input', test.input); 543 const filter = builder.input('filter', test.filter); 544 545 if (test.options && test.options.bias) { 546 test.options.bias = builder.input('bias', test.options.bias); 547 } 548 549 if (test.output && 550 context.opSupportLimits().conv2d.input.dataTypes.includes( 551 test.input.dataType)) { 552 const output = builder.conv2d(input, filter, test.options); 553 assert_equals(output.dataType, test.output.dataType); 554 assert_array_equals(output.shape, test.output.shape); 555 } else { 556 const regrexp = /\[conv_2d_\*\]/; 557 assert_throws_with_label( 558 () => builder.conv2d(input, filter, test.options), regrexp); 559 } 560 }, test.name));