ASTMetadataHLSL.cpp (14788B)
1 // 2 // Copyright 2002 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 7 // Analysis of the AST needed for HLSL generation 8 9 #include "compiler/translator/ASTMetadataHLSL.h" 10 11 #include "compiler/translator/CallDAG.h" 12 #include "compiler/translator/SymbolTable.h" 13 #include "compiler/translator/tree_util/IntermTraverse.h" 14 15 namespace sh 16 { 17 18 namespace 19 { 20 21 // Class used to traverse the AST of a function definition, checking if the 22 // function uses a gradient, and writing the set of control flow using gradients. 23 // It assumes that the analysis has already been made for the function's 24 // callees. 25 class PullGradient : public TIntermTraverser 26 { 27 public: 28 PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag) 29 : TIntermTraverser(true, false, true), 30 mMetadataList(metadataList), 31 mMetadata(&(*metadataList)[index]), 32 mIndex(index), 33 mDag(dag) 34 { 35 ASSERT(index < metadataList->size()); 36 37 // ESSL 100 builtin gradient functions 38 mGradientBuiltinFunctions.insert(ImmutableString("texture2D")); 39 mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj")); 40 mGradientBuiltinFunctions.insert(ImmutableString("textureCube")); 41 42 // ESSL 300 builtin gradient functions 43 mGradientBuiltinFunctions.insert(ImmutableString("dFdx")); 44 mGradientBuiltinFunctions.insert(ImmutableString("dFdy")); 45 mGradientBuiltinFunctions.insert(ImmutableString("fwidth")); 46 mGradientBuiltinFunctions.insert(ImmutableString("texture")); 47 mGradientBuiltinFunctions.insert(ImmutableString("textureProj")); 48 mGradientBuiltinFunctions.insert(ImmutableString("textureOffset")); 49 mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset")); 50 51 // ESSL 310 doesn't add builtin gradient functions 52 } 53 54 void traverse(TIntermFunctionDefinition *node) 55 { 56 node->traverse(this); 57 ASSERT(mParents.empty()); 58 } 59 60 // Called when a gradient operation or a call to a function using a gradient is found. 61 void onGradient() 62 { 63 mMetadata->mUsesGradient = true; 64 // Mark the latest control flow as using a gradient. 65 if (!mParents.empty()) 66 { 67 mMetadata->mControlFlowsContainingGradient.insert(mParents.back()); 68 } 69 } 70 71 void visitControlFlow(Visit visit, TIntermNode *node) 72 { 73 if (visit == PreVisit) 74 { 75 mParents.push_back(node); 76 } 77 else if (visit == PostVisit) 78 { 79 ASSERT(mParents.back() == node); 80 mParents.pop_back(); 81 // A control flow's using a gradient means its parents are too. 82 if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty()) 83 { 84 mMetadata->mControlFlowsContainingGradient.insert(mParents.back()); 85 } 86 } 87 } 88 89 bool visitLoop(Visit visit, TIntermLoop *loop) override 90 { 91 visitControlFlow(visit, loop); 92 return true; 93 } 94 95 bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override 96 { 97 visitControlFlow(visit, ifElse); 98 return true; 99 } 100 101 bool visitAggregate(Visit visit, TIntermAggregate *node) override 102 { 103 if (visit == PreVisit) 104 { 105 if (node->getOp() == EOpCallFunctionInAST) 106 { 107 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId()); 108 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); 109 110 if ((*mMetadataList)[calleeIndex].mUsesGradient) 111 { 112 onGradient(); 113 } 114 } 115 else if (BuiltInGroup::IsBuiltIn(node->getOp()) && !BuiltInGroup::IsMath(node->getOp())) 116 { 117 if (mGradientBuiltinFunctions.find(node->getFunction()->name()) != 118 mGradientBuiltinFunctions.end()) 119 { 120 onGradient(); 121 } 122 } 123 } 124 125 return true; 126 } 127 128 private: 129 MetadataList *mMetadataList; 130 ASTMetadataHLSL *mMetadata; 131 size_t mIndex; 132 const CallDAG &mDag; 133 134 // Contains a stack of the control flow nodes that are parents of the node being 135 // currently visited. It is used to mark control flows using a gradient. 136 std::vector<TIntermNode *> mParents; 137 138 // A list of builtin functions that use gradients 139 std::set<ImmutableString> mGradientBuiltinFunctions; 140 }; 141 142 // Traverses the AST of a function definition to compute the the discontinuous loops 143 // and the if statements containing gradient loops. It assumes that the gradient loops 144 // (loops that contain a gradient) have already been computed and that it has already 145 // traversed the current function's callees. 146 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser 147 { 148 public: 149 PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList, 150 size_t index, 151 const CallDAG &dag) 152 : TIntermTraverser(true, false, true), 153 mMetadataList(metadataList), 154 mMetadata(&(*metadataList)[index]), 155 mIndex(index), 156 mDag(dag) 157 {} 158 159 void traverse(TIntermFunctionDefinition *node) 160 { 161 node->traverse(this); 162 ASSERT(mLoopsAndSwitches.empty()); 163 ASSERT(mIfs.empty()); 164 } 165 166 // Called when traversing a gradient loop or a call to a function with a 167 // gradient loop in its call graph. 168 void onGradientLoop() 169 { 170 mMetadata->mHasGradientLoopInCallGraph = true; 171 // Mark the latest if as using a discontinuous loop. 172 if (!mIfs.empty()) 173 { 174 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back()); 175 } 176 } 177 178 bool visitLoop(Visit visit, TIntermLoop *loop) override 179 { 180 if (visit == PreVisit) 181 { 182 mLoopsAndSwitches.push_back(loop); 183 184 if (mMetadata->hasGradientInCallGraph(loop)) 185 { 186 onGradientLoop(); 187 } 188 } 189 else if (visit == PostVisit) 190 { 191 ASSERT(mLoopsAndSwitches.back() == loop); 192 mLoopsAndSwitches.pop_back(); 193 } 194 195 return true; 196 } 197 198 bool visitIfElse(Visit visit, TIntermIfElse *node) override 199 { 200 if (visit == PreVisit) 201 { 202 mIfs.push_back(node); 203 } 204 else if (visit == PostVisit) 205 { 206 ASSERT(mIfs.back() == node); 207 mIfs.pop_back(); 208 // An if using a discontinuous loop means its parents ifs are also discontinuous. 209 if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty()) 210 { 211 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back()); 212 } 213 } 214 215 return true; 216 } 217 218 bool visitBranch(Visit visit, TIntermBranch *node) override 219 { 220 if (visit == PreVisit) 221 { 222 switch (node->getFlowOp()) 223 { 224 case EOpBreak: 225 { 226 ASSERT(!mLoopsAndSwitches.empty()); 227 TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode(); 228 if (loop != nullptr) 229 { 230 mMetadata->mDiscontinuousLoops.insert(loop); 231 } 232 } 233 break; 234 case EOpContinue: 235 { 236 ASSERT(!mLoopsAndSwitches.empty()); 237 TIntermLoop *loop = nullptr; 238 size_t i = mLoopsAndSwitches.size(); 239 while (loop == nullptr && i > 0) 240 { 241 --i; 242 loop = mLoopsAndSwitches.at(i)->getAsLoopNode(); 243 } 244 ASSERT(loop != nullptr); 245 mMetadata->mDiscontinuousLoops.insert(loop); 246 } 247 break; 248 case EOpKill: 249 case EOpReturn: 250 // A return or discard jumps out of all the enclosing loops 251 if (!mLoopsAndSwitches.empty()) 252 { 253 for (TIntermNode *intermNode : mLoopsAndSwitches) 254 { 255 TIntermLoop *loop = intermNode->getAsLoopNode(); 256 if (loop) 257 { 258 mMetadata->mDiscontinuousLoops.insert(loop); 259 } 260 } 261 } 262 break; 263 default: 264 UNREACHABLE(); 265 } 266 } 267 268 return true; 269 } 270 271 bool visitAggregate(Visit visit, TIntermAggregate *node) override 272 { 273 if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST) 274 { 275 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId()); 276 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); 277 278 if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph) 279 { 280 onGradientLoop(); 281 } 282 } 283 284 return true; 285 } 286 287 bool visitSwitch(Visit visit, TIntermSwitch *node) override 288 { 289 if (visit == PreVisit) 290 { 291 mLoopsAndSwitches.push_back(node); 292 } 293 else if (visit == PostVisit) 294 { 295 ASSERT(mLoopsAndSwitches.back() == node); 296 mLoopsAndSwitches.pop_back(); 297 } 298 return true; 299 } 300 301 private: 302 MetadataList *mMetadataList; 303 ASTMetadataHLSL *mMetadata; 304 size_t mIndex; 305 const CallDAG &mDag; 306 307 std::vector<TIntermNode *> mLoopsAndSwitches; 308 std::vector<TIntermIfElse *> mIfs; 309 }; 310 311 // Tags all the functions called in a discontinuous loop 312 class PushDiscontinuousLoops : public TIntermTraverser 313 { 314 public: 315 PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag) 316 : TIntermTraverser(true, true, true), 317 mMetadataList(metadataList), 318 mMetadata(&(*metadataList)[index]), 319 mIndex(index), 320 mDag(dag), 321 mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0) 322 {} 323 324 void traverse(TIntermFunctionDefinition *node) 325 { 326 node->traverse(this); 327 ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)); 328 } 329 330 bool visitLoop(Visit visit, TIntermLoop *loop) override 331 { 332 bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0; 333 334 if (visit == PreVisit && isDiscontinuous) 335 { 336 mNestedDiscont++; 337 } 338 else if (visit == PostVisit && isDiscontinuous) 339 { 340 mNestedDiscont--; 341 } 342 343 return true; 344 } 345 346 bool visitAggregate(Visit visit, TIntermAggregate *node) override 347 { 348 switch (node->getOp()) 349 { 350 case EOpCallFunctionInAST: 351 if (visit == PreVisit && mNestedDiscont > 0) 352 { 353 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId()); 354 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); 355 356 (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true; 357 } 358 break; 359 default: 360 break; 361 } 362 return true; 363 } 364 365 private: 366 MetadataList *mMetadataList; 367 ASTMetadataHLSL *mMetadata; 368 size_t mIndex; 369 const CallDAG &mDag; 370 371 int mNestedDiscont; 372 }; 373 } // namespace 374 375 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node) 376 { 377 return mControlFlowsContainingGradient.count(node) > 0; 378 } 379 380 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node) 381 { 382 return mIfsContainingGradientLoop.count(node) > 0; 383 } 384 385 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag) 386 { 387 MetadataList metadataList(callDag.size()); 388 389 // Compute all the information related to when gradient operations are used. 390 // We want to know for each function and control flow operation if they have 391 // a gradient operation in their call graph (shortened to "using a gradient" 392 // in the rest of the file). 393 // 394 // This computation is logically split in three steps: 395 // 1 - For each function compute if it uses a gradient in its body, ignoring 396 // calls to other user-defined functions. 397 // 2 - For each function determine if it uses a gradient in its call graph, 398 // using the result of step 1 and the CallDAG to know its callees. 399 // 3 - For each control flow statement of each function, check if it uses a 400 // gradient in the function's body, or if it calls a user-defined function that 401 // uses a gradient. 402 // 403 // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3 404 // for leaves first, then going down the tree. This is correct because 1 doesn't 405 // depend on other functions, and 2 and 3 depend only on callees. 406 for (size_t i = 0; i < callDag.size(); i++) 407 { 408 PullGradient pull(&metadataList, i, callDag); 409 pull.traverse(callDag.getRecordFromIndex(i).node); 410 } 411 412 // Compute which loops are discontinuous and which function are called in 413 // these loops. The same way computing gradient usage is a "pull" process, 414 // computing "bing used in a discont. loop" is a push process. However we also 415 // need to know what ifs have a discontinuous loop inside so we do the same type 416 // of callgraph analysis as for the gradient. 417 418 // First compute which loops are discontinuous (no specific order) and pull 419 // the ifs and functions using a gradient loop. 420 for (size_t i = 0; i < callDag.size(); i++) 421 { 422 PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag); 423 pull.traverse(callDag.getRecordFromIndex(i).node); 424 } 425 426 // Then push the information to callees, either from the a local discontinuous 427 // loop or from the caller being called in a discontinuous loop already 428 for (size_t i = callDag.size(); i-- > 0;) 429 { 430 PushDiscontinuousLoops push(&metadataList, i, callDag); 431 push.traverse(callDag.getRecordFromIndex(i).node); 432 } 433 434 // We create "Lod0" version of functions with the gradient operations replaced 435 // by non-gradient operations so that the D3D compiler is happier with discont 436 // loops. 437 for (auto &metadata : metadataList) 438 { 439 metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient; 440 } 441 442 return metadataList; 443 } 444 445 } // namespace sh