ScalarizeVecAndMatConstructorArgs.cpp (7330B)
1 // 2 // Copyright 2002 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have 7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids 8 // driver bugs around vector and matrix constructors. 9 // 10 11 #include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h" 12 #include "common/debug.h" 13 14 #include <algorithm> 15 16 #include "angle_gl.h" 17 #include "common/angleutils.h" 18 #include "compiler/translator/Compiler.h" 19 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" 20 #include "compiler/translator/tree_util/IntermNode_util.h" 21 #include "compiler/translator/tree_util/IntermTraverse.h" 22 #include "compiler/translator/util.h" 23 24 namespace sh 25 { 26 27 namespace 28 { 29 30 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermTyped *symbolNode, int index) 31 { 32 return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index)); 33 } 34 35 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermTyped *symbolNode, int colIndex, int rowIndex) 36 { 37 TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex); 38 39 return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex)); 40 } 41 42 class ScalarizeArgsTraverser : public TIntermTraverser 43 { 44 public: 45 ScalarizeArgsTraverser(TSymbolTable *symbolTable) 46 : TIntermTraverser(true, false, false, symbolTable), 47 mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor) 48 {} 49 50 protected: 51 bool visitAggregate(Visit visit, TIntermAggregate *node) override; 52 bool visitBlock(Visit visit, TIntermBlock *node) override; 53 54 private: 55 void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix); 56 57 // If we have the following code: 58 // mat4 m(0); 59 // vec4 v(1, m); 60 // We will rewrite to: 61 // mat4 m(0); 62 // mat4 s0 = m; 63 // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]); 64 // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This 65 // way the possible side effects of the constructor argument will only be evaluated once. 66 TIntermTyped *createTempVariable(TIntermTyped *original); 67 68 std::vector<TIntermSequence> mBlockStack; 69 70 IntermNodePatternMatcher mNodesToScalarize; 71 }; 72 73 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) 74 { 75 ASSERT(visit == PreVisit); 76 if (mNodesToScalarize.match(node, getParentNode())) 77 { 78 if (node->getType().isVector()) 79 { 80 scalarizeArgs(node, false, true); 81 } 82 else 83 { 84 ASSERT(node->getType().isMatrix()); 85 scalarizeArgs(node, true, false); 86 } 87 } 88 return true; 89 } 90 91 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node) 92 { 93 mBlockStack.push_back(TIntermSequence()); 94 { 95 for (TIntermNode *child : *node->getSequence()) 96 { 97 ASSERT(child != nullptr); 98 child->traverse(this); 99 mBlockStack.back().push_back(child); 100 } 101 } 102 if (mBlockStack.back().size() > node->getSequence()->size()) 103 { 104 node->getSequence()->clear(); 105 *(node->getSequence()) = mBlockStack.back(); 106 } 107 mBlockStack.pop_back(); 108 return false; 109 } 110 111 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, 112 bool scalarizeVector, 113 bool scalarizeMatrix) 114 { 115 ASSERT(aggregate); 116 ASSERT(!aggregate->isArray()); 117 int size = static_cast<int>(aggregate->getType().getObjectSize()); 118 TIntermSequence *sequence = aggregate->getSequence(); 119 TIntermSequence originalArgs(*sequence); 120 sequence->clear(); 121 for (TIntermNode *originalArgNode : originalArgs) 122 { 123 ASSERT(size > 0); 124 TIntermTyped *originalArg = originalArgNode->getAsTyped(); 125 ASSERT(originalArg); 126 TIntermTyped *argVariable = createTempVariable(originalArg); 127 if (originalArg->isScalar()) 128 { 129 sequence->push_back(argVariable); 130 size--; 131 } 132 else if (originalArg->isVector()) 133 { 134 if (scalarizeVector) 135 { 136 int repeat = std::min<int>(size, originalArg->getNominalSize()); 137 size -= repeat; 138 for (int index = 0; index < repeat; ++index) 139 { 140 TIntermBinary *newNode = 141 ConstructVectorIndexBinaryNode(argVariable->deepCopy(), index); 142 sequence->push_back(newNode); 143 } 144 } 145 else 146 { 147 sequence->push_back(argVariable); 148 size -= originalArg->getNominalSize(); 149 } 150 } 151 else 152 { 153 ASSERT(originalArg->isMatrix()); 154 if (scalarizeMatrix) 155 { 156 int colIndex = 0, rowIndex = 0; 157 int repeat = std::min<int>(size, originalArg->getCols() * originalArg->getRows()); 158 size -= repeat; 159 while (repeat > 0) 160 { 161 TIntermBinary *newNode = 162 ConstructMatrixIndexBinaryNode(argVariable->deepCopy(), colIndex, rowIndex); 163 sequence->push_back(newNode); 164 rowIndex++; 165 if (rowIndex >= originalArg->getRows()) 166 { 167 rowIndex = 0; 168 colIndex++; 169 } 170 repeat--; 171 } 172 } 173 else 174 { 175 sequence->push_back(argVariable); 176 size -= originalArg->getCols() * originalArg->getRows(); 177 } 178 } 179 } 180 } 181 182 TIntermTyped *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original) 183 { 184 ASSERT(original); 185 186 TType *type = new TType(original->getType()); 187 type->setQualifier(EvqTemporary); 188 189 // The precision of the constant must have been retained (or derived), which will now apply to 190 // the temp variable. In some cases, the precision cannot be derived, so use the constant as 191 // is. For example, in the following standalone statement, the precision of the constant 0 192 // cannot be determined: 193 // 194 // mat2(0, bvec3(m)); 195 // 196 if (IsPrecisionApplicableToType(type->getBasicType()) && type->getPrecision() == EbpUndefined) 197 { 198 return original; 199 } 200 201 TVariable *variable = CreateTempVariable(mSymbolTable, type); 202 203 ASSERT(mBlockStack.size() > 0); 204 TIntermSequence &sequence = mBlockStack.back(); 205 TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original); 206 sequence.push_back(declaration); 207 208 return CreateTempSymbolNode(variable); 209 } 210 211 } // namespace 212 213 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler, 214 TIntermBlock *root, 215 TSymbolTable *symbolTable) 216 { 217 ScalarizeArgsTraverser scalarizer(symbolTable); 218 root->traverse(&scalarizer); 219 220 return compiler->validateAST(root); 221 } 222 223 } // namespace sh