RemoveDynamicIndexing.cpp (24004B)
1 // 2 // Copyright 2002 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and 7 // matrices, replacing them with calls to functions that choose which component to return or write. 8 // We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset 9 // of RWByteAddressBuffer. 10 // 11 12 #include "compiler/translator/tree_ops/RemoveDynamicIndexing.h" 13 14 #include "compiler/translator/Compiler.h" 15 #include "compiler/translator/Diagnostics.h" 16 #include "compiler/translator/InfoSink.h" 17 #include "compiler/translator/StaticType.h" 18 #include "compiler/translator/SymbolTable.h" 19 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" 20 #include "compiler/translator/tree_util/IntermNode_util.h" 21 #include "compiler/translator/tree_util/IntermTraverse.h" 22 23 namespace sh 24 { 25 26 namespace 27 { 28 29 using DynamicIndexingNodeMatcher = std::function<bool(TIntermBinary *)>; 30 31 const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqParamIn, 1, 1>(); 32 33 constexpr const ImmutableString kBaseName("base"); 34 constexpr const ImmutableString kIndexName("index"); 35 constexpr const ImmutableString kValueName("value"); 36 37 std::string GetIndexFunctionName(const TType &type, bool write) 38 { 39 TInfoSinkBase nameSink; 40 nameSink << "dyn_index_"; 41 if (write) 42 { 43 nameSink << "write_"; 44 } 45 if (type.isMatrix()) 46 { 47 nameSink << "mat" << static_cast<uint32_t>(type.getCols()) << "x" 48 << static_cast<uint32_t>(type.getRows()); 49 } 50 else 51 { 52 switch (type.getBasicType()) 53 { 54 case EbtInt: 55 nameSink << "ivec"; 56 break; 57 case EbtBool: 58 nameSink << "bvec"; 59 break; 60 case EbtUInt: 61 nameSink << "uvec"; 62 break; 63 case EbtFloat: 64 nameSink << "vec"; 65 break; 66 default: 67 UNREACHABLE(); 68 } 69 nameSink << static_cast<uint32_t>(type.getNominalSize()); 70 } 71 return nameSink.str(); 72 } 73 74 TIntermConstantUnion *CreateIntConstantNode(int i) 75 { 76 TConstantUnion *constant = new TConstantUnion(); 77 constant->setIConst(i); 78 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh)); 79 } 80 81 TIntermTyped *EnsureSignedInt(TIntermTyped *node) 82 { 83 if (node->getBasicType() == EbtInt) 84 return node; 85 86 TIntermSequence arguments; 87 arguments.push_back(node); 88 return TIntermAggregate::CreateConstructor(TType(EbtInt), &arguments); 89 } 90 91 TType *GetFieldType(const TType &indexedType) 92 { 93 TType *fieldType = new TType(indexedType); 94 if (indexedType.isMatrix()) 95 { 96 fieldType->toMatrixColumnType(); 97 } 98 else 99 { 100 ASSERT(indexedType.isVector()); 101 fieldType->toComponentType(); 102 } 103 // Default precision to highp if not specified. For example in |vec3(0)[i], i < 0|, there is no 104 // precision assigned to vec3(0). 105 if (fieldType->getPrecision() == EbpUndefined) 106 { 107 fieldType->setPrecision(EbpHigh); 108 } 109 return fieldType; 110 } 111 112 const TType *GetBaseType(const TType &type, bool write) 113 { 114 TType *baseType = new TType(type); 115 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't 116 // end up using mediump version of an indexing function for a highp value, if both mediump and 117 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in 118 // principle this code could be used with multiple backends. 119 baseType->setPrecision(EbpHigh); 120 baseType->setQualifier(EvqParamInOut); 121 if (!write) 122 baseType->setQualifier(EvqParamIn); 123 return baseType; 124 } 125 126 // Generate a read or write function for one field in a vector/matrix. 127 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range 128 // indices in other places. 129 // Note that indices can be either int or uint. We create only int versions of the functions, 130 // and convert uint indices to int at the call site. 131 // read function example: 132 // float dyn_index_vec2(in vec2 base, in int index) 133 // { 134 // switch(index) 135 // { 136 // case (0): 137 // return base[0]; 138 // case (1): 139 // return base[1]; 140 // default: 141 // break; 142 // } 143 // if (index < 0) 144 // return base[0]; 145 // return base[1]; 146 // } 147 // write function example: 148 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value) 149 // { 150 // switch(index) 151 // { 152 // case (0): 153 // base[0] = value; 154 // return; 155 // case (1): 156 // base[1] = value; 157 // return; 158 // default: 159 // break; 160 // } 161 // if (index < 0) 162 // { 163 // base[0] = value; 164 // return; 165 // } 166 // base[1] = value; 167 // } 168 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. 169 TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type, 170 bool write, 171 const TFunction &func, 172 TSymbolTable *symbolTable) 173 { 174 ASSERT(!type.isArray()); 175 176 uint8_t numCases = 0; 177 if (type.isMatrix()) 178 { 179 numCases = type.getCols(); 180 } 181 else 182 { 183 numCases = type.getNominalSize(); 184 } 185 186 std::string functionName = GetIndexFunctionName(type, write); 187 TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func); 188 189 TIntermSymbol *baseParam = new TIntermSymbol(func.getParam(0)); 190 TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1)); 191 TIntermSymbol *valueParam = nullptr; 192 if (write) 193 { 194 valueParam = new TIntermSymbol(func.getParam(2)); 195 } 196 197 TIntermBlock *statementList = new TIntermBlock(); 198 for (uint8_t i = 0; i < numCases; ++i) 199 { 200 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i)); 201 statementList->getSequence()->push_back(caseNode); 202 203 TIntermBinary *indexNode = 204 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i)); 205 if (write) 206 { 207 TIntermBinary *assignNode = 208 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy()); 209 statementList->getSequence()->push_back(assignNode); 210 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); 211 statementList->getSequence()->push_back(returnNode); 212 } 213 else 214 { 215 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode); 216 statementList->getSequence()->push_back(returnNode); 217 } 218 } 219 220 // Default case 221 TIntermCase *defaultNode = new TIntermCase(nullptr); 222 statementList->getSequence()->push_back(defaultNode); 223 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr); 224 statementList->getSequence()->push_back(breakNode); 225 226 TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList); 227 228 TIntermBlock *bodyNode = new TIntermBlock(); 229 bodyNode->getSequence()->push_back(switchNode); 230 231 TIntermBinary *cond = 232 new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0)); 233 234 // Two blocks: one accesses (either reads or writes) the first element and returns, 235 // the other accesses the last element. 236 TIntermBlock *useFirstBlock = new TIntermBlock(); 237 TIntermBlock *useLastBlock = new TIntermBlock(); 238 TIntermBinary *indexFirstNode = 239 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0)); 240 TIntermBinary *indexLastNode = 241 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1)); 242 if (write) 243 { 244 TIntermBinary *assignFirstNode = 245 new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy()); 246 useFirstBlock->getSequence()->push_back(assignFirstNode); 247 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); 248 useFirstBlock->getSequence()->push_back(returnNode); 249 250 TIntermBinary *assignLastNode = 251 new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy()); 252 useLastBlock->getSequence()->push_back(assignLastNode); 253 } 254 else 255 { 256 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode); 257 useFirstBlock->getSequence()->push_back(returnFirstNode); 258 259 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode); 260 useLastBlock->getSequence()->push_back(returnLastNode); 261 } 262 TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr); 263 bodyNode->getSequence()->push_back(ifNode); 264 bodyNode->getSequence()->push_back(useLastBlock); 265 266 TIntermFunctionDefinition *indexingFunction = 267 new TIntermFunctionDefinition(prototypeNode, bodyNode); 268 return indexingFunction; 269 } 270 271 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser 272 { 273 public: 274 RemoveDynamicIndexingTraverser(DynamicIndexingNodeMatcher &&matcher, 275 TSymbolTable *symbolTable, 276 PerformanceDiagnostics *perfDiagnostics); 277 278 bool visitBinary(Visit visit, TIntermBinary *node) override; 279 280 void insertHelperDefinitions(TIntermNode *root); 281 282 void nextIteration(); 283 284 bool usedTreeInsertion() const { return mUsedTreeInsertion; } 285 286 protected: 287 // Maps of types that are indexed to the indexing function ids used for them. Note that these 288 // can not store multiple variants of the same type with different precisions - only one 289 // precision gets stored. 290 std::map<TType, TFunction *> mIndexedVecAndMatrixTypes; 291 std::map<TType, TFunction *> mWrittenVecAndMatrixTypes; 292 293 bool mUsedTreeInsertion; 294 295 // When true, the traverser will remove side effects from any indexing expression. 296 // This is done so that in code like 297 // V[j++][i]++. 298 // where V is an array of vectors, j++ will only be evaluated once. 299 bool mRemoveIndexSideEffectsInSubtree; 300 301 DynamicIndexingNodeMatcher mMatcher; 302 PerformanceDiagnostics *mPerfDiagnostics; 303 }; 304 305 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser( 306 DynamicIndexingNodeMatcher &&matcher, 307 TSymbolTable *symbolTable, 308 PerformanceDiagnostics *perfDiagnostics) 309 : TLValueTrackingTraverser(true, false, false, symbolTable), 310 mUsedTreeInsertion(false), 311 mRemoveIndexSideEffectsInSubtree(false), 312 mMatcher(matcher), 313 mPerfDiagnostics(perfDiagnostics) 314 {} 315 316 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root) 317 { 318 TIntermBlock *rootBlock = root->getAsBlock(); 319 ASSERT(rootBlock != nullptr); 320 TIntermSequence insertions; 321 for (auto &type : mIndexedVecAndMatrixTypes) 322 { 323 insertions.push_back( 324 GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable)); 325 } 326 for (auto &type : mWrittenVecAndMatrixTypes) 327 { 328 insertions.push_back( 329 GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable)); 330 } 331 rootBlock->insertChildNodes(0, insertions); 332 } 333 334 // Create a call to dyn_index_*() based on an indirect indexing op node 335 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node, 336 TIntermTyped *index, 337 TFunction *indexingFunction) 338 { 339 ASSERT(node->getOp() == EOpIndexIndirect); 340 TIntermSequence arguments; 341 arguments.push_back(node->getLeft()); 342 arguments.push_back(index); 343 344 TIntermAggregate *indexingCall = 345 TIntermAggregate::CreateFunctionCall(*indexingFunction, &arguments); 346 indexingCall->setLine(node->getLine()); 347 return indexingCall; 348 } 349 350 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node, 351 TVariable *index, 352 TVariable *writtenValue, 353 TFunction *indexedWriteFunction) 354 { 355 ASSERT(node->getOp() == EOpIndexIndirect); 356 TIntermSequence arguments; 357 // Deep copy the child nodes so that two pointers to the same node don't end up in the tree. 358 arguments.push_back(node->getLeft()->deepCopy()); 359 arguments.push_back(CreateTempSymbolNode(index)); 360 arguments.push_back(CreateTempSymbolNode(writtenValue)); 361 362 TIntermAggregate *indexedWriteCall = 363 TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, &arguments); 364 indexedWriteCall->setLine(node->getLine()); 365 return indexedWriteCall; 366 } 367 368 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node) 369 { 370 if (mUsedTreeInsertion) 371 return false; 372 373 if (node->getOp() == EOpIndexIndirect) 374 { 375 if (mRemoveIndexSideEffectsInSubtree) 376 { 377 ASSERT(node->getRight()->hasSideEffects()); 378 // In case we're just removing index side effects, convert 379 // v_expr[index_expr] 380 // to this: 381 // int s0 = index_expr; v_expr[s0]; 382 // Now v_expr[s0] can be safely executed several times without unintended side effects. 383 TIntermDeclaration *indexVariableDeclaration = nullptr; 384 TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(), 385 EvqTemporary, &indexVariableDeclaration); 386 insertStatementInParentBlock(indexVariableDeclaration); 387 mUsedTreeInsertion = true; 388 389 // Replace the index with the temp variable 390 TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable); 391 queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED); 392 } 393 else if (mMatcher(node)) 394 { 395 if (mPerfDiagnostics) 396 { 397 mPerfDiagnostics->warning(node->getLine(), 398 "Performance: dynamic indexing of vectors and " 399 "matrices is emulated and can be slow.", 400 "[]"); 401 } 402 bool write = isLValueRequiredHere(); 403 404 #if defined(ANGLE_ENABLE_ASSERTS) 405 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently 406 // implemented checks in this traverser. 407 IntermNodePatternMatcher matcher( 408 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue); 409 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write); 410 #endif 411 412 const TType &type = node->getLeft()->getType(); 413 ImmutableString indexingFunctionName(GetIndexFunctionName(type, false)); 414 TFunction *indexingFunction = nullptr; 415 if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end()) 416 { 417 indexingFunction = 418 new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal, 419 GetFieldType(type), true); 420 indexingFunction->addParameter(new TVariable( 421 mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal)); 422 indexingFunction->addParameter( 423 new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal)); 424 mIndexedVecAndMatrixTypes[type] = indexingFunction; 425 } 426 else 427 { 428 indexingFunction = mIndexedVecAndMatrixTypes[type]; 429 } 430 431 if (write) 432 { 433 // Convert: 434 // v_expr[index_expr]++; 435 // to this: 436 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++; 437 // dyn_index_write(v_expr, s0, s1); 438 // This works even if index_expr has some side effects. 439 if (node->getLeft()->hasSideEffects()) 440 { 441 // If v_expr has side effects, those need to be removed before proceeding. 442 // Otherwise the side effects of v_expr would be evaluated twice. 443 // The only case where an l-value can have side effects is when it is 444 // indexing. For example, it can be V[j++] where V is an array of vectors. 445 mRemoveIndexSideEffectsInSubtree = true; 446 return true; 447 } 448 449 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode(); 450 if (leftBinary != nullptr && mMatcher(leftBinary)) 451 { 452 // This is a case like: 453 // mat2 m; 454 // m[a][b]++; 455 // Process the child node m[a] first. 456 return true; 457 } 458 459 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value 460 // only writes it and doesn't need the previous value. http://anglebug.com/1116 461 462 TFunction *indexedWriteFunction = nullptr; 463 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end()) 464 { 465 ImmutableString functionName( 466 GetIndexFunctionName(node->getLeft()->getType(), true)); 467 indexedWriteFunction = 468 new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal, 469 StaticType::GetBasic<EbtVoid, EbpUndefined>(), false); 470 indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName, 471 GetBaseType(type, true), 472 SymbolType::AngleInternal)); 473 indexedWriteFunction->addParameter(new TVariable( 474 mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal)); 475 TType *valueType = GetFieldType(type); 476 valueType->setQualifier(EvqParamIn); 477 indexedWriteFunction->addParameter(new TVariable( 478 mSymbolTable, kValueName, static_cast<const TType *>(valueType), 479 SymbolType::AngleInternal)); 480 mWrittenVecAndMatrixTypes[type] = indexedWriteFunction; 481 } 482 else 483 { 484 indexedWriteFunction = mWrittenVecAndMatrixTypes[type]; 485 } 486 487 TIntermSequence insertionsBefore; 488 TIntermSequence insertionsAfter; 489 490 // Store the index in a temporary signed int variable. 491 // s0 = index_expr; 492 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight()); 493 TIntermDeclaration *indexVariableDeclaration = nullptr; 494 TVariable *indexVariable = DeclareTempVariable( 495 mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration); 496 insertionsBefore.push_back(indexVariableDeclaration); 497 498 // s1 = dyn_index(v_expr, s0); 499 TIntermAggregate *indexingCall = CreateIndexFunctionCall( 500 node, CreateTempSymbolNode(indexVariable), indexingFunction); 501 TIntermDeclaration *fieldVariableDeclaration = nullptr; 502 TVariable *fieldVariable = DeclareTempVariable( 503 mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration); 504 insertionsBefore.push_back(fieldVariableDeclaration); 505 506 // dyn_index_write(v_expr, s0, s1); 507 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall( 508 node, indexVariable, fieldVariable, indexedWriteFunction); 509 insertionsAfter.push_back(indexedWriteCall); 510 insertStatementsInParentBlock(insertionsBefore, insertionsAfter); 511 512 // replace the node with s1 513 queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED); 514 mUsedTreeInsertion = true; 515 } 516 else 517 { 518 // The indexed value is not being written, so we can simply convert 519 // v_expr[index_expr] 520 // into 521 // dyn_index(v_expr, index_expr) 522 // If the index_expr is unsigned, we'll convert it to signed. 523 ASSERT(!mRemoveIndexSideEffectsInSubtree); 524 TIntermAggregate *indexingCall = CreateIndexFunctionCall( 525 node, EnsureSignedInt(node->getRight()), indexingFunction); 526 queueReplacement(indexingCall, OriginalNode::IS_DROPPED); 527 } 528 } 529 } 530 return !mUsedTreeInsertion; 531 } 532 533 void RemoveDynamicIndexingTraverser::nextIteration() 534 { 535 mUsedTreeInsertion = false; 536 mRemoveIndexSideEffectsInSubtree = false; 537 } 538 539 bool RemoveDynamicIndexingIf(DynamicIndexingNodeMatcher &&matcher, 540 TCompiler *compiler, 541 TIntermNode *root, 542 TSymbolTable *symbolTable, 543 PerformanceDiagnostics *perfDiagnostics) 544 { 545 // This transformation adds function declarations after the fact and so some validation is 546 // momentarily disabled. 547 bool enableValidateFunctionCall = compiler->disableValidateFunctionCall(); 548 549 RemoveDynamicIndexingTraverser traverser(std::move(matcher), symbolTable, perfDiagnostics); 550 do 551 { 552 traverser.nextIteration(); 553 root->traverse(&traverser); 554 if (!traverser.updateTree(compiler, root)) 555 { 556 return false; 557 } 558 } while (traverser.usedTreeInsertion()); 559 // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle 560 // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are 561 // function call nodes with no corresponding definition nodes. This needs special handling in 562 // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a 563 // superficial reading of the code. 564 traverser.insertHelperDefinitions(root); 565 566 compiler->restoreValidateFunctionCall(enableValidateFunctionCall); 567 return compiler->validateAST(root); 568 } 569 570 } // namespace 571 572 [[nodiscard]] bool RemoveDynamicIndexingOfNonSSBOVectorOrMatrix( 573 TCompiler *compiler, 574 TIntermNode *root, 575 TSymbolTable *symbolTable, 576 PerformanceDiagnostics *perfDiagnostics) 577 { 578 DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) { 579 return IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node); 580 }; 581 return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable, 582 perfDiagnostics); 583 } 584 585 [[nodiscard]] bool RemoveDynamicIndexingOfSwizzledVector(TCompiler *compiler, 586 TIntermNode *root, 587 TSymbolTable *symbolTable, 588 PerformanceDiagnostics *perfDiagnostics) 589 { 590 DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) { 591 return IntermNodePatternMatcher::IsDynamicIndexingOfSwizzledVector(node); 592 }; 593 return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable, 594 perfDiagnostics); 595 } 596 597 } // namespace sh