tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ASTMetadataHLSL.cpp (14788B)


      1 //
      2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 // Analysis of the AST needed for HLSL generation
      8 
      9 #include "compiler/translator/ASTMetadataHLSL.h"
     10 
     11 #include "compiler/translator/CallDAG.h"
     12 #include "compiler/translator/SymbolTable.h"
     13 #include "compiler/translator/tree_util/IntermTraverse.h"
     14 
     15 namespace sh
     16 {
     17 
     18 namespace
     19 {
     20 
     21 // Class used to traverse the AST of a function definition, checking if the
     22 // function uses a gradient, and writing the set of control flow using gradients.
     23 // It assumes that the analysis has already been made for the function's
     24 // callees.
     25 class PullGradient : public TIntermTraverser
     26 {
     27  public:
     28    PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
     29        : TIntermTraverser(true, false, true),
     30          mMetadataList(metadataList),
     31          mMetadata(&(*metadataList)[index]),
     32          mIndex(index),
     33          mDag(dag)
     34    {
     35        ASSERT(index < metadataList->size());
     36 
     37        // ESSL 100 builtin gradient functions
     38        mGradientBuiltinFunctions.insert(ImmutableString("texture2D"));
     39        mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj"));
     40        mGradientBuiltinFunctions.insert(ImmutableString("textureCube"));
     41 
     42        // ESSL 300 builtin gradient functions
     43        mGradientBuiltinFunctions.insert(ImmutableString("dFdx"));
     44        mGradientBuiltinFunctions.insert(ImmutableString("dFdy"));
     45        mGradientBuiltinFunctions.insert(ImmutableString("fwidth"));
     46        mGradientBuiltinFunctions.insert(ImmutableString("texture"));
     47        mGradientBuiltinFunctions.insert(ImmutableString("textureProj"));
     48        mGradientBuiltinFunctions.insert(ImmutableString("textureOffset"));
     49        mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset"));
     50 
     51        // ESSL 310 doesn't add builtin gradient functions
     52    }
     53 
     54    void traverse(TIntermFunctionDefinition *node)
     55    {
     56        node->traverse(this);
     57        ASSERT(mParents.empty());
     58    }
     59 
     60    // Called when a gradient operation or a call to a function using a gradient is found.
     61    void onGradient()
     62    {
     63        mMetadata->mUsesGradient = true;
     64        // Mark the latest control flow as using a gradient.
     65        if (!mParents.empty())
     66        {
     67            mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
     68        }
     69    }
     70 
     71    void visitControlFlow(Visit visit, TIntermNode *node)
     72    {
     73        if (visit == PreVisit)
     74        {
     75            mParents.push_back(node);
     76        }
     77        else if (visit == PostVisit)
     78        {
     79            ASSERT(mParents.back() == node);
     80            mParents.pop_back();
     81            // A control flow's using a gradient means its parents are too.
     82            if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
     83            {
     84                mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
     85            }
     86        }
     87    }
     88 
     89    bool visitLoop(Visit visit, TIntermLoop *loop) override
     90    {
     91        visitControlFlow(visit, loop);
     92        return true;
     93    }
     94 
     95    bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
     96    {
     97        visitControlFlow(visit, ifElse);
     98        return true;
     99    }
    100 
    101    bool visitAggregate(Visit visit, TIntermAggregate *node) override
    102    {
    103        if (visit == PreVisit)
    104        {
    105            if (node->getOp() == EOpCallFunctionInAST)
    106            {
    107                size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
    108                ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
    109 
    110                if ((*mMetadataList)[calleeIndex].mUsesGradient)
    111                {
    112                    onGradient();
    113                }
    114            }
    115            else if (BuiltInGroup::IsBuiltIn(node->getOp()) && !BuiltInGroup::IsMath(node->getOp()))
    116            {
    117                if (mGradientBuiltinFunctions.find(node->getFunction()->name()) !=
    118                    mGradientBuiltinFunctions.end())
    119                {
    120                    onGradient();
    121                }
    122            }
    123        }
    124 
    125        return true;
    126    }
    127 
    128  private:
    129    MetadataList *mMetadataList;
    130    ASTMetadataHLSL *mMetadata;
    131    size_t mIndex;
    132    const CallDAG &mDag;
    133 
    134    // Contains a stack of the control flow nodes that are parents of the node being
    135    // currently visited. It is used to mark control flows using a gradient.
    136    std::vector<TIntermNode *> mParents;
    137 
    138    // A list of builtin functions that use gradients
    139    std::set<ImmutableString> mGradientBuiltinFunctions;
    140 };
    141 
    142 // Traverses the AST of a function definition to compute the the discontinuous loops
    143 // and the if statements containing gradient loops. It assumes that the gradient loops
    144 // (loops that contain a gradient) have already been computed and that it has already
    145 // traversed the current function's callees.
    146 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
    147 {
    148  public:
    149    PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
    150                                             size_t index,
    151                                             const CallDAG &dag)
    152        : TIntermTraverser(true, false, true),
    153          mMetadataList(metadataList),
    154          mMetadata(&(*metadataList)[index]),
    155          mIndex(index),
    156          mDag(dag)
    157    {}
    158 
    159    void traverse(TIntermFunctionDefinition *node)
    160    {
    161        node->traverse(this);
    162        ASSERT(mLoopsAndSwitches.empty());
    163        ASSERT(mIfs.empty());
    164    }
    165 
    166    // Called when traversing a gradient loop or a call to a function with a
    167    // gradient loop in its call graph.
    168    void onGradientLoop()
    169    {
    170        mMetadata->mHasGradientLoopInCallGraph = true;
    171        // Mark the latest if as using a discontinuous loop.
    172        if (!mIfs.empty())
    173        {
    174            mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
    175        }
    176    }
    177 
    178    bool visitLoop(Visit visit, TIntermLoop *loop) override
    179    {
    180        if (visit == PreVisit)
    181        {
    182            mLoopsAndSwitches.push_back(loop);
    183 
    184            if (mMetadata->hasGradientInCallGraph(loop))
    185            {
    186                onGradientLoop();
    187            }
    188        }
    189        else if (visit == PostVisit)
    190        {
    191            ASSERT(mLoopsAndSwitches.back() == loop);
    192            mLoopsAndSwitches.pop_back();
    193        }
    194 
    195        return true;
    196    }
    197 
    198    bool visitIfElse(Visit visit, TIntermIfElse *node) override
    199    {
    200        if (visit == PreVisit)
    201        {
    202            mIfs.push_back(node);
    203        }
    204        else if (visit == PostVisit)
    205        {
    206            ASSERT(mIfs.back() == node);
    207            mIfs.pop_back();
    208            // An if using a discontinuous loop means its parents ifs are also discontinuous.
    209            if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
    210            {
    211                mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
    212            }
    213        }
    214 
    215        return true;
    216    }
    217 
    218    bool visitBranch(Visit visit, TIntermBranch *node) override
    219    {
    220        if (visit == PreVisit)
    221        {
    222            switch (node->getFlowOp())
    223            {
    224                case EOpBreak:
    225                {
    226                    ASSERT(!mLoopsAndSwitches.empty());
    227                    TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
    228                    if (loop != nullptr)
    229                    {
    230                        mMetadata->mDiscontinuousLoops.insert(loop);
    231                    }
    232                }
    233                break;
    234                case EOpContinue:
    235                {
    236                    ASSERT(!mLoopsAndSwitches.empty());
    237                    TIntermLoop *loop = nullptr;
    238                    size_t i          = mLoopsAndSwitches.size();
    239                    while (loop == nullptr && i > 0)
    240                    {
    241                        --i;
    242                        loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
    243                    }
    244                    ASSERT(loop != nullptr);
    245                    mMetadata->mDiscontinuousLoops.insert(loop);
    246                }
    247                break;
    248                case EOpKill:
    249                case EOpReturn:
    250                    // A return or discard jumps out of all the enclosing loops
    251                    if (!mLoopsAndSwitches.empty())
    252                    {
    253                        for (TIntermNode *intermNode : mLoopsAndSwitches)
    254                        {
    255                            TIntermLoop *loop = intermNode->getAsLoopNode();
    256                            if (loop)
    257                            {
    258                                mMetadata->mDiscontinuousLoops.insert(loop);
    259                            }
    260                        }
    261                    }
    262                    break;
    263                default:
    264                    UNREACHABLE();
    265            }
    266        }
    267 
    268        return true;
    269    }
    270 
    271    bool visitAggregate(Visit visit, TIntermAggregate *node) override
    272    {
    273        if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
    274        {
    275            size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
    276            ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
    277 
    278            if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
    279            {
    280                onGradientLoop();
    281            }
    282        }
    283 
    284        return true;
    285    }
    286 
    287    bool visitSwitch(Visit visit, TIntermSwitch *node) override
    288    {
    289        if (visit == PreVisit)
    290        {
    291            mLoopsAndSwitches.push_back(node);
    292        }
    293        else if (visit == PostVisit)
    294        {
    295            ASSERT(mLoopsAndSwitches.back() == node);
    296            mLoopsAndSwitches.pop_back();
    297        }
    298        return true;
    299    }
    300 
    301  private:
    302    MetadataList *mMetadataList;
    303    ASTMetadataHLSL *mMetadata;
    304    size_t mIndex;
    305    const CallDAG &mDag;
    306 
    307    std::vector<TIntermNode *> mLoopsAndSwitches;
    308    std::vector<TIntermIfElse *> mIfs;
    309 };
    310 
    311 // Tags all the functions called in a discontinuous loop
    312 class PushDiscontinuousLoops : public TIntermTraverser
    313 {
    314  public:
    315    PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
    316        : TIntermTraverser(true, true, true),
    317          mMetadataList(metadataList),
    318          mMetadata(&(*metadataList)[index]),
    319          mIndex(index),
    320          mDag(dag),
    321          mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
    322    {}
    323 
    324    void traverse(TIntermFunctionDefinition *node)
    325    {
    326        node->traverse(this);
    327        ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
    328    }
    329 
    330    bool visitLoop(Visit visit, TIntermLoop *loop) override
    331    {
    332        bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
    333 
    334        if (visit == PreVisit && isDiscontinuous)
    335        {
    336            mNestedDiscont++;
    337        }
    338        else if (visit == PostVisit && isDiscontinuous)
    339        {
    340            mNestedDiscont--;
    341        }
    342 
    343        return true;
    344    }
    345 
    346    bool visitAggregate(Visit visit, TIntermAggregate *node) override
    347    {
    348        switch (node->getOp())
    349        {
    350            case EOpCallFunctionInAST:
    351                if (visit == PreVisit && mNestedDiscont > 0)
    352                {
    353                    size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
    354                    ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
    355 
    356                    (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
    357                }
    358                break;
    359            default:
    360                break;
    361        }
    362        return true;
    363    }
    364 
    365  private:
    366    MetadataList *mMetadataList;
    367    ASTMetadataHLSL *mMetadata;
    368    size_t mIndex;
    369    const CallDAG &mDag;
    370 
    371    int mNestedDiscont;
    372 };
    373 }  // namespace
    374 
    375 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
    376 {
    377    return mControlFlowsContainingGradient.count(node) > 0;
    378 }
    379 
    380 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
    381 {
    382    return mIfsContainingGradientLoop.count(node) > 0;
    383 }
    384 
    385 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
    386 {
    387    MetadataList metadataList(callDag.size());
    388 
    389    // Compute all the information related to when gradient operations are used.
    390    // We want to know for each function and control flow operation if they have
    391    // a gradient operation in their call graph (shortened to "using a gradient"
    392    // in the rest of the file).
    393    //
    394    // This computation is logically split in three steps:
    395    //  1 - For each function compute if it uses a gradient in its body, ignoring
    396    // calls to other user-defined functions.
    397    //  2 - For each function determine if it uses a gradient in its call graph,
    398    // using the result of step 1 and the CallDAG to know its callees.
    399    //  3 - For each control flow statement of each function, check if it uses a
    400    // gradient in the function's body, or if it calls a user-defined function that
    401    // uses a gradient.
    402    //
    403    // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
    404    // for leaves first, then going down the tree. This is correct because 1 doesn't
    405    // depend on other functions, and 2 and 3 depend only on callees.
    406    for (size_t i = 0; i < callDag.size(); i++)
    407    {
    408        PullGradient pull(&metadataList, i, callDag);
    409        pull.traverse(callDag.getRecordFromIndex(i).node);
    410    }
    411 
    412    // Compute which loops are discontinuous and which function are called in
    413    // these loops. The same way computing gradient usage is a "pull" process,
    414    // computing "bing used in a discont. loop" is a push process. However we also
    415    // need to know what ifs have a discontinuous loop inside so we do the same type
    416    // of callgraph analysis as for the gradient.
    417 
    418    // First compute which loops are discontinuous (no specific order) and pull
    419    // the ifs and functions using a gradient loop.
    420    for (size_t i = 0; i < callDag.size(); i++)
    421    {
    422        PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
    423        pull.traverse(callDag.getRecordFromIndex(i).node);
    424    }
    425 
    426    // Then push the information to callees, either from the a local discontinuous
    427    // loop or from the caller being called in a discontinuous loop already
    428    for (size_t i = callDag.size(); i-- > 0;)
    429    {
    430        PushDiscontinuousLoops push(&metadataList, i, callDag);
    431        push.traverse(callDag.getRecordFromIndex(i).node);
    432    }
    433 
    434    // We create "Lod0" version of functions with the gradient operations replaced
    435    // by non-gradient operations so that the D3D compiler is happier with discont
    436    // loops.
    437    for (auto &metadata : metadataList)
    438    {
    439        metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
    440    }
    441 
    442    return metadataList;
    443 }
    444 
    445 }  // namespace sh