tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ScalarizeVecAndMatConstructorArgs.cpp (7330B)


      1 //
      2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
      7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
      8 // driver bugs around vector and matrix constructors.
      9 //
     10 
     11 #include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h"
     12 #include "common/debug.h"
     13 
     14 #include <algorithm>
     15 
     16 #include "angle_gl.h"
     17 #include "common/angleutils.h"
     18 #include "compiler/translator/Compiler.h"
     19 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
     20 #include "compiler/translator/tree_util/IntermNode_util.h"
     21 #include "compiler/translator/tree_util/IntermTraverse.h"
     22 #include "compiler/translator/util.h"
     23 
     24 namespace sh
     25 {
     26 
     27 namespace
     28 {
     29 
     30 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermTyped *symbolNode, int index)
     31 {
     32    return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
     33 }
     34 
     35 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermTyped *symbolNode, int colIndex, int rowIndex)
     36 {
     37    TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
     38 
     39    return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
     40 }
     41 
     42 class ScalarizeArgsTraverser : public TIntermTraverser
     43 {
     44  public:
     45    ScalarizeArgsTraverser(TSymbolTable *symbolTable)
     46        : TIntermTraverser(true, false, false, symbolTable),
     47          mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
     48    {}
     49 
     50  protected:
     51    bool visitAggregate(Visit visit, TIntermAggregate *node) override;
     52    bool visitBlock(Visit visit, TIntermBlock *node) override;
     53 
     54  private:
     55    void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
     56 
     57    // If we have the following code:
     58    //   mat4 m(0);
     59    //   vec4 v(1, m);
     60    // We will rewrite to:
     61    //   mat4 m(0);
     62    //   mat4 s0 = m;
     63    //   vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
     64    // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
     65    // way the possible side effects of the constructor argument will only be evaluated once.
     66    TIntermTyped *createTempVariable(TIntermTyped *original);
     67 
     68    std::vector<TIntermSequence> mBlockStack;
     69 
     70    IntermNodePatternMatcher mNodesToScalarize;
     71 };
     72 
     73 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
     74 {
     75    ASSERT(visit == PreVisit);
     76    if (mNodesToScalarize.match(node, getParentNode()))
     77    {
     78        if (node->getType().isVector())
     79        {
     80            scalarizeArgs(node, false, true);
     81        }
     82        else
     83        {
     84            ASSERT(node->getType().isMatrix());
     85            scalarizeArgs(node, true, false);
     86        }
     87    }
     88    return true;
     89 }
     90 
     91 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
     92 {
     93    mBlockStack.push_back(TIntermSequence());
     94    {
     95        for (TIntermNode *child : *node->getSequence())
     96        {
     97            ASSERT(child != nullptr);
     98            child->traverse(this);
     99            mBlockStack.back().push_back(child);
    100        }
    101    }
    102    if (mBlockStack.back().size() > node->getSequence()->size())
    103    {
    104        node->getSequence()->clear();
    105        *(node->getSequence()) = mBlockStack.back();
    106    }
    107    mBlockStack.pop_back();
    108    return false;
    109 }
    110 
    111 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
    112                                           bool scalarizeVector,
    113                                           bool scalarizeMatrix)
    114 {
    115    ASSERT(aggregate);
    116    ASSERT(!aggregate->isArray());
    117    int size                  = static_cast<int>(aggregate->getType().getObjectSize());
    118    TIntermSequence *sequence = aggregate->getSequence();
    119    TIntermSequence originalArgs(*sequence);
    120    sequence->clear();
    121    for (TIntermNode *originalArgNode : originalArgs)
    122    {
    123        ASSERT(size > 0);
    124        TIntermTyped *originalArg = originalArgNode->getAsTyped();
    125        ASSERT(originalArg);
    126        TIntermTyped *argVariable = createTempVariable(originalArg);
    127        if (originalArg->isScalar())
    128        {
    129            sequence->push_back(argVariable);
    130            size--;
    131        }
    132        else if (originalArg->isVector())
    133        {
    134            if (scalarizeVector)
    135            {
    136                int repeat = std::min<int>(size, originalArg->getNominalSize());
    137                size -= repeat;
    138                for (int index = 0; index < repeat; ++index)
    139                {
    140                    TIntermBinary *newNode =
    141                        ConstructVectorIndexBinaryNode(argVariable->deepCopy(), index);
    142                    sequence->push_back(newNode);
    143                }
    144            }
    145            else
    146            {
    147                sequence->push_back(argVariable);
    148                size -= originalArg->getNominalSize();
    149            }
    150        }
    151        else
    152        {
    153            ASSERT(originalArg->isMatrix());
    154            if (scalarizeMatrix)
    155            {
    156                int colIndex = 0, rowIndex = 0;
    157                int repeat = std::min<int>(size, originalArg->getCols() * originalArg->getRows());
    158                size -= repeat;
    159                while (repeat > 0)
    160                {
    161                    TIntermBinary *newNode =
    162                        ConstructMatrixIndexBinaryNode(argVariable->deepCopy(), colIndex, rowIndex);
    163                    sequence->push_back(newNode);
    164                    rowIndex++;
    165                    if (rowIndex >= originalArg->getRows())
    166                    {
    167                        rowIndex = 0;
    168                        colIndex++;
    169                    }
    170                    repeat--;
    171                }
    172            }
    173            else
    174            {
    175                sequence->push_back(argVariable);
    176                size -= originalArg->getCols() * originalArg->getRows();
    177            }
    178        }
    179    }
    180 }
    181 
    182 TIntermTyped *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
    183 {
    184    ASSERT(original);
    185 
    186    TType *type = new TType(original->getType());
    187    type->setQualifier(EvqTemporary);
    188 
    189    // The precision of the constant must have been retained (or derived), which will now apply to
    190    // the temp variable.  In some cases, the precision cannot be derived, so use the constant as
    191    // is.  For example, in the following standalone statement, the precision of the constant 0
    192    // cannot be determined:
    193    //
    194    //      mat2(0, bvec3(m));
    195    //
    196    if (IsPrecisionApplicableToType(type->getBasicType()) && type->getPrecision() == EbpUndefined)
    197    {
    198        return original;
    199    }
    200 
    201    TVariable *variable = CreateTempVariable(mSymbolTable, type);
    202 
    203    ASSERT(mBlockStack.size() > 0);
    204    TIntermSequence &sequence       = mBlockStack.back();
    205    TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
    206    sequence.push_back(declaration);
    207 
    208    return CreateTempSymbolNode(variable);
    209 }
    210 
    211 }  // namespace
    212 
    213 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
    214                                       TIntermBlock *root,
    215                                       TSymbolTable *symbolTable)
    216 {
    217    ScalarizeArgsTraverser scalarizer(symbolTable);
    218    root->traverse(&scalarizer);
    219 
    220    return compiler->validateAST(root);
    221 }
    222 
    223 }  // namespace sh