tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

RewriteAtomicCounters.cpp (12342B)


      1 //
      2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
      7 //
      8 
      9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
     10 
     11 #include "compiler/translator/Compiler.h"
     12 #include "compiler/translator/ImmutableStringBuilder.h"
     13 #include "compiler/translator/SymbolTable.h"
     14 #include "compiler/translator/tree_util/IntermNode_util.h"
     15 #include "compiler/translator/tree_util/IntermTraverse.h"
     16 #include "compiler/translator/tree_util/ReplaceVariable.h"
     17 
     18 namespace sh
     19 {
     20 namespace
     21 {
     22 constexpr ImmutableString kAtomicCountersVarName  = ImmutableString("atomicCounters");
     23 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters");
     24 
     25 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
     26 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
     27 {
     28    // Define `uint counters[];` as the only field in the interface block.
     29    TFieldList *fieldList = new TFieldList;
     30    TType *counterType    = new TType(EbtUInt, EbpHigh, EvqGlobal);
     31    counterType->makeArray(0);
     32 
     33    TField *countersField =
     34        new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
     35 
     36    fieldList->push_back(countersField);
     37 
     38    TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
     39    coherentMemory.coherent         = true;
     40 
     41    // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
     42    // in libANGLE/Constants.h.
     43    constexpr uint32_t kMaxAtomicCounterBuffers = 8;
     44 
     45    // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
     46    TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
     47    layoutQualifier.blockStorage     = EbsStd430;
     48 
     49    return DeclareInterfaceBlock(root, symbolTable, fieldList, EvqBuffer, layoutQualifier,
     50                                 coherentMemory, kMaxAtomicCounterBuffers,
     51                                 ImmutableString(vk::kAtomicCountersBlockName),
     52                                 kAtomicCountersVarName);
     53 }
     54 
     55 TIntermTyped *CreateUniformBufferOffset(const TIntermTyped *uniformBufferOffsets, int binding)
     56 {
     57    // Each uint in the |acbBufferOffsets| uniform contains offsets for 4 bindings.  Therefore, the
     58    // expression to get the uniform offset for the binding is:
     59    //
     60    //     acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
     61 
     62    // acbBufferOffsets[binding / 4]
     63    TIntermBinary *uniformBufferOffsetUint = new TIntermBinary(
     64        EOpIndexDirect, uniformBufferOffsets->deepCopy(), CreateIndexNode(binding / 4));
     65 
     66    // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8)
     67    TIntermBinary *uniformBufferOffsetShifted = uniformBufferOffsetUint;
     68    if (binding % 4 != 0)
     69    {
     70        uniformBufferOffsetShifted = new TIntermBinary(EOpBitShiftRight, uniformBufferOffsetUint,
     71                                                       CreateUIntNode((binding % 4) * 8));
     72    }
     73 
     74    // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
     75    return new TIntermBinary(EOpBitwiseAnd, uniformBufferOffsetShifted, CreateUIntNode(0xFF));
     76 }
     77 
     78 TIntermBinary *CreateAtomicCounterRef(TIntermTyped *atomicCounterExpression,
     79                                      const TVariable *atomicCounters,
     80                                      const TIntermTyped *uniformBufferOffsets)
     81 {
     82    // The atomic counters storage buffer declaration looks as such:
     83    //
     84    // layout(...) buffer ANGLEAtomicCounters
     85    // {
     86    //     uint counters[];
     87    // } atomicCounters[N];
     88    //
     89    // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
     90    //
     91    // This function takes an expression that uses an atomic counter, which can either be:
     92    //
     93    //  - ac
     94    //  - acArray[index]
     95    //
     96    // Note that RewriteArrayOfArrayOfOpaqueUniforms has already flattened array of array of atomic
     97    // counters.
     98    //
     99    // For the first case (ac), the following code is generated:
    100    //
    101    //     atomicCounters[binding].counters[offset]
    102    //
    103    // For the second case (acArray[index]), the following code is generated:
    104    //
    105    //     atomicCounters[binding].counters[offset + index]
    106    //
    107    // In either case, an offset given through uniforms is also added to |offset|.  The binding is
    108    // necessarily a constant thanks to MonomorphizeUnsupportedFunctions.
    109 
    110    // First determine if there's an index, and extract the atomic counter symbol out of the
    111    // expression.
    112    TIntermSymbol *atomicCounterSymbol = atomicCounterExpression->getAsSymbolNode();
    113    TIntermTyped *atomicCounterIndex   = nullptr;
    114    int atomicCounterConstIndex        = 0;
    115    TIntermBinary *asBinary            = atomicCounterExpression->getAsBinaryNode();
    116    if (asBinary != nullptr)
    117    {
    118        atomicCounterSymbol = asBinary->getLeft()->getAsSymbolNode();
    119 
    120        switch (asBinary->getOp())
    121        {
    122            case EOpIndexDirect:
    123                atomicCounterConstIndex = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
    124                break;
    125            case EOpIndexIndirect:
    126                atomicCounterIndex = asBinary->getRight();
    127                break;
    128            default:
    129                UNREACHABLE();
    130        }
    131    }
    132 
    133    // Extract binding and offset information out of the atomic counter symbol.
    134    ASSERT(atomicCounterSymbol);
    135    const TVariable *atomicCounterVar = &atomicCounterSymbol->variable();
    136    const TType &atomicCounterType    = atomicCounterVar->getType();
    137 
    138    const int binding = atomicCounterType.getLayoutQualifier().binding;
    139    int offset        = atomicCounterType.getLayoutQualifier().offset / 4;
    140 
    141    // Create the expression:
    142    //
    143    //     offset + arrayIndex + uniformOffset
    144    //
    145    // If arrayIndex is a constant, it's added with offset right here.
    146 
    147    offset += atomicCounterConstIndex;
    148 
    149    TIntermTyped *index = CreateUniformBufferOffset(uniformBufferOffsets, binding);
    150    if (atomicCounterIndex != nullptr)
    151    {
    152        index = new TIntermBinary(EOpAdd, index, atomicCounterIndex);
    153    }
    154    if (offset != 0)
    155    {
    156        index = new TIntermBinary(EOpAdd, index, CreateIndexNode(offset));
    157    }
    158 
    159    // Finally, create the complete expression:
    160    //
    161    //     atomicCounters[binding].counters[index]
    162 
    163    TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
    164 
    165    // atomicCounters[binding]
    166    TIntermBinary *countersBlock =
    167        new TIntermBinary(EOpIndexDirect, atomicCountersRef, CreateIndexNode(binding));
    168 
    169    // atomicCounters[binding].counters
    170    TIntermBinary *counters =
    171        new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, CreateIndexNode(0));
    172 
    173    return new TIntermBinary(EOpIndexIndirect, counters, index);
    174 }
    175 
    176 // Traverser that:
    177 //
    178 // 1. Removes the |uniform atomic_uint| declarations and remembers the binding and offset.
    179 // 2. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
    180 class RewriteAtomicCountersTraverser : public TIntermTraverser
    181 {
    182  public:
    183    RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
    184                                   const TVariable *atomicCounters,
    185                                   const TIntermTyped *acbBufferOffsets)
    186        : TIntermTraverser(true, false, false, symbolTable),
    187          mAtomicCounters(atomicCounters),
    188          mAcbBufferOffsets(acbBufferOffsets)
    189    {}
    190 
    191    bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
    192    {
    193        if (!mInGlobalScope)
    194        {
    195            return true;
    196        }
    197 
    198        const TIntermSequence &sequence = *(node->getSequence());
    199 
    200        TIntermTyped *variable = sequence.front()->getAsTyped();
    201        const TType &type      = variable->getType();
    202        bool isAtomicCounter   = type.isAtomicCounter();
    203 
    204        if (isAtomicCounter)
    205        {
    206            ASSERT(type.getQualifier() == EvqUniform);
    207            TIntermSequence emptySequence;
    208            mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
    209                                            std::move(emptySequence));
    210 
    211            return false;
    212        }
    213 
    214        return true;
    215    }
    216 
    217    bool visitAggregate(Visit visit, TIntermAggregate *node) override
    218    {
    219        if (BuiltInGroup::IsBuiltIn(node->getOp()))
    220        {
    221            bool converted = convertBuiltinFunction(node);
    222            return !converted;
    223        }
    224 
    225        // AST functions don't require modification as atomic counter function parameters are
    226        // removed by MonomorphizeUnsupportedFunctions.
    227        return true;
    228    }
    229 
    230    void visitSymbol(TIntermSymbol *symbol) override
    231    {
    232        // Cannot encounter the atomic counter symbol directly.  It can only be used with functions,
    233        // and therefore it's handled by visitAggregate.
    234        ASSERT(!symbol->getType().isAtomicCounter());
    235    }
    236 
    237    bool visitBinary(Visit visit, TIntermBinary *node) override
    238    {
    239        // Cannot encounter an atomic counter expression directly.  It can only be used with
    240        // functions, and therefore it's handled by visitAggregate.
    241        ASSERT(!node->getType().isAtomicCounter());
    242        return true;
    243    }
    244 
    245  private:
    246    bool convertBuiltinFunction(TIntermAggregate *node)
    247    {
    248        const TOperator op = node->getOp();
    249 
    250        // If the function is |memoryBarrierAtomicCounter|, simply replace it with
    251        // |memoryBarrierBuffer|.
    252        if (op == EOpMemoryBarrierAtomicCounter)
    253        {
    254            TIntermSequence emptySequence;
    255            TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
    256                "memoryBarrierBuffer", &emptySequence, *mSymbolTable, 310);
    257            queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
    258            return true;
    259        }
    260 
    261        // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
    262        if (!node->getFunction()->isAtomicCounterFunction())
    263        {
    264            return false;
    265        }
    266 
    267        // Note: atomicAdd(0) is used for atomic reads.
    268        uint32_t valueChange                = 0;
    269        constexpr char kAtomicAddFunction[] = "atomicAdd";
    270        bool isDecrement                    = false;
    271 
    272        if (op == EOpAtomicCounterIncrement)
    273        {
    274            valueChange = 1;
    275        }
    276        else if (op == EOpAtomicCounterDecrement)
    277        {
    278            // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
    279            valueChange = std::numeric_limits<uint32_t>::max();
    280            static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
    281                          "uint32_t max is not -1");
    282 
    283            isDecrement = true;
    284        }
    285        else
    286        {
    287            ASSERT(op == EOpAtomicCounter);
    288        }
    289 
    290        TIntermTyped *param = (*node->getSequence())[0]->getAsTyped();
    291 
    292        TIntermSequence substituteArguments;
    293        substituteArguments.push_back(
    294            CreateAtomicCounterRef(param, mAtomicCounters, mAcbBufferOffsets));
    295        substituteArguments.push_back(CreateUIntNode(valueChange));
    296 
    297        TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
    298            kAtomicAddFunction, &substituteArguments, *mSymbolTable, 310);
    299 
    300        // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
    301        // unlike atomicAdd.  So we need to do a -1 on the result as well.
    302        if (isDecrement)
    303        {
    304            substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntNode(1));
    305        }
    306 
    307        queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
    308        return true;
    309    }
    310 
    311    const TVariable *mAtomicCounters;
    312    const TIntermTyped *mAcbBufferOffsets;
    313 };
    314 
    315 }  // anonymous namespace
    316 
    317 bool RewriteAtomicCounters(TCompiler *compiler,
    318                           TIntermBlock *root,
    319                           TSymbolTable *symbolTable,
    320                           const TIntermTyped *acbBufferOffsets)
    321 {
    322    const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
    323 
    324    RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
    325    root->traverse(&traverser);
    326    return traverser.updateTree(compiler, root);
    327 }
    328 }  // namespace sh