ExpandIntegerPowExpressions.cpp (4087B)
1 // 2 // Copyright 2016 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // Implementation of the integer pow expressions HLSL bug workaround. 7 // See header for more info. 8 9 #include "compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.h" 10 11 #include <cmath> 12 #include <cstdlib> 13 14 #include "compiler/translator/tree_util/IntermNode_util.h" 15 #include "compiler/translator/tree_util/IntermTraverse.h" 16 17 namespace sh 18 { 19 20 namespace 21 { 22 23 class Traverser : public TIntermTraverser 24 { 25 public: 26 [[nodiscard]] static bool Apply(TCompiler *compiler, 27 TIntermNode *root, 28 TSymbolTable *symbolTable); 29 30 private: 31 Traverser(TSymbolTable *symbolTable); 32 bool visitAggregate(Visit visit, TIntermAggregate *node) override; 33 void nextIteration(); 34 35 bool mFound = false; 36 }; 37 38 // static 39 bool Traverser::Apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable) 40 { 41 Traverser traverser(symbolTable); 42 do 43 { 44 traverser.nextIteration(); 45 root->traverse(&traverser); 46 if (traverser.mFound) 47 { 48 if (!traverser.updateTree(compiler, root)) 49 { 50 return false; 51 } 52 } 53 } while (traverser.mFound); 54 55 return true; 56 } 57 58 Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable) 59 {} 60 61 void Traverser::nextIteration() 62 { 63 mFound = false; 64 } 65 66 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) 67 { 68 if (mFound) 69 { 70 return false; 71 } 72 73 // Test 0: skip non-pow operators. 74 if (node->getOp() != EOpPow) 75 { 76 return true; 77 } 78 79 const TIntermSequence *sequence = node->getSequence(); 80 ASSERT(sequence->size() == 2u); 81 const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion(); 82 83 // Test 1: check for a single constant. 84 if (!constantExponent || constantExponent->getNominalSize() != 1) 85 { 86 return true; 87 } 88 89 float exponentValue = constantExponent->getConstantValue()->getFConst(); 90 91 // Test 2: exponentValue is in the problematic range. 92 if (exponentValue < -5.0f || exponentValue > 9.0f) 93 { 94 return true; 95 } 96 97 // Test 3: exponentValue is integer or pretty close to an integer. 98 if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f) 99 { 100 return true; 101 } 102 103 // Test 4: skip -1, 0, and 1 104 int exponent = static_cast<int>(std::round(exponentValue)); 105 int n = std::abs(exponent); 106 if (n < 2) 107 { 108 return true; 109 } 110 111 // Potential problem case detected, apply workaround. 112 113 TIntermTyped *lhs = sequence->at(0)->getAsTyped(); 114 ASSERT(lhs); 115 116 TIntermDeclaration *lhsVariableDeclaration = nullptr; 117 TVariable *lhsVariable = 118 DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration); 119 insertStatementInParentBlock(lhsVariableDeclaration); 120 121 // Create a chain of n-1 multiples. 122 TIntermTyped *current = CreateTempSymbolNode(lhsVariable); 123 for (int i = 1; i < n; ++i) 124 { 125 TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable)); 126 mul->setLine(node->getLine()); 127 current = mul; 128 } 129 130 // For negative pow, compute the reciprocal of the positive pow. 131 if (exponent < 0) 132 { 133 TConstantUnion *oneVal = new TConstantUnion(); 134 oneVal->setFConst(1.0f); 135 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType()); 136 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current); 137 current = div; 138 } 139 140 queueReplacement(current, OriginalNode::IS_DROPPED); 141 mFound = true; 142 return false; 143 } 144 145 } // anonymous namespace 146 147 bool ExpandIntegerPowExpressions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable) 148 { 149 return Traverser::Apply(compiler, root, symbolTable); 150 } 151 152 } // namespace sh