tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

RewritePixelLocalStorage.cpp (37512B)


      1 //
      2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 #include "compiler/translator/tree_ops/RewritePixelLocalStorage.h"
      8 
      9 #include "common/angleutils.h"
     10 #include "compiler/translator/StaticType.h"
     11 #include "compiler/translator/SymbolTable.h"
     12 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
     13 #include "compiler/translator/tree_util/BuiltIn.h"
     14 #include "compiler/translator/tree_util/FindMain.h"
     15 #include "compiler/translator/tree_util/IntermNode_util.h"
     16 #include "compiler/translator/tree_util/IntermTraverse.h"
     17 
     18 namespace sh
     19 {
     20 namespace
     21 {
     22 constexpr static TBasicType DataTypeOfPLSType(TBasicType plsType)
     23 {
     24    switch (plsType)
     25    {
     26        case EbtPixelLocalANGLE:
     27            return EbtFloat;
     28        case EbtIPixelLocalANGLE:
     29            return EbtInt;
     30        case EbtUPixelLocalANGLE:
     31            return EbtUInt;
     32        default:
     33            UNREACHABLE();
     34            return EbtVoid;
     35    }
     36 }
     37 
     38 constexpr static TBasicType DataTypeOfImageType(TBasicType imageType)
     39 {
     40    switch (imageType)
     41    {
     42        case EbtImage2D:
     43            return EbtFloat;
     44        case EbtIImage2D:
     45            return EbtInt;
     46        case EbtUImage2D:
     47            return EbtUInt;
     48        default:
     49            UNREACHABLE();
     50            return EbtVoid;
     51    }
     52 }
     53 
     54 // Maps PLS symbols to a backing store.
     55 template <typename T>
     56 class PLSBackingStoreMap
     57 {
     58  public:
     59    // Sets the given variable as the backing storage for the plsSymbol's binding point. An entry
     60    // must not already exist in the map for this binding point.
     61    void insertNew(TIntermSymbol *plsSymbol, const T &backingStore)
     62    {
     63        ASSERT(plsSymbol);
     64        ASSERT(IsPixelLocal(plsSymbol->getBasicType()));
     65        int binding = plsSymbol->getType().getLayoutQualifier().binding;
     66        ASSERT(binding >= 0);
     67        auto result = mMap.insert({binding, backingStore});
     68        ASSERT(result.second);  // Ensure an image didn't already exist for this symbol.
     69    }
     70 
     71    // Looks up the backing store for the given plsSymbol's binding point. An entry must already
     72    // exist in the map for this binding point.
     73    const T &find(TIntermSymbol *plsSymbol)
     74    {
     75        ASSERT(plsSymbol);
     76        ASSERT(IsPixelLocal(plsSymbol->getBasicType()));
     77        int binding = plsSymbol->getType().getLayoutQualifier().binding;
     78        ASSERT(binding >= 0);
     79        auto iter = mMap.find(binding);
     80        ASSERT(iter != mMap.end());  // Ensure PLSImages already exist for this symbol.
     81        return iter->second;
     82    }
     83 
     84    const std::map<int, T> &bindingOrderedMap() const { return mMap; }
     85 
     86  private:
     87    // Use std::map so the backing stores are ordered by binding when we iterate.
     88    std::map<int, T> mMap;
     89 };
     90 
     91 // Base class for rewriting high level PLS operations to AST operations specified by
     92 // ShPixelLocalStorageType.
     93 class RewritePLSTraverser : public TIntermTraverser
     94 {
     95  public:
     96    RewritePLSTraverser(TCompiler *compiler,
     97                        TSymbolTable &symbolTable,
     98                        const ShCompileOptions &compileOptions,
     99                        int shaderVersion)
    100        : TIntermTraverser(true, false, false, &symbolTable),
    101          mCompiler(compiler),
    102          mCompileOptions(&compileOptions),
    103          mShaderVersion(shaderVersion)
    104    {}
    105 
    106    bool visitDeclaration(Visit, TIntermDeclaration *decl) override
    107    {
    108        TIntermTyped *declVariable = (decl->getSequence())->front()->getAsTyped();
    109        ASSERT(declVariable);
    110 
    111        if (!IsPixelLocal(declVariable->getBasicType()))
    112        {
    113            return true;
    114        }
    115 
    116        // PLS is not allowed in arrays.
    117        ASSERT(!declVariable->isArray());
    118 
    119        // This visitDeclaration doesn't get called for function arguments, and opaque types can
    120        // otherwise only be uniforms.
    121        ASSERT(declVariable->getQualifier() == EvqUniform);
    122 
    123        TIntermSymbol *plsSymbol = declVariable->getAsSymbolNode();
    124        ASSERT(plsSymbol);
    125 
    126        visitPLSDeclaration(plsSymbol);
    127 
    128        return false;
    129    }
    130 
    131    bool visitAggregate(Visit, TIntermAggregate *aggregate) override
    132    {
    133        if (!BuiltInGroup::IsPixelLocal(aggregate->getOp()))
    134        {
    135            return true;
    136        }
    137 
    138        const TIntermSequence &args = *aggregate->getSequence();
    139        ASSERT(args.size() >= 1);
    140        TIntermSymbol *plsSymbol = args[0]->getAsSymbolNode();
    141 
    142        // Rewrite pixelLocalLoadANGLE -> imageLoad.
    143        if (aggregate->getOp() == EOpPixelLocalLoadANGLE)
    144        {
    145            visitPLSLoad(plsSymbol);
    146            return false;  // No need to recurse since this node is being dropped.
    147        }
    148 
    149        // Rewrite pixelLocalStoreANGLE -> imageStore.
    150        if (aggregate->getOp() == EOpPixelLocalStoreANGLE)
    151        {
    152            // Also hoist the 'value' expression into a temp. In the event of
    153            // "pixelLocalStoreANGLE(..., pixelLocalLoadANGLE(...))", this ensures the load occurs
    154            // _before_ any potential barriers required by the subclass.
    155            //
    156            // NOTE: It is generally unsafe to hoist function arguments due to short circuiting,
    157            // e.g., "if (false && function(...))", but pixelLocalStoreANGLE returns type void, so
    158            // it is safe in this particular case.
    159            TType *valueType    = new TType(DataTypeOfPLSType(plsSymbol->getBasicType()),
    160                                            plsSymbol->getPrecision(), EvqTemporary, 4);
    161            TVariable *valueVar = CreateTempVariable(mSymbolTable, valueType);
    162            TIntermDeclaration *valueDecl =
    163                CreateTempInitDeclarationNode(valueVar, args[1]->getAsTyped());
    164            valueDecl->traverse(this);  // Rewrite any potential pixelLocalLoadANGLEs in valueDecl.
    165            insertStatementInParentBlock(valueDecl);
    166 
    167            visitPLSStore(plsSymbol, valueVar);
    168            return false;  // No need to recurse since this node is being dropped.
    169        }
    170 
    171        return true;
    172    }
    173 
    174    // Called after rewrite. Injects one-time setup code that needs to run before any PLS accesses.
    175    virtual void injectSetupCode(TCompiler *,
    176                                 TSymbolTable &,
    177                                 const ShCompileOptions &,
    178                                 TIntermBlock *mainBody,
    179                                 size_t plsBeginPosition)
    180    {}
    181 
    182    // Called after rewrite. Injects one-time finalization code that needs to run after all PLS.
    183    virtual void injectFinalizeCode(TCompiler *,
    184                                    TSymbolTable &,
    185                                    const ShCompileOptions &,
    186                                    TIntermBlock *mainBody,
    187                                    size_t plsEndPosition)
    188    {}
    189 
    190    TVariable *globalPixelCoord() const { return mGlobalPixelCoord; }
    191 
    192  protected:
    193    virtual void visitPLSDeclaration(TIntermSymbol *plsSymbol)             = 0;
    194    virtual void visitPLSLoad(TIntermSymbol *plsSymbol)                    = 0;
    195    virtual void visitPLSStore(TIntermSymbol *plsSymbol, TVariable *value) = 0;
    196 
    197    void ensureGlobalPixelCoordDeclared()
    198    {
    199        // Insert a global to hold the pixel coordinate as soon as we see PLS declared. This will be
    200        // initialized at the beginning of main().
    201        if (!mGlobalPixelCoord)
    202        {
    203            TType *coordType  = new TType(EbtInt, EbpHigh, EvqGlobal, 2);
    204            mGlobalPixelCoord = CreateTempVariable(mSymbolTable, coordType);
    205            insertStatementInParentBlock(CreateTempDeclarationNode(mGlobalPixelCoord));
    206        }
    207    }
    208 
    209    const TCompiler *const mCompiler;
    210    const ShCompileOptions *const mCompileOptions;
    211    const int mShaderVersion;
    212 
    213    // Stores the shader invocation's pixel coordinate as "ivec2(floor(gl_FragCoord.xy))".
    214    TVariable *mGlobalPixelCoord = nullptr;
    215 };
    216 
    217 // Rewrites high level PLS operations to shader image operations.
    218 class RewritePLSToImagesTraverser : public RewritePLSTraverser
    219 {
    220  public:
    221    RewritePLSToImagesTraverser(TCompiler *compiler,
    222                                TSymbolTable &symbolTable,
    223                                const ShCompileOptions &compileOptions,
    224                                int shaderVersion)
    225        : RewritePLSTraverser(compiler, symbolTable, compileOptions, shaderVersion)
    226    {}
    227 
    228  private:
    229    void visitPLSDeclaration(TIntermSymbol *plsSymbol) override
    230    {
    231        // Replace the PLS declaration with an image2D.
    232        ensureGlobalPixelCoordDeclared();
    233        TVariable *image2D = createPLSImageReplacement(plsSymbol);
    234        mImages.insertNew(plsSymbol, image2D);
    235        queueReplacement(new TIntermDeclaration({new TIntermSymbol(image2D)}),
    236                         OriginalNode::IS_DROPPED);
    237    }
    238 
    239    // Do all PLS formats need to be packed into r32f, r32i, or r32ui image2Ds?
    240    bool needsR32Packing() const
    241    {
    242        return mCompileOptions->pls.type == ShPixelLocalStorageType::ImageStoreR32PackedFormats;
    243    }
    244 
    245    // Creates an image2D that replaces a pixel local storage handle.
    246    TVariable *createPLSImageReplacement(const TIntermSymbol *plsSymbol)
    247    {
    248        ASSERT(plsSymbol);
    249        ASSERT(IsPixelLocal(plsSymbol->getBasicType()));
    250 
    251        TType *imageType = new TType(plsSymbol->getType());
    252 
    253        TLayoutQualifier layoutQualifier = imageType->getLayoutQualifier();
    254        switch (layoutQualifier.imageInternalFormat)
    255        {
    256            case TLayoutImageInternalFormat::EiifRGBA8:
    257                if (needsR32Packing())
    258                {
    259                    layoutQualifier.imageInternalFormat = EiifR32UI;
    260                    imageType->setPrecision(EbpHigh);
    261                    imageType->setBasicType(EbtUImage2D);
    262                }
    263                else
    264                {
    265                    imageType->setBasicType(EbtImage2D);
    266                }
    267                break;
    268            case TLayoutImageInternalFormat::EiifRGBA8I:
    269                if (needsR32Packing())
    270                {
    271                    layoutQualifier.imageInternalFormat = EiifR32I;
    272                    imageType->setPrecision(EbpHigh);
    273                }
    274                imageType->setBasicType(EbtIImage2D);
    275                break;
    276            case TLayoutImageInternalFormat::EiifRGBA8UI:
    277                if (needsR32Packing())
    278                {
    279                    layoutQualifier.imageInternalFormat = EiifR32UI;
    280                    imageType->setPrecision(EbpHigh);
    281                }
    282                imageType->setBasicType(EbtUImage2D);
    283                break;
    284            case TLayoutImageInternalFormat::EiifR32F:
    285                imageType->setBasicType(EbtImage2D);
    286                break;
    287            case TLayoutImageInternalFormat::EiifR32UI:
    288                imageType->setBasicType(EbtUImage2D);
    289                break;
    290            default:
    291                UNREACHABLE();
    292        }
    293        layoutQualifier.rasterOrdered = mCompileOptions->pls.fragmentSynchronizationType ==
    294                                        ShFragmentSynchronizationType::RasterizerOrderViews_D3D;
    295        imageType->setLayoutQualifier(layoutQualifier);
    296 
    297        TMemoryQualifier memoryQualifier{};
    298        memoryQualifier.coherent          = true;
    299        memoryQualifier.restrictQualifier = true;
    300        memoryQualifier.volatileQualifier = false;
    301        // TODO(anglebug.com/7279): Maybe we could walk the tree first and see which PLS is used
    302        // how. If the PLS is never loaded, we could add a writeonly qualifier, for example.
    303        memoryQualifier.readonly  = false;
    304        memoryQualifier.writeonly = false;
    305        imageType->setMemoryQualifier(memoryQualifier);
    306 
    307        const TVariable &plsVar = plsSymbol->variable();
    308        return new TVariable(plsVar.uniqueId(), plsVar.name(), plsVar.symbolType(),
    309                             plsVar.extensions(), imageType);
    310    }
    311 
    312    void visitPLSLoad(TIntermSymbol *plsSymbol) override
    313    {
    314        // Replace the pixelLocalLoadANGLE with imageLoad.
    315        TVariable *image2D = mImages.find(plsSymbol);
    316        ASSERT(mGlobalPixelCoord);
    317        TIntermTyped *pls = CreateBuiltInFunctionCallNode(
    318            "imageLoad", {new TIntermSymbol(image2D), new TIntermSymbol(mGlobalPixelCoord)},
    319            *mSymbolTable, 310);
    320        pls = unpackImageDataIfNecessary(pls, plsSymbol, image2D);
    321        queueReplacement(pls, OriginalNode::IS_DROPPED);
    322    }
    323 
    324    // Unpacks the raw PLS data if the output shader language needs r32* packing.
    325    TIntermTyped *unpackImageDataIfNecessary(TIntermTyped *data,
    326                                             TIntermSymbol *plsSymbol,
    327                                             TVariable *image2D)
    328    {
    329        TLayoutImageInternalFormat plsFormat =
    330            plsSymbol->getType().getLayoutQualifier().imageInternalFormat;
    331        TLayoutImageInternalFormat imageFormat =
    332            image2D->getType().getLayoutQualifier().imageInternalFormat;
    333        if (plsFormat == imageFormat)
    334        {
    335            return data;  // This PLS storage isn't packed.
    336        }
    337        ASSERT(needsR32Packing());
    338        switch (plsFormat)
    339        {
    340            case EiifRGBA8:
    341                // Unpack and normalize r,g,b,a from a single 32-bit unsigned int:
    342                //
    343                //     unpackUnorm4x8(data.r)
    344                //
    345                data = CreateBuiltInFunctionCallNode("unpackUnorm4x8", {CreateSwizzle(data, 0)},
    346                                                     *mSymbolTable, 310);
    347                break;
    348            case EiifRGBA8I:
    349            case EiifRGBA8UI:
    350            {
    351                constexpr unsigned shifts[] = {24, 16, 8, 0};
    352                // Unpack r,g,b,a form a single (signed or unsigned) 32-bit int. Shift left,
    353                // then right, to preserve the sign for ints. (highp integers are exactly
    354                // 32-bit, two's compliment.)
    355                //
    356                //     data.rrrr << uvec4(24, 16, 8, 0) >> 24u
    357                //
    358                data = CreateSwizzle(data, 0, 0, 0, 0);
    359                data = new TIntermBinary(EOpBitShiftLeft, data, CreateUVecNode(shifts, 4, EbpHigh));
    360                data = new TIntermBinary(EOpBitShiftRight, data, CreateUIntNode(24));
    361                break;
    362            }
    363            default:
    364                UNREACHABLE();
    365        }
    366        return data;
    367    }
    368 
    369    void visitPLSStore(TIntermSymbol *plsSymbol, TVariable *value) override
    370    {
    371        TVariable *image2D       = mImages.find(plsSymbol);
    372        TIntermTyped *packedData = clampAndPackPLSDataIfNecessary(value, plsSymbol, image2D);
    373 
    374        // Surround the store with memoryBarrierImage calls in order to ensure dependent stores and
    375        // loads in a single shader invocation are coherent. From the ES 3.1 spec:
    376        //
    377        //   Using variables declared as "coherent" guarantees only that the results of stores will
    378        //   be immediately visible to shader invocations using similarly-declared variables;
    379        //   calling MemoryBarrier is required to ensure that the stores are visible to other
    380        //   operations.
    381        //
    382        insertStatementsInParentBlock(
    383            {CreateBuiltInFunctionCallNode("memoryBarrierImage", {}, *mSymbolTable,
    384                                           310)},  // Before.
    385            {CreateBuiltInFunctionCallNode("memoryBarrierImage", {}, *mSymbolTable,
    386                                           310)});  // After.
    387 
    388        // Rewrite the pixelLocalStoreANGLE with imageStore.
    389        ASSERT(mGlobalPixelCoord);
    390        queueReplacement(
    391            CreateBuiltInFunctionCallNode(
    392                "imageStore",
    393                {new TIntermSymbol(image2D), new TIntermSymbol(mGlobalPixelCoord), packedData},
    394                *mSymbolTable, 310),
    395            OriginalNode::IS_DROPPED);
    396    }
    397 
    398    // Packs the PLS to raw data if the output shader language needs r32* packing.
    399    TIntermTyped *clampAndPackPLSDataIfNecessary(TVariable *plsVar,
    400                                                 TIntermSymbol *plsSymbol,
    401                                                 TVariable *image2D)
    402    {
    403        TLayoutImageInternalFormat plsFormat =
    404            plsSymbol->getType().getLayoutQualifier().imageInternalFormat;
    405        // anglebug.com/7524: Storing to integer formats with values larger than can be represented
    406        // is specified differently on different APIs. Clamp integer formats here to make it uniform
    407        // and more GL-like.
    408        switch (plsFormat)
    409        {
    410            case EiifRGBA8I:
    411            {
    412                // Clamp r,g,b,a to their min/max 8-bit values:
    413                //
    414                //     plsVar = clamp(plsVar, -128, 127) & 0xff
    415                //
    416                TIntermTyped *newPLSValue = CreateBuiltInFunctionCallNode(
    417                    "clamp",
    418                    {new TIntermSymbol(plsVar), CreateIndexNode(-128), CreateIndexNode(127)},
    419                    *mSymbolTable, mShaderVersion);
    420                insertStatementInParentBlock(CreateTempAssignmentNode(plsVar, newPLSValue));
    421                break;
    422            }
    423            case EiifRGBA8UI:
    424            {
    425                // Clamp r,g,b,a to their max 8-bit values:
    426                //
    427                //     plsVar = min(plsVar, 255)
    428                //
    429                TIntermTyped *newPLSValue = CreateBuiltInFunctionCallNode(
    430                    "min", {new TIntermSymbol(plsVar), CreateUIntNode(255)}, *mSymbolTable,
    431                    mShaderVersion);
    432                insertStatementInParentBlock(CreateTempAssignmentNode(plsVar, newPLSValue));
    433                break;
    434            }
    435            default:
    436                break;
    437        }
    438        TIntermTyped *result = new TIntermSymbol(plsVar);
    439        TLayoutImageInternalFormat imageFormat =
    440            image2D->getType().getLayoutQualifier().imageInternalFormat;
    441        if (plsFormat == imageFormat)
    442        {
    443            return result;  // This PLS storage isn't packed.
    444        }
    445        ASSERT(needsR32Packing());
    446        switch (plsFormat)
    447        {
    448            case EiifRGBA8:
    449            {
    450                if (mCompileOptions->passHighpToPackUnormSnormBuiltins)
    451                {
    452                    // anglebug.com/7527: unpackUnorm4x8 doesn't work on Pixel 4 when passed
    453                    // a mediump vec4. Use an intermediate highp vec4.
    454                    //
    455                    // It's safe to inject a variable here because it happens right before
    456                    // pixelLocalStoreANGLE, which returns type void. (See visitAggregate.)
    457                    TType *highpType              = new TType(EbtFloat, EbpHigh, EvqTemporary, 4);
    458                    TVariable *workaroundHighpVar = CreateTempVariable(mSymbolTable, highpType);
    459                    insertStatementInParentBlock(
    460                        CreateTempInitDeclarationNode(workaroundHighpVar, result));
    461                    result = new TIntermSymbol(workaroundHighpVar);
    462                }
    463 
    464                // Denormalize and pack r,g,b,a into a single 32-bit unsigned int:
    465                //
    466                //     packUnorm4x8(workaroundHighpVar)
    467                //
    468                result =
    469                    CreateBuiltInFunctionCallNode("packUnorm4x8", {result}, *mSymbolTable, 310);
    470                break;
    471            }
    472            case EiifRGBA8I:
    473            case EiifRGBA8UI:
    474            {
    475                if (plsFormat == EiifRGBA8I)
    476                {
    477                    // Mask off extra sign bits beyond 8.
    478                    //
    479                    //     plsVar &= 0xff
    480                    //
    481                    insertStatementInParentBlock(new TIntermBinary(
    482                        EOpBitwiseAndAssign, new TIntermSymbol(plsVar), CreateIndexNode(0xff)));
    483                }
    484                // Pack r,g,b,a into a single 32-bit (signed or unsigned) int:
    485                //
    486                //     r | (g << 8) | (b << 16) | (a << 24)
    487                //
    488                auto shiftComponent = [=](int componentIdx) {
    489                    return new TIntermBinary(EOpBitShiftLeft,
    490                                             CreateSwizzle(new TIntermSymbol(plsVar), componentIdx),
    491                                             CreateUIntNode(componentIdx * 8));
    492                };
    493                result = CreateSwizzle(result, 0);
    494                result = new TIntermBinary(EOpBitwiseOr, result, shiftComponent(1));
    495                result = new TIntermBinary(EOpBitwiseOr, result, shiftComponent(2));
    496                result = new TIntermBinary(EOpBitwiseOr, result, shiftComponent(3));
    497                break;
    498            }
    499            default:
    500                UNREACHABLE();
    501        }
    502        // Convert the packed data to a {u,i}vec4 for imageStore.
    503        TType imageStoreType(DataTypeOfImageType(image2D->getType().getBasicType()), 4);
    504        return TIntermAggregate::CreateConstructor(imageStoreType, {result});
    505    }
    506 
    507    void injectSetupCode(TCompiler *compiler,
    508                         TSymbolTable &symbolTable,
    509                         const ShCompileOptions &compileOptions,
    510                         TIntermBlock *mainBody,
    511                         size_t plsBeginPosition) override
    512    {
    513        // When PLS is implemented with images, early_fragment_tests ensure that depth/stencil
    514        // can also block stores to PLS.
    515        compiler->specifyEarlyFragmentTests();
    516 
    517        // Delimit the beginning of a per-pixel critical section, if supported. This makes pixel
    518        // local storage coherent.
    519        //
    520        // Either: GL_NV_fragment_shader_interlock
    521        //         GL_INTEL_fragment_shader_ordering
    522        //         GL_ARB_fragment_shader_interlock (may compile to
    523        //                                           SPV_EXT_fragment_shader_interlock)
    524        switch (compileOptions.pls.fragmentSynchronizationType)
    525        {
    526            // ROVs don't need explicit synchronization calls.
    527            case ShFragmentSynchronizationType::RasterizerOrderViews_D3D:
    528            case ShFragmentSynchronizationType::NotSupported:
    529                break;
    530            case ShFragmentSynchronizationType::FragmentShaderInterlock_NV_GL:
    531                mainBody->insertStatement(
    532                    plsBeginPosition,
    533                    CreateBuiltInFunctionCallNode("beginInvocationInterlockNV", {}, symbolTable,
    534                                                  kESSLInternalBackendBuiltIns));
    535                break;
    536            case ShFragmentSynchronizationType::FragmentShaderOrdering_INTEL_GL:
    537                mainBody->insertStatement(
    538                    plsBeginPosition,
    539                    CreateBuiltInFunctionCallNode("beginFragmentShaderOrderingINTEL", {},
    540                                                  symbolTable, kESSLInternalBackendBuiltIns));
    541                break;
    542            case ShFragmentSynchronizationType::FragmentShaderInterlock_ARB_GL:
    543                mainBody->insertStatement(
    544                    plsBeginPosition,
    545                    CreateBuiltInFunctionCallNode("beginInvocationInterlockARB", {}, symbolTable,
    546                                                  kESSLInternalBackendBuiltIns));
    547                break;
    548            default:
    549                UNREACHABLE();
    550        }
    551    }
    552 
    553    void injectFinalizeCode(TCompiler *,
    554                            TSymbolTable &symbolTable,
    555                            const ShCompileOptions &compileOptions,
    556                            TIntermBlock *mainBody,
    557                            size_t plsEndPosition) override
    558    {
    559        // Delimit the end of the PLS critical section, if required.
    560        //
    561        // Either: GL_NV_fragment_shader_interlock
    562        //         GL_ARB_fragment_shader_interlock (may compile to
    563        //                                           SPV_EXT_fragment_shader_interlock)
    564        switch (compileOptions.pls.fragmentSynchronizationType)
    565        {
    566            // ROVs don't need explicit synchronization calls.
    567            case ShFragmentSynchronizationType::RasterizerOrderViews_D3D:
    568            // GL_INTEL_fragment_shader_ordering doesn't have an "end()" call.
    569            case ShFragmentSynchronizationType::FragmentShaderOrdering_INTEL_GL:
    570            case ShFragmentSynchronizationType::NotSupported:
    571                break;
    572            case ShFragmentSynchronizationType::FragmentShaderInterlock_NV_GL:
    573 
    574                mainBody->insertStatement(
    575                    plsEndPosition,
    576                    CreateBuiltInFunctionCallNode("endInvocationInterlockNV", {}, symbolTable,
    577                                                  kESSLInternalBackendBuiltIns));
    578                break;
    579            case ShFragmentSynchronizationType::FragmentShaderInterlock_ARB_GL:
    580                mainBody->insertStatement(
    581                    plsEndPosition,
    582                    CreateBuiltInFunctionCallNode("endInvocationInterlockARB", {}, symbolTable,
    583                                                  kESSLInternalBackendBuiltIns));
    584                break;
    585            default:
    586                UNREACHABLE();
    587        }
    588    }
    589 
    590    PLSBackingStoreMap<TVariable *> mImages;
    591 };
    592 
    593 // Rewrites high level PLS operations to framebuffer fetch operations.
    594 class RewritePLSToFramebufferFetchTraverser : public RewritePLSTraverser
    595 {
    596  public:
    597    RewritePLSToFramebufferFetchTraverser(TCompiler *compiler,
    598                                          TSymbolTable &symbolTable,
    599                                          const ShCompileOptions &compileOptions,
    600                                          int shaderVersion)
    601        : RewritePLSTraverser(compiler, symbolTable, compileOptions, shaderVersion)
    602    {}
    603 
    604    void visitPLSDeclaration(TIntermSymbol *plsSymbol) override
    605    {
    606        // Replace the PLS declaration with a framebuffer attachment.
    607        PLSAttachment attachment(mCompiler, mSymbolTable, *mCompileOptions, plsSymbol->variable());
    608        mPLSAttachments.insertNew(plsSymbol, attachment);
    609        insertStatementInParentBlock(
    610            new TIntermDeclaration({new TIntermSymbol(attachment.fragmentVar)}));
    611        queueReplacement(CreateTempDeclarationNode(attachment.accessVar), OriginalNode::IS_DROPPED);
    612    }
    613 
    614    void visitPLSLoad(TIntermSymbol *plsSymbol) override
    615    {
    616        // Read our temporary accessVar.
    617        const PLSAttachment &attachment = mPLSAttachments.find(plsSymbol);
    618        queueReplacement(attachment.expandAccessVar(), OriginalNode::IS_DROPPED);
    619    }
    620 
    621    void visitPLSStore(TIntermSymbol *plsSymbol, TVariable *value) override
    622    {
    623        // Set our temporary accessVar.
    624        const PLSAttachment &attachment = mPLSAttachments.find(plsSymbol);
    625        queueReplacement(CreateTempAssignmentNode(attachment.accessVar, attachment.swizzle(value)),
    626                         OriginalNode::IS_DROPPED);
    627    }
    628 
    629    void injectSetupCode(TCompiler *compiler,
    630                         TSymbolTable &symbolTable,
    631                         const ShCompileOptions &compileOptions,
    632                         TIntermBlock *mainBody,
    633                         size_t plsBeginPosition) override
    634    {
    635        // [OpenGL ES Version 3.0.6, 3.9.2.3 "Shader Output"]: Any colors, or color components,
    636        // associated with a fragment that are not written by the fragment shader are undefined.
    637        //
    638        // [EXT_shader_framebuffer_fetch]: Prior to fragment shading, fragment outputs declared
    639        // inout are populated with the value last written to the framebuffer at the same(x, y,
    640        // sample) position.
    641        //
    642        // It's unclear from the EXT_shader_framebuffer_fetch spec whether inout fragment variables
    643        // become undefined if not explicitly written, but either way, when this compiles to subpass
    644        // loads in Vulkan, we definitely get undefined behavior if PLS variables are not written.
    645        //
    646        // To make sure every PLS variable gets written, we read them all before PLS operations,
    647        // then write them all back out after all PLS is complete.
    648        std::vector<TIntermNode *> plsPreloads;
    649        plsPreloads.reserve(mPLSAttachments.bindingOrderedMap().size());
    650        for (const auto &entry : mPLSAttachments.bindingOrderedMap())
    651        {
    652            const PLSAttachment &attachment = entry.second;
    653            plsPreloads.push_back(
    654                CreateTempAssignmentNode(attachment.accessVar, attachment.swizzleFragmentVar()));
    655        }
    656        mainBody->getSequence()->insert(mainBody->getSequence()->begin() + plsBeginPosition,
    657                                        plsPreloads.begin(), plsPreloads.end());
    658    }
    659 
    660    void injectFinalizeCode(TCompiler *,
    661                            TSymbolTable &symbolTable,
    662                            const ShCompileOptions &compileOptions,
    663                            TIntermBlock *mainBody,
    664                            size_t plsEndPosition) override
    665    {
    666        std::vector<TIntermNode *> plsWrites;
    667        plsWrites.reserve(mPLSAttachments.bindingOrderedMap().size());
    668        for (const auto &entry : mPLSAttachments.bindingOrderedMap())
    669        {
    670            const PLSAttachment &attachment = entry.second;
    671            plsWrites.push_back(new TIntermBinary(EOpAssign, attachment.swizzleFragmentVar(),
    672                                                  new TIntermSymbol(attachment.accessVar)));
    673        }
    674        mainBody->getSequence()->insert(mainBody->getSequence()->begin() + plsEndPosition,
    675                                        plsWrites.begin(), plsWrites.end());
    676    }
    677 
    678  private:
    679    struct PLSAttachment
    680    {
    681        PLSAttachment(const TCompiler *compiler,
    682                      TSymbolTable *symbolTable,
    683                      const ShCompileOptions &compileOptions,
    684                      const TVariable &plsVar)
    685        {
    686            const TType &plsType = plsVar.getType();
    687 
    688            TType *accessVarType;
    689            switch (plsType.getLayoutQualifier().imageInternalFormat)
    690            {
    691                default:
    692                    UNREACHABLE();
    693                    [[fallthrough]];
    694                case EiifRGBA8:
    695                    accessVarType = new TType(EbtFloat, 4);
    696                    break;
    697                case EiifRGBA8I:
    698                    accessVarType = new TType(EbtInt, 4);
    699                    break;
    700                case EiifRGBA8UI:
    701                    accessVarType = new TType(EbtUInt, 4);
    702                    break;
    703                case EiifR32F:
    704                    accessVarType = new TType(EbtFloat, 1);
    705                    break;
    706                case EiifR32UI:
    707                    accessVarType = new TType(EbtUInt, 1);
    708                    break;
    709            }
    710            accessVarType->setPrecision(plsType.getPrecision());
    711            accessVar = CreateTempVariable(symbolTable, accessVarType);
    712 
    713            // Qualcomm seems to want fragment outputs to be 4-component vectors, and produces a
    714            // compile error from "inout uint". Our Metal translator also saturates color outputs to
    715            // 4 components. And since the spec also seems silent on how many components an output
    716            // must have, we always use 4.
    717            TType *fragmentVarType = new TType(accessVarType->getBasicType(), 4);
    718            fragmentVarType->setPrecision(plsType.getPrecision());
    719            fragmentVarType->setQualifier(EvqFragmentInOut);
    720 
    721            // PLS attachments are bound in reverse order from the rear.
    722            TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
    723            layoutQualifier.location =
    724                compiler->getResources().MaxCombinedDrawBuffersAndPixelLocalStoragePlanes -
    725                plsType.getLayoutQualifier().binding - 1;
    726            layoutQualifier.locationsSpecified = 1;
    727            if (compileOptions.pls.fragmentSynchronizationType ==
    728                ShFragmentSynchronizationType::NotSupported)
    729            {
    730                // We're using EXT_shader_framebuffer_fetch_non_coherent, which requires the
    731                // "noncoherent" qualifier.
    732                layoutQualifier.noncoherent = true;
    733            }
    734            fragmentVarType->setLayoutQualifier(layoutQualifier);
    735 
    736            fragmentVar = new TVariable(plsVar.uniqueId(), plsVar.name(), plsVar.symbolType(),
    737                                        plsVar.extensions(), fragmentVarType);
    738        }
    739 
    740        // Expands our accessVar to 4 components, regardless of the size of the pixel local storage
    741        // internalformat.
    742        TIntermTyped *expandAccessVar() const
    743        {
    744            TIntermTyped *expanded = new TIntermSymbol(accessVar);
    745            if (accessVar->getType().getNominalSize() == 1)
    746            {
    747                switch (accessVar->getType().getBasicType())
    748                {
    749                    case EbtFloat:
    750                        expanded = TIntermAggregate::CreateConstructor(  // "vec4(r, 0, 0, 1)"
    751                            TType(EbtFloat, 4),
    752                            {expanded, CreateFloatNode(0, EbpHigh), CreateFloatNode(0, EbpHigh),
    753                             CreateFloatNode(1, EbpHigh)});
    754                        break;
    755                    case EbtUInt:
    756                        expanded = TIntermAggregate::CreateConstructor(  // "uvec4(r, 0, 0, 1)"
    757                            TType(EbtUInt, 4),
    758                            {expanded, CreateUIntNode(0), CreateUIntNode(0), CreateUIntNode(1)});
    759                        break;
    760                    default:
    761                        UNREACHABLE();
    762                        break;
    763                }
    764            }
    765            return expanded;
    766        }
    767 
    768        // Swizzles a variable down to the same number of components as the PLS internalformat.
    769        TIntermTyped *swizzle(TVariable *var) const
    770        {
    771            TIntermTyped *swizzled = new TIntermSymbol(var);
    772            if (var->getType().getNominalSize() != accessVar->getType().getNominalSize())
    773            {
    774                ASSERT(var->getType().getNominalSize() > accessVar->getType().getNominalSize());
    775                TVector swizzleOffsets{0, 1, 2, 3};
    776                swizzleOffsets.resize(accessVar->getType().getNominalSize());
    777                swizzled = new TIntermSwizzle(swizzled, swizzleOffsets);
    778            }
    779            return swizzled;
    780        }
    781 
    782        TIntermTyped *swizzleFragmentVar() const { return swizzle(fragmentVar); }
    783 
    784        TVariable *fragmentVar;
    785        TVariable *accessVar;
    786    };
    787 
    788    PLSBackingStoreMap<PLSAttachment> mPLSAttachments;
    789 };
    790 }  // anonymous namespace
    791 
    792 bool RewritePixelLocalStorage(TCompiler *compiler,
    793                              TIntermBlock *root,
    794                              TSymbolTable &symbolTable,
    795                              const ShCompileOptions &compileOptions,
    796                              int shaderVersion)
    797 {
    798    // If any functions take PLS arguments, monomorphize the functions by removing said parameters
    799    // and making the PLS calls from main() instead, using the global uniform from the call site
    800    // instead of the function argument. This is necessary because function arguments don't carry
    801    // the necessary "binding" or "format" layout qualifiers.
    802    if (!MonomorphizeUnsupportedFunctions(
    803            compiler, root, &symbolTable, compileOptions,
    804            UnsupportedFunctionArgsBitSet{UnsupportedFunctionArgs::PixelLocalStorage}))
    805    {
    806        return false;
    807    }
    808 
    809    TIntermBlock *mainBody = FindMainBody(root);
    810 
    811    std::unique_ptr<RewritePLSTraverser> traverser;
    812    switch (compileOptions.pls.type)
    813    {
    814        case ShPixelLocalStorageType::ImageStoreR32PackedFormats:
    815        case ShPixelLocalStorageType::ImageStoreNativeFormats:
    816            traverser = std::make_unique<RewritePLSToImagesTraverser>(
    817                compiler, symbolTable, compileOptions, shaderVersion);
    818            break;
    819        case ShPixelLocalStorageType::FramebufferFetch:
    820            traverser = std::make_unique<RewritePLSToFramebufferFetchTraverser>(
    821                compiler, symbolTable, compileOptions, shaderVersion);
    822            break;
    823        default:
    824            UNREACHABLE();
    825            return false;
    826    }
    827 
    828    // Rewrite PLS operations to image operations.
    829    root->traverse(traverser.get());
    830    if (!traverser->updateTree(compiler, root))
    831    {
    832        return false;
    833    }
    834 
    835    // Inject the code that needs to run before and after all PLS operations.
    836    // TODO(anglebug.com/7279): Inject these functions in a tight critical section, instead of
    837    // just locking the entire main() function:
    838    //   - Monomorphize all PLS calls into main().
    839    //   - Insert begin/end calls around the first/last PLS calls (and outside of flow control).
    840    traverser->injectSetupCode(compiler, symbolTable, compileOptions, mainBody, 0);
    841    traverser->injectFinalizeCode(compiler, symbolTable, compileOptions, mainBody,
    842                                  mainBody->getChildCount());
    843 
    844    if (traverser->globalPixelCoord())
    845    {
    846        // Initialize the global pixel coord at the beginning of main():
    847        //
    848        //     pixelCoord = ivec2(floor(gl_FragCoord.xy));
    849        //
    850        TIntermTyped *exp;
    851        exp = ReferenceBuiltInVariable(ImmutableString("gl_FragCoord"), symbolTable, shaderVersion);
    852        exp = CreateSwizzle(exp, 0, 1);
    853        exp = CreateBuiltInFunctionCallNode("floor", {exp}, symbolTable, shaderVersion);
    854        exp = TIntermAggregate::CreateConstructor(TType(EbtInt, 2), {exp});
    855        exp = CreateTempAssignmentNode(traverser->globalPixelCoord(), exp);
    856        mainBody->insertStatement(0, exp);
    857    }
    858 
    859    return compiler->validateAST(root);
    860 }
    861 }  // namespace sh