RewriteStructSamplers.cpp (26004B)
1 // 2 // Copyright 2018 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // RewriteStructSamplers: Extract samplers from structs. 7 // 8 9 #include "compiler/translator/tree_ops/RewriteStructSamplers.h" 10 11 #include "compiler/translator/ImmutableStringBuilder.h" 12 #include "compiler/translator/SymbolTable.h" 13 #include "compiler/translator/tree_util/IntermNode_util.h" 14 #include "compiler/translator/tree_util/IntermTraverse.h" 15 16 namespace sh 17 { 18 namespace 19 { 20 21 // Used to map one structure type to another (one where the samplers are removed). 22 struct StructureData 23 { 24 // The structure this was replaced with. If nullptr, it means the structure is removed (because 25 // it had all samplers). 26 const TStructure *modified; 27 // Indexed by the field index of original structure, to get the field index of the modified 28 // structure. For example: 29 // 30 // struct Original 31 // { 32 // sampler2D s1; 33 // vec4 f1; 34 // sampler2D s2; 35 // sampler2D s3; 36 // vec4 f2; 37 // }; 38 // 39 // struct Modified 40 // { 41 // vec4 f1; 42 // vec4 f2; 43 // }; 44 // 45 // fieldMap: 46 // 0 -> Invalid 47 // 1 -> 0 48 // 2 -> Invalid 49 // 3 -> Invalid 50 // 4 -> 1 51 // 52 TVector<int> fieldMap; 53 }; 54 55 using StructureMap = angle::HashMap<const TStructure *, StructureData>; 56 using StructureUniformMap = angle::HashMap<const TVariable *, const TVariable *>; 57 using ExtractedSamplerMap = angle::HashMap<std::string, const TVariable *>; 58 59 TIntermTyped *RewriteModifiedStructFieldSelectionExpression( 60 TCompiler *compiler, 61 TIntermBinary *node, 62 const StructureMap &structureMap, 63 const StructureUniformMap &structureUniformMap, 64 const ExtractedSamplerMap &extractedSamplers); 65 66 TIntermTyped *RewriteExpressionVisitBinaryHelper(TCompiler *compiler, 67 TIntermBinary *node, 68 const StructureMap &structureMap, 69 const StructureUniformMap &structureUniformMap, 70 const ExtractedSamplerMap &extractedSamplers) 71 { 72 // Only interested in EOpIndexDirectStruct binary nodes. 73 if (node->getOp() != EOpIndexDirectStruct) 74 { 75 return nullptr; 76 } 77 78 const TStructure *structure = node->getLeft()->getType().getStruct(); 79 ASSERT(structure); 80 81 // If the result of the index is not a sampler and the struct is not replaced, there's nothing 82 // to do. 83 if (!node->getType().isSampler() && structureMap.find(structure) == structureMap.end()) 84 { 85 return nullptr; 86 } 87 88 // Otherwise, replace the whole expression such that: 89 // 90 // - if sampler, it's indexed with whatever indices the parent structs were indexed with, 91 // - otherwise, the chain of field selections is rewritten by modifying the base uniform so all 92 // the intermediate nodes would have the correct type (and therefore fields). 93 ASSERT(structureMap.find(structure) != structureMap.end()); 94 95 return RewriteModifiedStructFieldSelectionExpression(compiler, node, structureMap, 96 structureUniformMap, extractedSamplers); 97 } 98 99 // Given an expression, this traverser calculates a new expression where sampler-in-structs are 100 // replaced with their extracted ones, and field indices are adjusted for the rest of the fields. 101 // In particular, this is run on the right node of EOpIndexIndirect binary nodes, so that the 102 // expression in the index gets a chance to go through this transformation. 103 class RewriteExpressionTraverser final : public TIntermTraverser 104 { 105 public: 106 explicit RewriteExpressionTraverser(TCompiler *compiler, 107 const StructureMap &structureMap, 108 const StructureUniformMap &structureUniformMap, 109 const ExtractedSamplerMap &extractedSamplers) 110 : TIntermTraverser(true, false, false), 111 mCompiler(compiler), 112 mStructureMap(structureMap), 113 mStructureUniformMap(structureUniformMap), 114 mExtractedSamplers(extractedSamplers) 115 {} 116 117 bool visitBinary(Visit visit, TIntermBinary *node) override 118 { 119 TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper( 120 mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers); 121 122 if (rewritten == nullptr) 123 { 124 return true; 125 } 126 127 queueReplacement(rewritten, OriginalNode::IS_DROPPED); 128 129 // Don't iterate as the expression is rewritten. 130 return false; 131 } 132 133 void visitSymbol(TIntermSymbol *node) override 134 { 135 // It's impossible to reach here with a symbol that needs replacement. 136 // MonomorphizeUnsupportedFunctions makes sure that whole structs containing 137 // samplers are not passed to functions, so any instance of the struct uniform is 138 // necessarily indexed right away. visitBinary should have already taken care of it. 139 ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end()); 140 } 141 142 private: 143 TCompiler *mCompiler; 144 145 // See RewriteStructSamplersTraverser. 146 const StructureMap &mStructureMap; 147 const StructureUniformMap &mStructureUniformMap; 148 const ExtractedSamplerMap &mExtractedSamplers; 149 }; 150 151 // Rewrite the index of an EOpIndexIndirect expression. The root can never need replacing, because 152 // it cannot be a sampler itself or of a struct type. 153 void RewriteIndexExpression(TCompiler *compiler, 154 TIntermTyped *expression, 155 const StructureMap &structureMap, 156 const StructureUniformMap &structureUniformMap, 157 const ExtractedSamplerMap &extractedSamplers) 158 { 159 RewriteExpressionTraverser traverser(compiler, structureMap, structureUniformMap, 160 extractedSamplers); 161 expression->traverse(&traverser); 162 bool valid = traverser.updateTree(compiler, expression); 163 ASSERT(valid); 164 } 165 166 // Given an expression such as the following: 167 // 168 // EOpIndexDirectStruct (sampler) 169 // / \ 170 // EOpIndex* field index 171 // / \ 172 // EOpIndexDirectStruct index 2 173 // / \ 174 // EOpIndex* field index 175 // / \ 176 // EOpIndexDirectStruct index 1 177 // / \ 178 // Uniform Struct field index 179 // 180 // produces: 181 // 182 // EOpIndex* 183 // / \ 184 // EOpIndex* index 2 185 // / \ 186 // sampler index 1 187 // 188 // Alternatively, if the expression is as such: 189 // 190 // EOpIndexDirectStruct 191 // / \ 192 // (modified struct type) EOpIndex* field index 193 // / \ 194 // EOpIndexDirectStruct index 2 195 // / \ 196 // EOpIndex* field index 197 // / \ 198 // EOpIndexDirectStruct index 1 199 // / \ 200 // Uniform Struct field index 201 // 202 // produces: 203 // 204 // EOpIndexDirectStruct 205 // / \ 206 // EOpIndex* mapped field index 207 // / \ 208 // EOpIndexDirectStruct index 2 209 // / \ 210 // EOpIndex* mapped field index 211 // / \ 212 // EOpIndexDirectStruct index 1 213 // / \ 214 // Uniform Struct mapped field index 215 // 216 TIntermTyped *RewriteModifiedStructFieldSelectionExpression( 217 TCompiler *compiler, 218 TIntermBinary *node, 219 const StructureMap &structureMap, 220 const StructureUniformMap &structureUniformMap, 221 const ExtractedSamplerMap &extractedSamplers) 222 { 223 ASSERT(node->getOp() == EOpIndexDirectStruct); 224 225 const bool isSampler = node->getType().isSampler(); 226 227 TIntermSymbol *baseUniform = nullptr; 228 std::string samplerName; 229 230 TVector<TIntermBinary *> indexNodeStack; 231 232 // Iterate once and build the name of the sampler. 233 TIntermBinary *iter = node; 234 while (baseUniform == nullptr) 235 { 236 indexNodeStack.push_back(iter); 237 baseUniform = iter->getLeft()->getAsSymbolNode(); 238 239 if (isSampler) 240 { 241 if (iter->getOp() == EOpIndexDirectStruct) 242 { 243 // When indexed into a struct, get the field name instead and construct the sampler 244 // name. 245 samplerName.insert(0, iter->getIndexStructFieldName().data()); 246 samplerName.insert(0, "_"); 247 } 248 249 if (baseUniform) 250 { 251 // If left is a symbol, we have reached the end of the chain. Use the struct name 252 // to finish building the name of the sampler. 253 samplerName.insert(0, baseUniform->variable().name().data()); 254 } 255 } 256 257 iter = iter->getLeft()->getAsBinaryNode(); 258 } 259 260 TIntermTyped *rewritten = nullptr; 261 262 if (isSampler) 263 { 264 ASSERT(extractedSamplers.find(samplerName) != extractedSamplers.end()); 265 rewritten = new TIntermSymbol(extractedSamplers.at(samplerName)); 266 } 267 else 268 { 269 const TVariable *baseUniformVar = &baseUniform->variable(); 270 ASSERT(structureUniformMap.find(baseUniformVar) != structureUniformMap.end()); 271 rewritten = new TIntermSymbol(structureUniformMap.at(baseUniformVar)); 272 } 273 274 // Iterate again and build the expression from bottom up. 275 for (auto it = indexNodeStack.rbegin(); it != indexNodeStack.rend(); ++it) 276 { 277 TIntermBinary *indexNode = *it; 278 279 switch (indexNode->getOp()) 280 { 281 case EOpIndexDirectStruct: 282 if (!isSampler) 283 { 284 // Remap the field. 285 const TStructure *structure = indexNode->getLeft()->getType().getStruct(); 286 ASSERT(structureMap.find(structure) != structureMap.end()); 287 288 TIntermConstantUnion *asConstantUnion = 289 indexNode->getRight()->getAsConstantUnion(); 290 ASSERT(asConstantUnion); 291 292 const int fieldIndex = asConstantUnion->getIConst(0); 293 ASSERT(fieldIndex < 294 static_cast<int>(structureMap.at(structure).fieldMap.size())); 295 296 const int mappedFieldIndex = structureMap.at(structure).fieldMap[fieldIndex]; 297 298 rewritten = new TIntermBinary(EOpIndexDirectStruct, rewritten, 299 CreateIndexNode(mappedFieldIndex)); 300 } 301 break; 302 303 case EOpIndexDirect: 304 rewritten = new TIntermBinary(EOpIndexDirect, rewritten, indexNode->getRight()); 305 break; 306 307 case EOpIndexIndirect: 308 { 309 // Run RewriteExpressionTraverser on the right node. It may itself be an expression 310 // with a sampler inside that needs to be rewritten, or simply use a field of a 311 // struct that's remapped. 312 TIntermTyped *indexExpression = indexNode->getRight(); 313 RewriteIndexExpression(compiler, indexExpression, structureMap, structureUniformMap, 314 extractedSamplers); 315 rewritten = new TIntermBinary(EOpIndexIndirect, rewritten, indexExpression); 316 break; 317 } 318 319 default: 320 UNREACHABLE(); 321 break; 322 } 323 } 324 325 return rewritten; 326 } 327 328 class RewriteStructSamplersTraverser final : public TIntermTraverser 329 { 330 public: 331 explicit RewriteStructSamplersTraverser(TCompiler *compiler, TSymbolTable *symbolTable) 332 : TIntermTraverser(true, false, false, symbolTable), 333 mCompiler(compiler), 334 mRemovedUniformsCount(0) 335 {} 336 337 int removedUniformsCount() const { return mRemovedUniformsCount; } 338 339 // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each 340 // stripped struct sampler. 341 bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override 342 { 343 if (!mInGlobalScope) 344 { 345 return true; 346 } 347 348 const TIntermSequence &sequence = *(decl->getSequence()); 349 TIntermTyped *declarator = sequence.front()->getAsTyped(); 350 const TType &type = declarator->getType(); 351 352 if (!type.isStructureContainingSamplers()) 353 { 354 return false; 355 } 356 357 TIntermSequence newSequence; 358 359 if (type.isStructSpecifier()) 360 { 361 // If this is just a struct definition (not a uniform variable declaration of a 362 // struct type), just remove the samplers. They are not instantiated yet. 363 const TStructure *structure = type.getStruct(); 364 ASSERT(structure && mStructureMap.find(structure) == mStructureMap.end()); 365 366 stripStructSpecifierSamplers(structure, &newSequence); 367 } 368 else 369 { 370 const TStructure *structure = type.getStruct(); 371 372 // If the structure is defined at the same time, create the mapping to the stripped 373 // version first. 374 if (mStructureMap.find(structure) == mStructureMap.end()) 375 { 376 stripStructSpecifierSamplers(structure, &newSequence); 377 } 378 379 // Then, extract the samplers from the struct and create global-scope variables instead. 380 TIntermSymbol *asSymbol = declarator->getAsSymbolNode(); 381 ASSERT(asSymbol); 382 const TVariable &variable = asSymbol->variable(); 383 ASSERT(variable.symbolType() != SymbolType::Empty); 384 385 extractStructSamplerUniforms(variable, structure, &newSequence); 386 } 387 388 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, 389 std::move(newSequence)); 390 391 return false; 392 } 393 394 // Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root. 395 bool visitBinary(Visit visit, TIntermBinary *node) override 396 { 397 TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper( 398 mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers); 399 400 if (rewritten == nullptr) 401 { 402 return true; 403 } 404 405 queueReplacement(rewritten, OriginalNode::IS_DROPPED); 406 407 // Don't iterate as the expression is rewritten. 408 return false; 409 } 410 411 // Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root. 412 void visitSymbol(TIntermSymbol *node) override 413 { 414 ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end()); 415 } 416 417 private: 418 // Removes all samplers from a struct specifier. 419 void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence) 420 { 421 TFieldList *newFieldList = new TFieldList; 422 ASSERT(structure->containsSamplers()); 423 424 // Add this struct to the struct map 425 ASSERT(mStructureMap.find(structure) == mStructureMap.end()); 426 StructureData *modifiedData = &mStructureMap[structure]; 427 428 modifiedData->modified = nullptr; 429 modifiedData->fieldMap.resize(structure->fields().size(), std::numeric_limits<int>::max()); 430 431 for (size_t fieldIndex = 0; fieldIndex < structure->fields().size(); ++fieldIndex) 432 { 433 const TField *field = structure->fields()[fieldIndex]; 434 const TType &fieldType = *field->type(); 435 436 // If the field is a sampler, or a struct that's entirely removed, skip it. 437 if (!fieldType.isSampler() && !isRemovedStructType(fieldType)) 438 { 439 TType *newType = nullptr; 440 441 // Otherwise, if it's a struct that's replaced, create a new field of the replaced 442 // type. 443 if (fieldType.isStructureContainingSamplers()) 444 { 445 const TStructure *fieldStruct = fieldType.getStruct(); 446 ASSERT(mStructureMap.find(fieldStruct) != mStructureMap.end()); 447 448 const TStructure *modifiedStruct = mStructureMap[fieldStruct].modified; 449 ASSERT(modifiedStruct); 450 451 newType = new TType(modifiedStruct, true); 452 if (fieldType.isArray()) 453 { 454 newType->makeArrays(fieldType.getArraySizes()); 455 } 456 } 457 else 458 { 459 // If not, duplicate the field as is. 460 newType = new TType(fieldType); 461 } 462 463 // Record the mapping of the field indices, so future EOpIndexDirectStruct's into 464 // this struct can be fixed up. 465 modifiedData->fieldMap[fieldIndex] = static_cast<int>(newFieldList->size()); 466 467 TField *newField = 468 new TField(newType, field->name(), field->line(), field->symbolType()); 469 newFieldList->push_back(newField); 470 } 471 } 472 473 // Prune empty structs. 474 if (newFieldList->empty()) 475 { 476 return; 477 } 478 479 // Declare a new struct with the same name and the new fields. 480 modifiedData->modified = 481 new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType()); 482 TType *newStructType = new TType(modifiedData->modified, true); 483 TVariable *newStructVar = 484 new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty); 485 TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar); 486 487 TIntermDeclaration *structDecl = new TIntermDeclaration; 488 structDecl->appendDeclarator(newStructRef); 489 490 newSequence->push_back(structDecl); 491 } 492 493 // Returns true if the type is a struct that was removed because we extracted all the members. 494 bool isRemovedStructType(const TType &type) const 495 { 496 const TStructure *structure = type.getStruct(); 497 if (structure == nullptr) 498 { 499 // Not a struct 500 return false; 501 } 502 503 // A struct is removed if it is in the map, but doesn't have a replacement struct. 504 auto iter = mStructureMap.find(structure); 505 return iter != mStructureMap.end() && iter->second.modified == nullptr; 506 } 507 508 // Removes samplers from struct uniforms. For each sampler removed also adds a new globally 509 // defined sampler uniform. 510 void extractStructSamplerUniforms(const TVariable &variable, 511 const TStructure *structure, 512 TIntermSequence *newSequence) 513 { 514 ASSERT(structure->containsSamplers()); 515 ASSERT(mStructureMap.find(structure) != mStructureMap.end()); 516 517 const TType &type = variable.getType(); 518 enterArray(type); 519 520 for (const TField *field : structure->fields()) 521 { 522 extractFieldSamplers(variable.name().data(), field, newSequence); 523 } 524 525 // If there's a replacement structure (because there are non-sampler fields in the struct), 526 // add a declaration with that type. 527 const TStructure *modified = mStructureMap[structure].modified; 528 if (modified != nullptr) 529 { 530 TType *newType = new TType(modified, false); 531 if (type.isArray()) 532 { 533 newType->makeArrays(type.getArraySizes()); 534 } 535 newType->setQualifier(EvqUniform); 536 const TVariable *newVariable = 537 new TVariable(mSymbolTable, variable.name(), newType, variable.symbolType()); 538 539 TIntermDeclaration *newDecl = new TIntermDeclaration(); 540 newDecl->appendDeclarator(new TIntermSymbol(newVariable)); 541 542 newSequence->push_back(newDecl); 543 544 ASSERT(mStructureUniformMap.find(&variable) == mStructureUniformMap.end()); 545 mStructureUniformMap[&variable] = newVariable; 546 } 547 else 548 { 549 mRemovedUniformsCount++; 550 } 551 552 exitArray(type); 553 } 554 555 // Extracts samplers from a field of a struct. Works with nested structs and arrays. 556 void extractFieldSamplers(const std::string &prefix, 557 const TField *field, 558 TIntermSequence *newSequence) 559 { 560 const TType &fieldType = *field->type(); 561 if (fieldType.isSampler() || fieldType.isStructureContainingSamplers()) 562 { 563 std::string newPrefix = prefix + "_" + field->name().data(); 564 565 if (fieldType.isSampler()) 566 { 567 extractSampler(newPrefix, fieldType, newSequence); 568 } 569 else 570 { 571 enterArray(fieldType); 572 const TStructure *structure = fieldType.getStruct(); 573 for (const TField *nestedField : structure->fields()) 574 { 575 extractFieldSamplers(newPrefix, nestedField, newSequence); 576 } 577 exitArray(fieldType); 578 } 579 } 580 } 581 582 void GenerateArraySizesFromStack(TVector<unsigned int> *sizesOut) 583 { 584 sizesOut->reserve(mArraySizeStack.size()); 585 586 for (auto it = mArraySizeStack.rbegin(); it != mArraySizeStack.rend(); ++it) 587 { 588 sizesOut->push_back(*it); 589 } 590 } 591 592 // Extracts a sampler from a struct. Declares the new extracted sampler. 593 void extractSampler(const std::string &newName, 594 const TType &fieldType, 595 TIntermSequence *newSequence) 596 { 597 ASSERT(fieldType.isSampler()); 598 599 TType *newType = new TType(fieldType); 600 601 // Add array dimensions accumulated so far due to struct arrays. Note that to support 602 // nested arrays, mArraySizeStack has the outermost size in the front. |makeArrays| thus 603 // expects this in reverse order. 604 TVector<unsigned int> parentArraySizes; 605 GenerateArraySizesFromStack(&parentArraySizes); 606 newType->makeArrays(parentArraySizes); 607 608 ImmutableStringBuilder nameBuilder(newName.size() + 1); 609 nameBuilder << newName; 610 611 newType->setQualifier(EvqUniform); 612 TVariable *newVariable = 613 new TVariable(mSymbolTable, nameBuilder, newType, SymbolType::AngleInternal); 614 TIntermSymbol *newSymbol = new TIntermSymbol(newVariable); 615 616 TIntermDeclaration *samplerDecl = new TIntermDeclaration; 617 samplerDecl->appendDeclarator(newSymbol); 618 619 newSequence->push_back(samplerDecl); 620 621 // TODO: Use a temp name instead of generating a name as currently done. There is no 622 // guarantee that these generated names cannot clash. Create a mapping from the previous 623 // name to the name assigned to the temp variable so ShaderVariable::mappedName can be 624 // updated post-transformation. http://anglebug.com/4301 625 ASSERT(mExtractedSamplers.find(newName) == mExtractedSamplers.end()); 626 mExtractedSamplers[newName] = newVariable; 627 } 628 629 void enterArray(const TType &arrayType) 630 { 631 const TSpan<const unsigned int> &arraySizes = arrayType.getArraySizes(); 632 for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it) 633 { 634 unsigned int arraySize = *it; 635 mArraySizeStack.push_back(arraySize); 636 } 637 } 638 639 void exitArray(const TType &arrayType) 640 { 641 mArraySizeStack.resize(mArraySizeStack.size() - arrayType.getNumArraySizes()); 642 } 643 644 TCompiler *mCompiler; 645 int mRemovedUniformsCount; 646 647 // Map structures with samplers to ones that have their samplers removed. 648 StructureMap mStructureMap; 649 650 // Map uniform variables of structure type that are replaced with another variable. 651 StructureUniformMap mStructureUniformMap; 652 653 // Map a constructed sampler name to its variable. Used to replace an expression that uses this 654 // sampler with the extracted one. 655 ExtractedSamplerMap mExtractedSamplers; 656 657 // A stack of array sizes. Used to figure out the array dimensions of the extracted sampler, 658 // for example when it's nested in an array of structs in an array of structs. 659 TVector<unsigned int> mArraySizeStack; 660 }; 661 } // anonymous namespace 662 663 bool RewriteStructSamplers(TCompiler *compiler, 664 TIntermBlock *root, 665 TSymbolTable *symbolTable, 666 int *removedUniformsCountOut) 667 { 668 RewriteStructSamplersTraverser traverser(compiler, symbolTable); 669 root->traverse(&traverser); 670 *removedUniformsCountOut = traverser.removedUniformsCount(); 671 return traverser.updateTree(compiler, root); 672 } 673 } // namespace sh