SimplifyLoopConditions.cpp (18459B)
1 // 2 // Copyright 2016 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions 7 // to regular statements inside the loop. This way further transformations that generate statements 8 // from loop conditions and loop expressions work correctly. 9 // 10 11 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h" 12 13 #include "compiler/translator/StaticType.h" 14 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" 15 #include "compiler/translator/tree_util/IntermNode_util.h" 16 #include "compiler/translator/tree_util/IntermTraverse.h" 17 18 namespace sh 19 { 20 21 namespace 22 { 23 24 struct LoopInfo 25 { 26 const TVariable *conditionVariable = nullptr; 27 TIntermTyped *condition = nullptr; 28 TIntermTyped *expression = nullptr; 29 }; 30 31 class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser 32 { 33 public: 34 SimplifyLoopConditionsTraverser(const IntermNodePatternMatcher *conditionsToSimplify, 35 TSymbolTable *symbolTable); 36 37 void traverseLoop(TIntermLoop *node) override; 38 39 bool visitUnary(Visit visit, TIntermUnary *node) override; 40 bool visitBinary(Visit visit, TIntermBinary *node) override; 41 bool visitAggregate(Visit visit, TIntermAggregate *node) override; 42 bool visitTernary(Visit visit, TIntermTernary *node) override; 43 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override; 44 bool visitBranch(Visit visit, TIntermBranch *node) override; 45 46 bool foundLoopToChange() const { return mFoundLoopToChange; } 47 48 protected: 49 // Marked to true once an operation that needs to be hoisted out of a loop expression has been 50 // found. 51 bool mFoundLoopToChange; 52 bool mInsideLoopInitConditionOrExpression; 53 const IntermNodePatternMatcher *mConditionsToSimplify; 54 55 private: 56 LoopInfo mLoop; 57 }; 58 59 SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser( 60 const IntermNodePatternMatcher *conditionsToSimplify, 61 TSymbolTable *symbolTable) 62 : TLValueTrackingTraverser(true, false, false, symbolTable), 63 mFoundLoopToChange(false), 64 mInsideLoopInitConditionOrExpression(false), 65 mConditionsToSimplify(conditionsToSimplify) 66 {} 67 68 // If we're inside a loop initialization, condition, or expression, we check for expressions that 69 // should be moved out of the loop condition or expression. If one is found, the loop is 70 // transformed. 71 // If we're not inside loop initialization, condition, or expression, we only need to traverse nodes 72 // that may contain loops. 73 74 bool SimplifyLoopConditionsTraverser::visitUnary(Visit visit, TIntermUnary *node) 75 { 76 if (!mInsideLoopInitConditionOrExpression) 77 return false; 78 79 if (mFoundLoopToChange) 80 return false; // Already decided to change this loop. 81 82 ASSERT(mConditionsToSimplify); 83 mFoundLoopToChange = mConditionsToSimplify->match(node); 84 return !mFoundLoopToChange; 85 } 86 87 bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node) 88 { 89 if (!mInsideLoopInitConditionOrExpression) 90 return false; 91 92 if (mFoundLoopToChange) 93 return false; // Already decided to change this loop. 94 95 ASSERT(mConditionsToSimplify); 96 mFoundLoopToChange = 97 mConditionsToSimplify->match(node, getParentNode(), isLValueRequiredHere()); 98 return !mFoundLoopToChange; 99 } 100 101 bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) 102 { 103 if (!mInsideLoopInitConditionOrExpression) 104 return false; 105 106 if (mFoundLoopToChange) 107 return false; // Already decided to change this loop. 108 109 ASSERT(mConditionsToSimplify); 110 mFoundLoopToChange = mConditionsToSimplify->match(node, getParentNode()); 111 return !mFoundLoopToChange; 112 } 113 114 bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node) 115 { 116 if (!mInsideLoopInitConditionOrExpression) 117 return false; 118 119 if (mFoundLoopToChange) 120 return false; // Already decided to change this loop. 121 122 ASSERT(mConditionsToSimplify); 123 mFoundLoopToChange = mConditionsToSimplify->match(node); 124 return !mFoundLoopToChange; 125 } 126 127 bool SimplifyLoopConditionsTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node) 128 { 129 if (!mInsideLoopInitConditionOrExpression) 130 return false; 131 132 if (mFoundLoopToChange) 133 return false; // Already decided to change this loop. 134 135 ASSERT(mConditionsToSimplify); 136 mFoundLoopToChange = mConditionsToSimplify->match(node); 137 return !mFoundLoopToChange; 138 } 139 140 bool SimplifyLoopConditionsTraverser::visitBranch(Visit visit, TIntermBranch *node) 141 { 142 if (node->getFlowOp() == EOpContinue && (mLoop.condition || mLoop.expression)) 143 { 144 TIntermBlock *parent = getParentNode()->getAsBlock(); 145 ASSERT(parent); 146 TIntermSequence seq; 147 if (mLoop.expression) 148 { 149 seq.push_back(mLoop.expression->deepCopy()); 150 } 151 if (mLoop.condition) 152 { 153 ASSERT(mLoop.conditionVariable); 154 seq.push_back( 155 CreateTempAssignmentNode(mLoop.conditionVariable, mLoop.condition->deepCopy())); 156 } 157 seq.push_back(node); 158 mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(parent, node, std::move(seq))); 159 } 160 161 return true; 162 } 163 164 TIntermBlock *CreateFromBody(TIntermLoop *node, bool *bodyEndsInBranchOut) 165 { 166 TIntermBlock *newBody = new TIntermBlock(); 167 *bodyEndsInBranchOut = false; 168 169 TIntermBlock *nodeBody = node->getBody(); 170 if (nodeBody != nullptr) 171 { 172 newBody->getSequence()->push_back(nodeBody); 173 *bodyEndsInBranchOut = EndsInBranch(nodeBody); 174 } 175 return newBody; 176 } 177 178 void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) 179 { 180 // Mark that we're inside a loop condition or expression, and determine if the loop needs to be 181 // transformed. 182 183 ScopedNodeInTraversalPath addToPath(this, node); 184 185 mInsideLoopInitConditionOrExpression = true; 186 mFoundLoopToChange = !mConditionsToSimplify; 187 188 if (!mFoundLoopToChange && node->getInit()) 189 { 190 node->getInit()->traverse(this); 191 } 192 193 if (!mFoundLoopToChange && node->getCondition()) 194 { 195 node->getCondition()->traverse(this); 196 } 197 198 if (!mFoundLoopToChange && node->getExpression()) 199 { 200 node->getExpression()->traverse(this); 201 } 202 203 mInsideLoopInitConditionOrExpression = false; 204 205 const LoopInfo prevLoop = mLoop; 206 207 if (mFoundLoopToChange) 208 { 209 const TType *boolType = StaticType::Get<EbtBool, EbpUndefined, EvqTemporary, 1, 1>(); 210 mLoop.conditionVariable = CreateTempVariable(mSymbolTable, boolType); 211 mLoop.condition = node->getCondition(); 212 mLoop.expression = node->getExpression(); 213 214 // Replace the loop condition with a boolean variable that's updated on each iteration. 215 TLoopType loopType = node->getType(); 216 if (loopType == ELoopWhile) 217 { 218 ASSERT(!mLoop.expression); 219 220 if (mLoop.condition->getAsSymbolNode()) 221 { 222 // Mask continue statement condition variable update. 223 mLoop.condition = nullptr; 224 } 225 else if (mLoop.condition->getAsConstantUnion()) 226 { 227 // Transform: 228 // while (expr) { body; } 229 // into 230 // bool s0 = expr; 231 // while (s0) { body; } 232 TIntermDeclaration *tempInitDeclaration = 233 CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition); 234 insertStatementInParentBlock(tempInitDeclaration); 235 236 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable)); 237 238 // Mask continue statement condition variable update. 239 mLoop.condition = nullptr; 240 } 241 else 242 { 243 // Transform: 244 // while (expr) { body; } 245 // into 246 // bool s0 = expr; 247 // while (s0) { { body; } s0 = expr; } 248 // 249 // Local case statements are transformed into: 250 // s0 = expr; continue; 251 TIntermDeclaration *tempInitDeclaration = 252 CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition); 253 insertStatementInParentBlock(tempInitDeclaration); 254 255 bool bodyEndsInBranch; 256 TIntermBlock *newBody = CreateFromBody(node, &bodyEndsInBranch); 257 if (!bodyEndsInBranch) 258 { 259 newBody->getSequence()->push_back(CreateTempAssignmentNode( 260 mLoop.conditionVariable, mLoop.condition->deepCopy())); 261 } 262 263 // Can't use queueReplacement to replace old body, since it may have been nullptr. 264 // It's safe to do the replacements in place here - the new body will still be 265 // traversed, but that won't create any problems. 266 node->setBody(newBody); 267 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable)); 268 } 269 } 270 else if (loopType == ELoopDoWhile) 271 { 272 ASSERT(!mLoop.expression); 273 274 if (mLoop.condition->getAsSymbolNode()) 275 { 276 // Mask continue statement condition variable update. 277 mLoop.condition = nullptr; 278 } 279 else if (mLoop.condition->getAsConstantUnion()) 280 { 281 // Transform: 282 // do { 283 // body; 284 // } while (expr); 285 // into 286 // bool s0 = expr; 287 // do { 288 // body; 289 // } while (s0); 290 TIntermDeclaration *tempInitDeclaration = 291 CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition); 292 insertStatementInParentBlock(tempInitDeclaration); 293 294 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable)); 295 296 // Mask continue statement condition variable update. 297 mLoop.condition = nullptr; 298 } 299 else 300 { 301 // Transform: 302 // do { 303 // body; 304 // } while (expr); 305 // into 306 // bool s0; 307 // do { 308 // { body; } 309 // s0 = expr; 310 // } while (s0); 311 // Local case statements are transformed into: 312 // s0 = expr; continue; 313 TIntermDeclaration *tempInitDeclaration = 314 CreateTempDeclarationNode(mLoop.conditionVariable); 315 insertStatementInParentBlock(tempInitDeclaration); 316 317 bool bodyEndsInBranch; 318 TIntermBlock *newBody = CreateFromBody(node, &bodyEndsInBranch); 319 if (!bodyEndsInBranch) 320 { 321 newBody->getSequence()->push_back( 322 CreateTempAssignmentNode(mLoop.conditionVariable, mLoop.condition)); 323 } 324 325 // Can't use queueReplacement to replace old body, since it may have been nullptr. 326 // It's safe to do the replacements in place here - the new body will still be 327 // traversed, but that won't create any problems. 328 node->setBody(newBody); 329 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable)); 330 } 331 } 332 else if (loopType == ELoopFor) 333 { 334 if (!mLoop.condition) 335 { 336 mLoop.condition = CreateBoolNode(true); 337 } 338 339 TIntermLoop *whileLoop; 340 TIntermBlock *loopScope = new TIntermBlock(); 341 TIntermSequence *loopScopeSequence = loopScope->getSequence(); 342 343 // Insert "init;" 344 if (node->getInit()) 345 { 346 loopScopeSequence->push_back(node->getInit()); 347 } 348 349 if (mLoop.condition->getAsSymbolNode()) 350 { 351 // Move the loop condition inside the loop. 352 // Transform: 353 // for (init; expr; exprB) { body; } 354 // into 355 // { 356 // init; 357 // while (expr) { 358 // { body; } 359 // exprB; 360 // } 361 // } 362 // 363 // Local case statements are transformed into: 364 // exprB; continue; 365 366 // Insert "{ body; }" in the while loop 367 bool bodyEndsInBranch; 368 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch); 369 // Insert "exprB;" in the while loop 370 if (!bodyEndsInBranch && node->getExpression()) 371 { 372 whileLoopBody->getSequence()->push_back(node->getExpression()); 373 } 374 // Create "while(expr) { whileLoopBody }" 375 whileLoop = 376 new TIntermLoop(ELoopWhile, nullptr, mLoop.condition, nullptr, whileLoopBody); 377 378 // Mask continue statement condition variable update. 379 mLoop.condition = nullptr; 380 } 381 else if (mLoop.condition->getAsConstantUnion()) 382 { 383 // Move the loop condition inside the loop. 384 // Transform: 385 // for (init; expr; exprB) { body; } 386 // into 387 // { 388 // init; 389 // bool s0 = expr; 390 // while (s0) { 391 // { body; } 392 // exprB; 393 // } 394 // } 395 // 396 // Local case statements are transformed into: 397 // exprB; continue; 398 399 // Insert "bool s0 = expr;" 400 loopScopeSequence->push_back( 401 CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition)); 402 // Insert "{ body; }" in the while loop 403 bool bodyEndsInBranch; 404 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch); 405 // Insert "exprB;" in the while loop 406 if (!bodyEndsInBranch && node->getExpression()) 407 { 408 whileLoopBody->getSequence()->push_back(node->getExpression()); 409 } 410 // Create "while(s0) { whileLoopBody }" 411 whileLoop = new TIntermLoop(ELoopWhile, nullptr, 412 CreateTempSymbolNode(mLoop.conditionVariable), nullptr, 413 whileLoopBody); 414 415 // Mask continue statement condition variable update. 416 mLoop.condition = nullptr; 417 } 418 else 419 { 420 // Move the loop condition inside the loop. 421 // Transform: 422 // for (init; expr; exprB) { body; } 423 // into 424 // { 425 // init; 426 // bool s0 = expr; 427 // while (s0) { 428 // { body; } 429 // exprB; 430 // s0 = expr; 431 // } 432 // } 433 // 434 // Local case statements are transformed into: 435 // exprB; s0 = expr; continue; 436 437 // Insert "bool s0 = expr;" 438 loopScopeSequence->push_back( 439 CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition)); 440 // Insert "{ body; }" in the while loop 441 bool bodyEndsInBranch; 442 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch); 443 // Insert "exprB;" in the while loop 444 if (!bodyEndsInBranch && node->getExpression()) 445 { 446 whileLoopBody->getSequence()->push_back(node->getExpression()); 447 } 448 // Insert "s0 = expr;" in the while loop 449 if (!bodyEndsInBranch) 450 { 451 whileLoopBody->getSequence()->push_back(CreateTempAssignmentNode( 452 mLoop.conditionVariable, mLoop.condition->deepCopy())); 453 } 454 // Create "while(s0) { whileLoopBody }" 455 whileLoop = new TIntermLoop(ELoopWhile, nullptr, 456 CreateTempSymbolNode(mLoop.conditionVariable), nullptr, 457 whileLoopBody); 458 } 459 460 loopScope->getSequence()->push_back(whileLoop); 461 queueReplacement(loopScope, OriginalNode::IS_DROPPED); 462 463 // After this the old body node will be traversed and loops inside it may be 464 // transformed. This is fine, since the old body node will still be in the AST after 465 // the transformation that's queued here, and transforming loops inside it doesn't 466 // need to know the exact post-transform path to it. 467 } 468 } 469 470 mFoundLoopToChange = false; 471 472 // We traverse the body of the loop even if the loop is transformed. 473 if (node->getBody()) 474 node->getBody()->traverse(this); 475 476 mLoop = prevLoop; 477 } 478 479 } // namespace 480 481 bool SimplifyLoopConditions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable) 482 { 483 SimplifyLoopConditionsTraverser traverser(nullptr, symbolTable); 484 root->traverse(&traverser); 485 return traverser.updateTree(compiler, root); 486 } 487 488 bool SimplifyLoopConditions(TCompiler *compiler, 489 TIntermNode *root, 490 unsigned int conditionsToSimplifyMask, 491 TSymbolTable *symbolTable) 492 { 493 IntermNodePatternMatcher conditionsToSimplify(conditionsToSimplifyMask); 494 SimplifyLoopConditionsTraverser traverser(&conditionsToSimplify, symbolTable); 495 root->traverse(&traverser); 496 return traverser.updateTree(compiler, root); 497 } 498 499 } // namespace sh