ReplaceArrayOfMatrixVarying.cpp (6272B)
1 // 2 // Copyright 2020 The ANGLE Project Authors. All rights reserved. Use of this 3 // source code is governed by a BSD-style license that can be found in the 4 // LICENSE file. 5 // 6 // ReplaceArrayOfMatrixVarying: Find any references to array of matrices varying 7 // and replace it with array of vectors. 8 // 9 10 #include "compiler/translator/tree_util/ReplaceArrayOfMatrixVarying.h" 11 12 #include <vector> 13 14 #include "common/bitset_utils.h" 15 #include "common/debug.h" 16 #include "common/utilities.h" 17 #include "compiler/translator/Compiler.h" 18 #include "compiler/translator/SymbolTable.h" 19 #include "compiler/translator/tree_util/BuiltIn.h" 20 #include "compiler/translator/tree_util/FindMain.h" 21 #include "compiler/translator/tree_util/IntermNode_util.h" 22 #include "compiler/translator/tree_util/IntermTraverse.h" 23 #include "compiler/translator/tree_util/ReplaceVariable.h" 24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h" 25 #include "compiler/translator/util.h" 26 27 namespace sh 28 { 29 30 // We create two variables to replace the given varying: 31 // - The new varying which is an array of vectors to be used at input/ouput only. 32 // - The new global variable which is a same type as given variable, to temporarily be used 33 // as replacements for assignments, arithmetic ops and so on. During input/ouput phrase, this temp 34 // variable will be copied from/to the array of vectors variable above. 35 // NOTE(hqle): Consider eliminating the need for using temp variable. 36 37 namespace 38 { 39 class CollectVaryingTraverser : public TIntermTraverser 40 { 41 public: 42 CollectVaryingTraverser(std::vector<const TVariable *> *varyingsOut) 43 : TIntermTraverser(true, false, false), mVaryingsOut(varyingsOut) 44 {} 45 46 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override 47 { 48 const TIntermSequence &sequence = *(node->getSequence()); 49 50 if (sequence.size() != 1) 51 { 52 return false; 53 } 54 55 TIntermTyped *variableType = sequence.front()->getAsTyped(); 56 if (!variableType || !IsVarying(variableType->getQualifier()) || 57 !variableType->isMatrix() || !variableType->isArray()) 58 { 59 return false; 60 } 61 62 TIntermSymbol *variableSymbol = variableType->getAsSymbolNode(); 63 if (!variableSymbol) 64 { 65 return false; 66 } 67 68 mVaryingsOut->push_back(&variableSymbol->variable()); 69 70 return false; 71 } 72 73 private: 74 std::vector<const TVariable *> *mVaryingsOut; 75 }; 76 } // namespace 77 78 [[nodiscard]] bool ReplaceArrayOfMatrixVarying(TCompiler *compiler, 79 TIntermBlock *root, 80 TSymbolTable *symbolTable, 81 const TVariable *varying) 82 { 83 const TType &type = varying->getType(); 84 85 // Create global variable to temporarily acts as the given variable in places such as 86 // arithmetic, assignments an so on. 87 TType *tmpReplacementType = new TType(type); 88 tmpReplacementType->setQualifier(EvqGlobal); 89 90 TVariable *tempReplaceVar = new TVariable( 91 symbolTable, ImmutableString(std::string("ANGLE_AOM_Temp_") + varying->name().data()), 92 tmpReplacementType, SymbolType::AngleInternal); 93 94 if (!ReplaceVariable(compiler, root, varying, tempReplaceVar)) 95 { 96 return false; 97 } 98 99 // Create array of vectors type 100 TType *varyingReplaceType = new TType(type); 101 varyingReplaceType->toMatrixColumnType(); 102 varyingReplaceType->toArrayElementType(); 103 varyingReplaceType->makeArray(type.getCols() * type.getOutermostArraySize()); 104 105 TVariable *varyingReplaceVar = 106 new TVariable(symbolTable, varying->name(), varyingReplaceType, SymbolType::UserDefined); 107 108 TIntermSymbol *varyingReplaceDeclarator = new TIntermSymbol(varyingReplaceVar); 109 TIntermDeclaration *varyingReplaceDecl = new TIntermDeclaration; 110 varyingReplaceDecl->appendDeclarator(varyingReplaceDeclarator); 111 root->insertStatement(0, varyingReplaceDecl); 112 113 // Copy from/to the temp variable 114 TIntermBlock *reassignBlock = new TIntermBlock; 115 TIntermSymbol *tempReplaceSymbol = new TIntermSymbol(tempReplaceVar); 116 TIntermSymbol *varyingReplaceSymbol = new TIntermSymbol(varyingReplaceVar); 117 bool isInput = IsVaryingIn(type.getQualifier()); 118 119 for (unsigned int i = 0; i < type.getOutermostArraySize(); ++i) 120 { 121 TIntermBinary *tempMatrixIndexed = 122 new TIntermBinary(EOpIndexDirect, tempReplaceSymbol->deepCopy(), CreateIndexNode(i)); 123 for (uint8_t col = 0; col < type.getCols(); ++col) 124 { 125 126 TIntermBinary *tempMatrixColIndexed = new TIntermBinary( 127 EOpIndexDirect, tempMatrixIndexed->deepCopy(), CreateIndexNode(col)); 128 TIntermBinary *vectorIndexed = 129 new TIntermBinary(EOpIndexDirect, varyingReplaceSymbol->deepCopy(), 130 CreateIndexNode(i * type.getCols() + col)); 131 TIntermBinary *assignment; 132 if (isInput) 133 { 134 assignment = new TIntermBinary(EOpAssign, tempMatrixColIndexed, vectorIndexed); 135 } 136 else 137 { 138 assignment = new TIntermBinary(EOpAssign, vectorIndexed, tempMatrixColIndexed); 139 } 140 reassignBlock->appendStatement(assignment); 141 } 142 } 143 144 if (isInput) 145 { 146 TIntermFunctionDefinition *main = FindMain(root); 147 main->getBody()->insertStatement(0, reassignBlock); 148 return compiler->validateAST(root); 149 } 150 else 151 { 152 return RunAtTheEndOfShader(compiler, root, reassignBlock, symbolTable); 153 } 154 } 155 156 [[nodiscard]] bool ReplaceArrayOfMatrixVaryings(TCompiler *compiler, 157 TIntermBlock *root, 158 TSymbolTable *symbolTable) 159 { 160 std::vector<const TVariable *> arrayOfMatrixVars; 161 CollectVaryingTraverser varCollector(&arrayOfMatrixVars); 162 root->traverse(&varCollector); 163 164 for (const TVariable *var : arrayOfMatrixVars) 165 { 166 if (!ReplaceArrayOfMatrixVarying(compiler, root, symbolTable, var)) 167 { 168 return false; 169 } 170 } 171 172 return true; 173 } 174 175 } // namespace sh