tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ExpandIntegerPowExpressions.cpp (4087B)


      1 //
      2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // Implementation of the integer pow expressions HLSL bug workaround.
      7 // See header for more info.
      8 
      9 #include "compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.h"
     10 
     11 #include <cmath>
     12 #include <cstdlib>
     13 
     14 #include "compiler/translator/tree_util/IntermNode_util.h"
     15 #include "compiler/translator/tree_util/IntermTraverse.h"
     16 
     17 namespace sh
     18 {
     19 
     20 namespace
     21 {
     22 
     23 class Traverser : public TIntermTraverser
     24 {
     25  public:
     26    [[nodiscard]] static bool Apply(TCompiler *compiler,
     27                                    TIntermNode *root,
     28                                    TSymbolTable *symbolTable);
     29 
     30  private:
     31    Traverser(TSymbolTable *symbolTable);
     32    bool visitAggregate(Visit visit, TIntermAggregate *node) override;
     33    void nextIteration();
     34 
     35    bool mFound = false;
     36 };
     37 
     38 // static
     39 bool Traverser::Apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
     40 {
     41    Traverser traverser(symbolTable);
     42    do
     43    {
     44        traverser.nextIteration();
     45        root->traverse(&traverser);
     46        if (traverser.mFound)
     47        {
     48            if (!traverser.updateTree(compiler, root))
     49            {
     50                return false;
     51            }
     52        }
     53    } while (traverser.mFound);
     54 
     55    return true;
     56 }
     57 
     58 Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
     59 {}
     60 
     61 void Traverser::nextIteration()
     62 {
     63    mFound = false;
     64 }
     65 
     66 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
     67 {
     68    if (mFound)
     69    {
     70        return false;
     71    }
     72 
     73    // Test 0: skip non-pow operators.
     74    if (node->getOp() != EOpPow)
     75    {
     76        return true;
     77    }
     78 
     79    const TIntermSequence *sequence = node->getSequence();
     80    ASSERT(sequence->size() == 2u);
     81    const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
     82 
     83    // Test 1: check for a single constant.
     84    if (!constantExponent || constantExponent->getNominalSize() != 1)
     85    {
     86        return true;
     87    }
     88 
     89    float exponentValue = constantExponent->getConstantValue()->getFConst();
     90 
     91    // Test 2: exponentValue is in the problematic range.
     92    if (exponentValue < -5.0f || exponentValue > 9.0f)
     93    {
     94        return true;
     95    }
     96 
     97    // Test 3: exponentValue is integer or pretty close to an integer.
     98    if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
     99    {
    100        return true;
    101    }
    102 
    103    // Test 4: skip -1, 0, and 1
    104    int exponent = static_cast<int>(std::round(exponentValue));
    105    int n        = std::abs(exponent);
    106    if (n < 2)
    107    {
    108        return true;
    109    }
    110 
    111    // Potential problem case detected, apply workaround.
    112 
    113    TIntermTyped *lhs = sequence->at(0)->getAsTyped();
    114    ASSERT(lhs);
    115 
    116    TIntermDeclaration *lhsVariableDeclaration = nullptr;
    117    TVariable *lhsVariable =
    118        DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
    119    insertStatementInParentBlock(lhsVariableDeclaration);
    120 
    121    // Create a chain of n-1 multiples.
    122    TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
    123    for (int i = 1; i < n; ++i)
    124    {
    125        TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
    126        mul->setLine(node->getLine());
    127        current = mul;
    128    }
    129 
    130    // For negative pow, compute the reciprocal of the positive pow.
    131    if (exponent < 0)
    132    {
    133        TConstantUnion *oneVal = new TConstantUnion();
    134        oneVal->setFConst(1.0f);
    135        TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
    136        TIntermBinary *div            = new TIntermBinary(EOpDiv, oneNode, current);
    137        current                       = div;
    138    }
    139 
    140    queueReplacement(current, OriginalNode::IS_DROPPED);
    141    mFound = true;
    142    return false;
    143 }
    144 
    145 }  // anonymous namespace
    146 
    147 bool ExpandIntegerPowExpressions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
    148 {
    149    return Traverser::Apply(compiler, root, symbolTable);
    150 }
    151 
    152 }  // namespace sh