tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ValidateLimitations.cpp (14122B)


      1 //
      2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 #include "compiler/translator/ValidateLimitations.h"
      8 
      9 #include "angle_gl.h"
     10 #include "compiler/translator/Diagnostics.h"
     11 #include "compiler/translator/ParseContext.h"
     12 #include "compiler/translator/tree_util/IntermTraverse.h"
     13 
     14 namespace sh
     15 {
     16 
     17 namespace
     18 {
     19 
     20 int GetLoopSymbolId(TIntermLoop *loop)
     21 {
     22    // Here we assume all the operations are valid, because the loop node is
     23    // already validated before this call.
     24    TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
     25    TIntermBinary *declInit  = (*declSeq)[0]->getAsBinaryNode();
     26    TIntermSymbol *symbol    = declInit->getLeft()->getAsSymbolNode();
     27 
     28    return symbol->uniqueId().get();
     29 }
     30 
     31 // Traverses a node to check if it represents a constant index expression.
     32 // Definition:
     33 // constant-index-expressions are a superset of constant-expressions.
     34 // Constant-index-expressions can include loop indices as defined in
     35 // GLSL ES 1.0 spec, Appendix A, section 4.
     36 // The following are constant-index-expressions:
     37 // - Constant expressions
     38 // - Loop indices as defined in section 4
     39 // - Expressions composed of both of the above
     40 class ValidateConstIndexExpr : public TIntermTraverser
     41 {
     42  public:
     43    ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
     44        : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
     45    {}
     46 
     47    // Returns true if the parsed node represents a constant index expression.
     48    bool isValid() const { return mValid; }
     49 
     50    void visitSymbol(TIntermSymbol *symbol) override
     51    {
     52        // Only constants and loop indices are allowed in a
     53        // constant index expression.
     54        if (mValid)
     55        {
     56            bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
     57                                          symbol->uniqueId().get()) != mLoopSymbolIds.end();
     58            mValid            = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
     59        }
     60    }
     61 
     62  private:
     63    bool mValid;
     64    const std::vector<int> mLoopSymbolIds;
     65 };
     66 
     67 // Traverses intermediate tree to ensure that the shader does not exceed the
     68 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
     69 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
     70 {
     71  public:
     72    ValidateLimitationsTraverser(sh::GLenum shaderType,
     73                                 TSymbolTable *symbolTable,
     74                                 TDiagnostics *diagnostics);
     75 
     76    void visitSymbol(TIntermSymbol *node) override;
     77    bool visitBinary(Visit, TIntermBinary *) override;
     78    bool visitLoop(Visit, TIntermLoop *) override;
     79 
     80  private:
     81    void error(TSourceLoc loc, const char *reason, const char *token);
     82    void error(TSourceLoc loc, const char *reason, const ImmutableString &token);
     83 
     84    bool isLoopIndex(TIntermSymbol *symbol);
     85    bool validateLoopType(TIntermLoop *node);
     86 
     87    bool validateForLoopHeader(TIntermLoop *node);
     88    // If valid, return the index symbol id; Otherwise, return -1.
     89    int validateForLoopInit(TIntermLoop *node);
     90    bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
     91    bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
     92 
     93    // Returns true if indexing does not exceed the minimum functionality
     94    // mandated in GLSL 1.0 spec, Appendix A, Section 5.
     95    bool isConstExpr(TIntermNode *node);
     96    bool isConstIndexExpr(TIntermNode *node);
     97    bool validateIndexing(TIntermBinary *node);
     98 
     99    sh::GLenum mShaderType;
    100    TDiagnostics *mDiagnostics;
    101    std::vector<int> mLoopSymbolIds;
    102 };
    103 
    104 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
    105                                                           TSymbolTable *symbolTable,
    106                                                           TDiagnostics *diagnostics)
    107    : TLValueTrackingTraverser(true, false, false, symbolTable),
    108      mShaderType(shaderType),
    109      mDiagnostics(diagnostics)
    110 {
    111    ASSERT(diagnostics);
    112 }
    113 
    114 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
    115 {
    116    if (isLoopIndex(node) && isLValueRequiredHere())
    117    {
    118        error(node->getLine(),
    119              "Loop index cannot be statically assigned to within the body of the loop",
    120              node->getName());
    121    }
    122 }
    123 
    124 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
    125 {
    126    // Check indexing.
    127    switch (node->getOp())
    128    {
    129        case EOpIndexDirect:
    130        case EOpIndexIndirect:
    131            validateIndexing(node);
    132            break;
    133        default:
    134            break;
    135    }
    136    return true;
    137 }
    138 
    139 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
    140 {
    141    if (!validateLoopType(node))
    142        return false;
    143 
    144    if (!validateForLoopHeader(node))
    145        return false;
    146 
    147    TIntermNode *body = node->getBody();
    148    if (body != nullptr)
    149    {
    150        mLoopSymbolIds.push_back(GetLoopSymbolId(node));
    151        body->traverse(this);
    152        mLoopSymbolIds.pop_back();
    153    }
    154 
    155    // The loop is fully processed - no need to visit children.
    156    return false;
    157 }
    158 
    159 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
    160 {
    161    mDiagnostics->error(loc, reason, token);
    162 }
    163 
    164 void ValidateLimitationsTraverser::error(TSourceLoc loc,
    165                                         const char *reason,
    166                                         const ImmutableString &token)
    167 {
    168    error(loc, reason, token.data());
    169 }
    170 
    171 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
    172 {
    173    return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
    174           mLoopSymbolIds.end();
    175 }
    176 
    177 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
    178 {
    179    TLoopType type = node->getType();
    180    if (type == ELoopFor)
    181        return true;
    182 
    183    // Reject while and do-while loops.
    184    error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
    185    return false;
    186 }
    187 
    188 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
    189 {
    190    ASSERT(node->getType() == ELoopFor);
    191 
    192    //
    193    // The for statement has the form:
    194    //    for ( init-declaration ; condition ; expression ) statement
    195    //
    196    int indexSymbolId = validateForLoopInit(node);
    197    if (indexSymbolId < 0)
    198        return false;
    199    if (!validateForLoopCond(node, indexSymbolId))
    200        return false;
    201    if (!validateForLoopExpr(node, indexSymbolId))
    202        return false;
    203 
    204    return true;
    205 }
    206 
    207 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
    208 {
    209    TIntermNode *init = node->getInit();
    210    if (init == nullptr)
    211    {
    212        error(node->getLine(), "Missing init declaration", "for");
    213        return -1;
    214    }
    215 
    216    //
    217    // init-declaration has the form:
    218    //     type-specifier identifier = constant-expression
    219    //
    220    TIntermDeclaration *decl = init->getAsDeclarationNode();
    221    if (decl == nullptr)
    222    {
    223        error(init->getLine(), "Invalid init declaration", "for");
    224        return -1;
    225    }
    226    // To keep things simple do not allow declaration list.
    227    TIntermSequence *declSeq = decl->getSequence();
    228    if (declSeq->size() != 1)
    229    {
    230        error(decl->getLine(), "Invalid init declaration", "for");
    231        return -1;
    232    }
    233    TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
    234    if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
    235    {
    236        error(decl->getLine(), "Invalid init declaration", "for");
    237        return -1;
    238    }
    239    TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
    240    if (symbol == nullptr)
    241    {
    242        error(declInit->getLine(), "Invalid init declaration", "for");
    243        return -1;
    244    }
    245    // The loop index has type int or float.
    246    TBasicType type = symbol->getBasicType();
    247    if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
    248    {
    249        error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
    250        return -1;
    251    }
    252    // The loop index is initialized with constant expression.
    253    if (!isConstExpr(declInit->getRight()))
    254    {
    255        error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
    256              symbol->getName());
    257        return -1;
    258    }
    259 
    260    return symbol->uniqueId().get();
    261 }
    262 
    263 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
    264 {
    265    TIntermNode *cond = node->getCondition();
    266    if (cond == nullptr)
    267    {
    268        error(node->getLine(), "Missing condition", "for");
    269        return false;
    270    }
    271    //
    272    // condition has the form:
    273    //     loop_index relational_operator constant_expression
    274    //
    275    TIntermBinary *binOp = cond->getAsBinaryNode();
    276    if (binOp == nullptr)
    277    {
    278        error(node->getLine(), "Invalid condition", "for");
    279        return false;
    280    }
    281    // Loop index should be to the left of relational operator.
    282    TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
    283    if (symbol == nullptr)
    284    {
    285        error(binOp->getLine(), "Invalid condition", "for");
    286        return false;
    287    }
    288    if (symbol->uniqueId().get() != indexSymbolId)
    289    {
    290        error(symbol->getLine(), "Expected loop index", symbol->getName());
    291        return false;
    292    }
    293    // Relational operator is one of: > >= < <= == or !=.
    294    switch (binOp->getOp())
    295    {
    296        case EOpEqual:
    297        case EOpNotEqual:
    298        case EOpLessThan:
    299        case EOpGreaterThan:
    300        case EOpLessThanEqual:
    301        case EOpGreaterThanEqual:
    302            break;
    303        default:
    304            error(binOp->getLine(), "Invalid relational operator",
    305                  GetOperatorString(binOp->getOp()));
    306            break;
    307    }
    308    // Loop index must be compared with a constant.
    309    if (!isConstExpr(binOp->getRight()))
    310    {
    311        error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
    312              symbol->getName());
    313        return false;
    314    }
    315 
    316    return true;
    317 }
    318 
    319 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
    320 {
    321    TIntermNode *expr = node->getExpression();
    322    if (expr == nullptr)
    323    {
    324        error(node->getLine(), "Missing expression", "for");
    325        return false;
    326    }
    327 
    328    // for expression has one of the following forms:
    329    //     loop_index++
    330    //     loop_index--
    331    //     loop_index += constant_expression
    332    //     loop_index -= constant_expression
    333    //     ++loop_index
    334    //     --loop_index
    335    // The last two forms are not specified in the spec, but I am assuming
    336    // its an oversight.
    337    TIntermUnary *unOp   = expr->getAsUnaryNode();
    338    TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
    339 
    340    TOperator op            = EOpNull;
    341    const TFunction *opFunc = nullptr;
    342    TIntermSymbol *symbol   = nullptr;
    343    if (unOp != nullptr)
    344    {
    345        op     = unOp->getOp();
    346        opFunc = unOp->getFunction();
    347        symbol = unOp->getOperand()->getAsSymbolNode();
    348    }
    349    else if (binOp != nullptr)
    350    {
    351        op     = binOp->getOp();
    352        symbol = binOp->getLeft()->getAsSymbolNode();
    353    }
    354 
    355    // The operand must be loop index.
    356    if (symbol == nullptr)
    357    {
    358        error(expr->getLine(), "Invalid expression", "for");
    359        return false;
    360    }
    361    if (symbol->uniqueId().get() != indexSymbolId)
    362    {
    363        error(symbol->getLine(), "Expected loop index", symbol->getName());
    364        return false;
    365    }
    366 
    367    // The operator is one of: ++ -- += -=.
    368    switch (op)
    369    {
    370        case EOpPostIncrement:
    371        case EOpPostDecrement:
    372        case EOpPreIncrement:
    373        case EOpPreDecrement:
    374            ASSERT((unOp != nullptr) && (binOp == nullptr));
    375            break;
    376        case EOpAddAssign:
    377        case EOpSubAssign:
    378            ASSERT((unOp == nullptr) && (binOp != nullptr));
    379            break;
    380        default:
    381            if (BuiltInGroup::IsBuiltIn(op))
    382            {
    383                ASSERT(opFunc != nullptr);
    384                error(expr->getLine(), "Invalid built-in call", opFunc->name().data());
    385            }
    386            else
    387            {
    388                error(expr->getLine(), "Invalid operator", GetOperatorString(op));
    389            }
    390            return false;
    391    }
    392 
    393    // Loop index must be incremented/decremented with a constant.
    394    if (binOp != nullptr)
    395    {
    396        if (!isConstExpr(binOp->getRight()))
    397        {
    398            error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
    399                  symbol->getName());
    400            return false;
    401        }
    402    }
    403 
    404    return true;
    405 }
    406 
    407 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
    408 {
    409    ASSERT(node != nullptr);
    410    return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
    411 }
    412 
    413 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
    414 {
    415    ASSERT(node != nullptr);
    416 
    417    ValidateConstIndexExpr validate(mLoopSymbolIds);
    418    node->traverse(&validate);
    419    return validate.isValid();
    420 }
    421 
    422 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
    423 {
    424    ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
    425 
    426    bool valid          = true;
    427    TIntermTyped *index = node->getRight();
    428    // The index expession must be a constant-index-expression unless
    429    // the operand is a uniform in a vertex shader.
    430    TIntermTyped *operand = node->getLeft();
    431    bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
    432    if (!skip && !isConstIndexExpr(index))
    433    {
    434        error(index->getLine(), "Index expression must be constant", "[]");
    435        valid = false;
    436    }
    437    return valid;
    438 }
    439 
    440 }  // namespace
    441 
    442 bool ValidateLimitations(TIntermNode *root,
    443                         GLenum shaderType,
    444                         TSymbolTable *symbolTable,
    445                         TDiagnostics *diagnostics)
    446 {
    447    ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
    448    root->traverse(&validate);
    449    return diagnostics->numErrors() == 0;
    450 }
    451 
    452 }  // namespace sh