FindPreciseNodes.cpp (22619B)
1 // 2 // Copyright 2021 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // FindPreciseNodes.cpp: Propagates |precise| to AST nodes. 7 // 8 // The high level algorithm is as follows. For every node that "assigns" to a precise object, 9 // subobject (a precise struct whose field is being assigned) or superobject (a struct with a 10 // precise field), two things happen: 11 // 12 // - The operation is marked precise if it's an arithmetic operation 13 // - The right hand side of the assignment is made precise. If only a subobject is precise, only 14 // the corresponding subobject of the right hand side is made precise. 15 // 16 17 #include "compiler/translator/tree_util/FindPreciseNodes.h" 18 19 #include "common/hash_utils.h" 20 #include "compiler/translator/Compiler.h" 21 #include "compiler/translator/IntermNode.h" 22 #include "compiler/translator/Symbol.h" 23 #include "compiler/translator/tree_util/IntermTraverse.h" 24 25 namespace sh 26 { 27 28 namespace 29 { 30 31 // An access chain applied to a variable. The |precise|-ness of a node does not change when 32 // indexing arrays, selecting matrix columns or swizzle vectors. This access chain thus only 33 // includes block field selections. The access chain is used to identify the part of an object 34 // that is or should be |precise|. If both a.b.c and a.b are precise, only a.b is every considered. 35 class AccessChain 36 { 37 public: 38 AccessChain() = default; 39 40 bool operator==(const AccessChain &other) const { return mChain == other.mChain; } 41 42 const TVariable *build(TIntermTyped *lvalue); 43 44 const TVector<size_t> &getChain() const { return mChain; } 45 46 void reduceChain(size_t newSize) 47 { 48 ASSERT(newSize <= mChain.size()); 49 mChain.resize(newSize); 50 } 51 void clear() { reduceChain(0); } 52 void push_back(size_t index) { mChain.push_back(index); } 53 void pop_front(size_t n); 54 void append(const AccessChain &other) 55 { 56 mChain.insert(mChain.end(), other.mChain.begin(), other.mChain.end()); 57 } 58 bool removePrefix(const AccessChain &other); 59 60 private: 61 TVector<size_t> mChain; 62 }; 63 64 bool IsIndexOp(TOperator op) 65 { 66 switch (op) 67 { 68 case EOpIndexDirect: 69 case EOpIndexDirectStruct: 70 case EOpIndexDirectInterfaceBlock: 71 case EOpIndexIndirect: 72 return true; 73 default: 74 return false; 75 } 76 } 77 78 const TVariable *AccessChain::build(TIntermTyped *lvalue) 79 { 80 if (lvalue->getAsSwizzleNode()) 81 { 82 return build(lvalue->getAsSwizzleNode()->getOperand()); 83 } 84 if (lvalue->getAsSymbolNode()) 85 { 86 const TVariable *var = &lvalue->getAsSymbolNode()->variable(); 87 88 // For fields of nameless interface blocks, add the field index too. 89 if (var->getType().getInterfaceBlock() != nullptr) 90 { 91 mChain.push_back(var->getType().getInterfaceBlockFieldIndex()); 92 } 93 94 return var; 95 } 96 TIntermBinary *binary = lvalue->getAsBinaryNode(); 97 ASSERT(binary); 98 99 TOperator op = binary->getOp(); 100 ASSERT(IsIndexOp(op)); 101 102 const TVariable *var = build(binary->getLeft()); 103 104 if (op == EOpIndexDirectStruct || op == EOpIndexDirectInterfaceBlock) 105 { 106 int fieldIndex = binary->getRight()->getAsConstantUnion()->getIConst(0); 107 mChain.push_back(fieldIndex); 108 } 109 110 return var; 111 } 112 113 void AccessChain::pop_front(size_t n) 114 { 115 std::rotate(mChain.begin(), mChain.begin() + n, mChain.end()); 116 reduceChain(mChain.size() - n); 117 } 118 119 bool AccessChain::removePrefix(const AccessChain &other) 120 { 121 // First, make sure the common part of the two access chains match. 122 size_t commonSize = std::min(mChain.size(), other.mChain.size()); 123 124 for (size_t index = 0; index < commonSize; ++index) 125 { 126 if (mChain[index] != other.mChain[index]) 127 { 128 return false; 129 } 130 } 131 132 // Remove the common part from the access chain. If other is a deeper access chain, this access 133 // chain will become empty. 134 pop_front(commonSize); 135 136 return true; 137 } 138 139 AccessChain GetAssignmentAccessChain(TIntermOperator *node) 140 { 141 // The assignment is either a unary or a binary node, and the lvalue is always the first child. 142 AccessChain lvalueAccessChain; 143 lvalueAccessChain.build(node->getChildNode(0)->getAsTyped()); 144 return lvalueAccessChain; 145 } 146 147 template <typename Traverser> 148 void TraverseIndexNodesOnly(TIntermNode *node, Traverser *traverser) 149 { 150 if (node->getAsSwizzleNode()) 151 { 152 node = node->getAsSwizzleNode()->getOperand(); 153 } 154 155 if (node->getAsSymbolNode()) 156 { 157 return; 158 } 159 160 TIntermBinary *binary = node->getAsBinaryNode(); 161 ASSERT(binary); 162 163 TOperator op = binary->getOp(); 164 ASSERT(IsIndexOp(op)); 165 166 if (op == EOpIndexIndirect) 167 { 168 binary->getRight()->traverse(traverser); 169 } 170 171 TraverseIndexNodesOnly(binary->getLeft(), traverser); 172 } 173 174 // An object, which could be a sub-object of a variable. 175 struct ObjectAndAccessChain 176 { 177 const TVariable *variable; 178 AccessChain accessChain; 179 }; 180 181 bool operator==(const ObjectAndAccessChain &a, const ObjectAndAccessChain &b) 182 { 183 return a.variable == b.variable && a.accessChain == b.accessChain; 184 } 185 186 struct ObjectAndAccessChainHash 187 { 188 size_t operator()(const ObjectAndAccessChain &object) const 189 { 190 size_t result = angle::ComputeGenericHash(&object.variable, sizeof(object.variable)); 191 if (!object.accessChain.getChain().empty()) 192 { 193 result = 194 result ^ angle::ComputeGenericHash(object.accessChain.getChain().data(), 195 object.accessChain.getChain().size() * 196 sizeof(object.accessChain.getChain()[0])); 197 } 198 return result; 199 } 200 }; 201 202 // A map from variables to AST nodes that modify them (i.e. nodes where IsAssignment(op)). 203 using VariableToAssignmentNodeMap = angle::HashMap<const TVariable *, TVector<TIntermOperator *>>; 204 // A set of |return| nodes from functions with a |precise| return value. 205 using PreciseReturnNodes = angle::HashSet<TIntermBranch *>; 206 // A set of precise objects that need processing, or have been processed. 207 using PreciseObjectSet = angle::HashSet<ObjectAndAccessChain, ObjectAndAccessChainHash>; 208 209 struct ASTInfo 210 { 211 // Generic information about the tree: 212 VariableToAssignmentNodeMap variableAssignmentNodeMap; 213 // Information pertaining to |precise| expressions: 214 PreciseReturnNodes preciseReturnNodes; 215 PreciseObjectSet preciseObjectsToProcess; 216 PreciseObjectSet preciseObjectsVisited; 217 }; 218 219 int GetObjectPreciseSubChainLength(const ObjectAndAccessChain &object) 220 { 221 const TType &type = object.variable->getType(); 222 223 if (type.isPrecise()) 224 { 225 return 0; 226 } 227 228 const TFieldListCollection *block = type.getInterfaceBlock(); 229 if (block == nullptr) 230 { 231 block = type.getStruct(); 232 } 233 const TVector<size_t> &accessChain = object.accessChain.getChain(); 234 235 for (size_t length = 0; length < accessChain.size(); ++length) 236 { 237 ASSERT(block != nullptr); 238 239 const TField *field = block->fields()[accessChain[length]]; 240 if (field->type()->isPrecise()) 241 { 242 return static_cast<int>(length + 1); 243 } 244 245 block = field->type()->getStruct(); 246 } 247 248 return -1; 249 } 250 251 void AddPreciseObject(ASTInfo *info, const ObjectAndAccessChain &object) 252 { 253 if (info->preciseObjectsVisited.count(object) > 0) 254 { 255 return; 256 } 257 258 info->preciseObjectsToProcess.insert(object); 259 info->preciseObjectsVisited.insert(object); 260 } 261 262 void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object); 263 264 void AddObjectIfPrecise(ASTInfo *info, const ObjectAndAccessChain &object) 265 { 266 // See if the access chain is already precise, and if so add the minimum access chain that is 267 // precise. 268 int preciseSubChainLength = GetObjectPreciseSubChainLength(object); 269 if (preciseSubChainLength == -1) 270 { 271 // If the access chain is not precise, see if there are any fields of it that are precise, 272 // and add those individually. 273 AddPreciseSubObjects(info, object); 274 return; 275 } 276 277 ObjectAndAccessChain preciseObject = object; 278 preciseObject.accessChain.reduceChain(preciseSubChainLength); 279 280 AddPreciseObject(info, preciseObject); 281 } 282 283 void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object) 284 { 285 const TFieldListCollection *block = object.variable->getType().getInterfaceBlock(); 286 if (block == nullptr) 287 { 288 block = object.variable->getType().getStruct(); 289 } 290 const TVector<size_t> &accessChain = object.accessChain.getChain(); 291 292 for (size_t length = 0; length < accessChain.size(); ++length) 293 { 294 block = block->fields()[accessChain[length]]->type()->getStruct(); 295 } 296 297 if (block == nullptr) 298 { 299 return; 300 } 301 302 for (size_t fieldIndex = 0; fieldIndex < block->fields().size(); ++fieldIndex) 303 { 304 ObjectAndAccessChain subObject = object; 305 subObject.accessChain.push_back(fieldIndex); 306 307 // If the field is precise, add it as a precise subobject. Otherwise recurse. 308 if (block->fields()[fieldIndex]->type()->isPrecise()) 309 { 310 AddPreciseObject(info, subObject); 311 } 312 else 313 { 314 AddPreciseSubObjects(info, subObject); 315 } 316 } 317 } 318 319 bool IsArithmeticOp(TOperator op) 320 { 321 switch (op) 322 { 323 case EOpNegative: 324 325 case EOpPostIncrement: 326 case EOpPostDecrement: 327 case EOpPreIncrement: 328 case EOpPreDecrement: 329 330 case EOpAdd: 331 case EOpSub: 332 case EOpMul: 333 case EOpDiv: 334 case EOpIMod: 335 336 case EOpVectorTimesScalar: 337 case EOpVectorTimesMatrix: 338 case EOpMatrixTimesVector: 339 case EOpMatrixTimesScalar: 340 case EOpMatrixTimesMatrix: 341 342 case EOpAddAssign: 343 case EOpSubAssign: 344 345 case EOpMulAssign: 346 case EOpVectorTimesMatrixAssign: 347 case EOpVectorTimesScalarAssign: 348 case EOpMatrixTimesScalarAssign: 349 case EOpMatrixTimesMatrixAssign: 350 351 case EOpDivAssign: 352 case EOpIModAssign: 353 354 case EOpDot: 355 return true; 356 default: 357 return false; 358 } 359 } 360 361 // A traverser that gathers the following information, used to kick off processing: 362 // 363 // - For each variable, the AST nodes that modify it. 364 // - The set of |precise| return AST node. 365 // - The set of |precise| access chains assigned to. 366 // 367 class InfoGatherTraverser : public TIntermTraverser 368 { 369 public: 370 InfoGatherTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {} 371 372 bool visitUnary(Visit visit, TIntermUnary *node) override 373 { 374 // If the node is an assignment (i.e. ++ and --), store the relevant information. 375 if (!IsAssignment(node->getOp())) 376 { 377 return true; 378 } 379 380 visitLvalue(node, node->getOperand()); 381 return false; 382 } 383 384 bool visitBinary(Visit visit, TIntermBinary *node) override 385 { 386 if (IsAssignment(node->getOp())) 387 { 388 visitLvalue(node, node->getLeft()); 389 390 node->getRight()->traverse(this); 391 392 return false; 393 } 394 395 return true; 396 } 397 398 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override 399 { 400 const TIntermSequence &sequence = *(node->getSequence()); 401 TIntermSymbol *symbol = sequence.front()->getAsSymbolNode(); 402 TIntermBinary *initNode = sequence.front()->getAsBinaryNode(); 403 TIntermTyped *initExpression = nullptr; 404 405 if (symbol == nullptr) 406 { 407 ASSERT(initNode->getOp() == EOpInitialize); 408 409 symbol = initNode->getLeft()->getAsSymbolNode(); 410 initExpression = initNode->getRight(); 411 } 412 413 ASSERT(symbol); 414 ObjectAndAccessChain object = {&symbol->variable(), {}}; 415 AddObjectIfPrecise(mInfo, object); 416 417 if (initExpression) 418 { 419 mInfo->variableAssignmentNodeMap[object.variable].push_back(initNode); 420 421 // Visit the init expression, which may itself have assignments. 422 initExpression->traverse(this); 423 } 424 425 return false; 426 } 427 428 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override 429 { 430 mCurrentFunction = node->getFunction(); 431 432 for (size_t paramIndex = 0; paramIndex < mCurrentFunction->getParamCount(); ++paramIndex) 433 { 434 ObjectAndAccessChain param = {mCurrentFunction->getParam(paramIndex), {}}; 435 AddObjectIfPrecise(mInfo, param); 436 } 437 438 return true; 439 } 440 441 bool visitBranch(Visit visit, TIntermBranch *node) override 442 { 443 if (node->getFlowOp() == EOpReturn && node->getChildCount() == 1 && 444 mCurrentFunction->getReturnType().isPrecise()) 445 { 446 mInfo->preciseReturnNodes.insert(node); 447 } 448 449 return true; 450 } 451 452 bool visitGlobalQualifierDeclaration(Visit visit, 453 TIntermGlobalQualifierDeclaration *node) override 454 { 455 if (node->isPrecise()) 456 { 457 ObjectAndAccessChain preciseObject = {&node->getSymbol()->variable(), {}}; 458 AddPreciseObject(mInfo, preciseObject); 459 } 460 461 return false; 462 } 463 464 private: 465 void visitLvalue(TIntermOperator *assignmentNode, TIntermTyped *lvalueNode) 466 { 467 AccessChain lvalueChain; 468 const TVariable *lvalueBase = lvalueChain.build(lvalueNode); 469 mInfo->variableAssignmentNodeMap[lvalueBase].push_back(assignmentNode); 470 471 ObjectAndAccessChain lvalue = {lvalueBase, lvalueChain}; 472 AddObjectIfPrecise(mInfo, lvalue); 473 474 TraverseIndexNodesOnly(lvalueNode, this); 475 } 476 477 ASTInfo *mInfo = nullptr; 478 const TFunction *mCurrentFunction = nullptr; 479 }; 480 481 // A traverser that, given an access chain, traverses an expression and marks parts of it |precise|. 482 // For example, in the expression |Struct1(a, Struct2(b, c), d)|: 483 // 484 // - Given access chain [1], both |b| and |c| are marked precise. 485 // - Given access chain [1, 0], only |b| is marked precise. 486 // 487 // When access chain is empty, arithmetic nodes are marked |precise| and any access chains found in 488 // their children is recursively added for processing. 489 // 490 // The access chain given to the traverser is derived from the left hand side of an assignment, 491 // while the traverser is run on the right hand side. 492 class PropagatePreciseTraverser : public TIntermTraverser 493 { 494 public: 495 PropagatePreciseTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {} 496 497 void propagatePrecise(TIntermNode *expression, const AccessChain &accessChain) 498 { 499 mCurrentAccessChain = accessChain; 500 expression->traverse(this); 501 } 502 503 bool visitUnary(Visit visit, TIntermUnary *node) override 504 { 505 // Unary operations cannot be applied to structures. 506 ASSERT(mCurrentAccessChain.getChain().empty()); 507 508 // Mark arithmetic nodes as |precise|. 509 if (IsArithmeticOp(node->getOp())) 510 { 511 node->setIsPrecise(); 512 } 513 514 // Mark the operand itself |precise| too. 515 return true; 516 } 517 518 bool visitBinary(Visit visit, TIntermBinary *node) override 519 { 520 if (IsIndexOp(node->getOp())) 521 { 522 // Append the remaining access chain with that of the node, and mark that as |precise|. 523 // For example, if we are evaluating an expression and expecting to mark the access 524 // chain [1, 3] as |precise|, and the node itself has access chain [0, 2] applied to 525 // variable V, then what ends up being |precise| is V with access chain [0, 2, 1, 3]. 526 AccessChain nodeAccessChain; 527 const TVariable *baseVariable = nodeAccessChain.build(node); 528 nodeAccessChain.append(mCurrentAccessChain); 529 530 ObjectAndAccessChain preciseObject = {baseVariable, nodeAccessChain}; 531 AddPreciseObject(mInfo, preciseObject); 532 533 // Visit index nodes, each of which should be considered |precise| in its entirety. 534 mCurrentAccessChain.clear(); 535 TraverseIndexNodesOnly(node, this); 536 537 return false; 538 } 539 540 if (node->getOp() == EOpComma) 541 { 542 // For expr1,expr2, consider only expr2 as that's the one whose calculation is relevant. 543 node->getRight()->traverse(this); 544 return false; 545 } 546 547 // Mark arithmetic nodes as |precise|. 548 if (IsArithmeticOp(node->getOp())) 549 { 550 node->setIsPrecise(); 551 } 552 553 if (IsAssignment(node->getOp()) || node->getOp() == EOpInitialize) 554 { 555 // If the node itself is a[...] op= expr, consider only expr as |precise|, as that's the 556 // one whose calculation is significant. 557 node->getRight()->traverse(this); 558 559 // The indices used on the left hand side are also significant in their entirety. 560 mCurrentAccessChain.clear(); 561 TraverseIndexNodesOnly(node->getLeft(), this); 562 563 return false; 564 } 565 566 // Binary operations cannot be applied to structures. 567 ASSERT(mCurrentAccessChain.getChain().empty()); 568 569 // Mark the operands themselves |precise| too. 570 return true; 571 } 572 573 void visitSymbol(TIntermSymbol *symbol) override 574 { 575 // Mark the symbol together with the current access chain as |precise|. 576 ObjectAndAccessChain preciseObject = {&symbol->variable(), mCurrentAccessChain}; 577 AddPreciseObject(mInfo, preciseObject); 578 } 579 580 bool visitAggregate(Visit visit, TIntermAggregate *node) override 581 { 582 // If this is a struct constructor and the access chain is not empty, only apply |precise| 583 // to the field selected by the access chain. 584 const TType &type = node->getType(); 585 const bool isStructConstructor = 586 node->getOp() == EOpConstruct && type.getStruct() != nullptr && !type.isArray(); 587 588 if (!mCurrentAccessChain.getChain().empty() && isStructConstructor) 589 { 590 size_t selectedFieldIndex = mCurrentAccessChain.getChain().front(); 591 mCurrentAccessChain.pop_front(1); 592 593 ASSERT(selectedFieldIndex < node->getChildCount()); 594 595 // Visit only said field. 596 node->getChildNode(selectedFieldIndex)->traverse(this); 597 return false; 598 } 599 600 // If this is an array constructor, each element is equally |precise| with the same access 601 // chain. Otherwise there cannot be any access chain for constructors. 602 if (node->getOp() == EOpConstruct) 603 { 604 ASSERT(type.isArray() || mCurrentAccessChain.getChain().empty()); 605 return true; 606 } 607 608 // Otherwise this is a function call. The access chain is irrelevant and every (non-out) 609 // parameter of the function call should be considered |precise|. 610 mCurrentAccessChain.clear(); 611 612 const TFunction *function = node->getFunction(); 613 ASSERT(function); 614 615 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex) 616 { 617 if (function->getParam(paramIndex)->getType().getQualifier() != EvqParamOut) 618 { 619 node->getChildNode(paramIndex)->traverse(this); 620 } 621 } 622 623 // Mark arithmetic nodes as |precise|. 624 if (IsArithmeticOp(node->getOp())) 625 { 626 node->setIsPrecise(); 627 } 628 629 return false; 630 } 631 632 private: 633 ASTInfo *mInfo = nullptr; 634 AccessChain mCurrentAccessChain; 635 }; 636 } // anonymous namespace 637 638 void FindPreciseNodes(TCompiler *compiler, TIntermBlock *root) 639 { 640 ASTInfo info; 641 642 InfoGatherTraverser infoGather(&info); 643 root->traverse(&infoGather); 644 645 PropagatePreciseTraverser propagator(&info); 646 647 // First, get return expressions out of the way by propagating |precise|. 648 for (TIntermBranch *returnNode : info.preciseReturnNodes) 649 { 650 ASSERT(returnNode->getChildCount() == 1); 651 propagator.propagatePrecise(returnNode->getChildNode(0), {}); 652 } 653 654 // Now take |precise| access chains one by one, and propagate their |precise|-ness to the right 655 // hand side of all assignments in which they are on the left hand side, as well as the 656 // arithmetic expression that assigns to them. 657 658 while (!info.preciseObjectsToProcess.empty()) 659 { 660 // Get one |precise| object to process. 661 auto first = info.preciseObjectsToProcess.begin(); 662 const ObjectAndAccessChain toProcess = *first; 663 info.preciseObjectsToProcess.erase(first); 664 665 // Propagate |precise| to every node where it's assigned to. 666 const TVector<TIntermOperator *> &assignmentNodes = 667 info.variableAssignmentNodeMap[toProcess.variable]; 668 for (TIntermOperator *assignmentNode : assignmentNodes) 669 { 670 AccessChain assignmentAccessChain = GetAssignmentAccessChain(assignmentNode); 671 672 // There are two possibilities: 673 // 674 // - The assignment is to a bigger access chain than that which is being processed, in 675 // which case the entire right hand side is marked |precise|, 676 // - The assignment is to a smaller access chain, in which case only the subobject of 677 // the right hand side that corresponds to the remaining part of the access chain must 678 // be marked |precise|. 679 // 680 // For example, if processing |a.b.c| as a |precise| access chain: 681 // 682 // - If the assignment is to |a.b.c.d|, then the entire right hand side must be 683 // |precise|. 684 // - If the assignment is to |a.b|, only the |.c| part of the right hand side expression 685 // must be |precise|. 686 // - If the assignment is to |a.e|, there is nothing to do. 687 // 688 AccessChain remainingAccessChain = toProcess.accessChain; 689 if (!remainingAccessChain.removePrefix(assignmentAccessChain)) 690 { 691 continue; 692 } 693 694 propagator.propagatePrecise(assignmentNode, remainingAccessChain); 695 } 696 } 697 698 // The AST nodes now contain information gathered by this post-processing step, and so the tree 699 // must no longer be transformed. 700 compiler->enableValidateNoMoreTransformations(); 701 } 702 703 } // namespace sh