ValidateLimitations.cpp (14122B)
1 // 2 // Copyright 2002 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 7 #include "compiler/translator/ValidateLimitations.h" 8 9 #include "angle_gl.h" 10 #include "compiler/translator/Diagnostics.h" 11 #include "compiler/translator/ParseContext.h" 12 #include "compiler/translator/tree_util/IntermTraverse.h" 13 14 namespace sh 15 { 16 17 namespace 18 { 19 20 int GetLoopSymbolId(TIntermLoop *loop) 21 { 22 // Here we assume all the operations are valid, because the loop node is 23 // already validated before this call. 24 TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence(); 25 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode(); 26 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode(); 27 28 return symbol->uniqueId().get(); 29 } 30 31 // Traverses a node to check if it represents a constant index expression. 32 // Definition: 33 // constant-index-expressions are a superset of constant-expressions. 34 // Constant-index-expressions can include loop indices as defined in 35 // GLSL ES 1.0 spec, Appendix A, section 4. 36 // The following are constant-index-expressions: 37 // - Constant expressions 38 // - Loop indices as defined in section 4 39 // - Expressions composed of both of the above 40 class ValidateConstIndexExpr : public TIntermTraverser 41 { 42 public: 43 ValidateConstIndexExpr(const std::vector<int> &loopSymbols) 44 : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols) 45 {} 46 47 // Returns true if the parsed node represents a constant index expression. 48 bool isValid() const { return mValid; } 49 50 void visitSymbol(TIntermSymbol *symbol) override 51 { 52 // Only constants and loop indices are allowed in a 53 // constant index expression. 54 if (mValid) 55 { 56 bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), 57 symbol->uniqueId().get()) != mLoopSymbolIds.end(); 58 mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol; 59 } 60 } 61 62 private: 63 bool mValid; 64 const std::vector<int> mLoopSymbolIds; 65 }; 66 67 // Traverses intermediate tree to ensure that the shader does not exceed the 68 // minimum functionality mandated in GLSL 1.0 spec, Appendix A. 69 class ValidateLimitationsTraverser : public TLValueTrackingTraverser 70 { 71 public: 72 ValidateLimitationsTraverser(sh::GLenum shaderType, 73 TSymbolTable *symbolTable, 74 TDiagnostics *diagnostics); 75 76 void visitSymbol(TIntermSymbol *node) override; 77 bool visitBinary(Visit, TIntermBinary *) override; 78 bool visitLoop(Visit, TIntermLoop *) override; 79 80 private: 81 void error(TSourceLoc loc, const char *reason, const char *token); 82 void error(TSourceLoc loc, const char *reason, const ImmutableString &token); 83 84 bool isLoopIndex(TIntermSymbol *symbol); 85 bool validateLoopType(TIntermLoop *node); 86 87 bool validateForLoopHeader(TIntermLoop *node); 88 // If valid, return the index symbol id; Otherwise, return -1. 89 int validateForLoopInit(TIntermLoop *node); 90 bool validateForLoopCond(TIntermLoop *node, int indexSymbolId); 91 bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId); 92 93 // Returns true if indexing does not exceed the minimum functionality 94 // mandated in GLSL 1.0 spec, Appendix A, Section 5. 95 bool isConstExpr(TIntermNode *node); 96 bool isConstIndexExpr(TIntermNode *node); 97 bool validateIndexing(TIntermBinary *node); 98 99 sh::GLenum mShaderType; 100 TDiagnostics *mDiagnostics; 101 std::vector<int> mLoopSymbolIds; 102 }; 103 104 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType, 105 TSymbolTable *symbolTable, 106 TDiagnostics *diagnostics) 107 : TLValueTrackingTraverser(true, false, false, symbolTable), 108 mShaderType(shaderType), 109 mDiagnostics(diagnostics) 110 { 111 ASSERT(diagnostics); 112 } 113 114 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node) 115 { 116 if (isLoopIndex(node) && isLValueRequiredHere()) 117 { 118 error(node->getLine(), 119 "Loop index cannot be statically assigned to within the body of the loop", 120 node->getName()); 121 } 122 } 123 124 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node) 125 { 126 // Check indexing. 127 switch (node->getOp()) 128 { 129 case EOpIndexDirect: 130 case EOpIndexIndirect: 131 validateIndexing(node); 132 break; 133 default: 134 break; 135 } 136 return true; 137 } 138 139 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node) 140 { 141 if (!validateLoopType(node)) 142 return false; 143 144 if (!validateForLoopHeader(node)) 145 return false; 146 147 TIntermNode *body = node->getBody(); 148 if (body != nullptr) 149 { 150 mLoopSymbolIds.push_back(GetLoopSymbolId(node)); 151 body->traverse(this); 152 mLoopSymbolIds.pop_back(); 153 } 154 155 // The loop is fully processed - no need to visit children. 156 return false; 157 } 158 159 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token) 160 { 161 mDiagnostics->error(loc, reason, token); 162 } 163 164 void ValidateLimitationsTraverser::error(TSourceLoc loc, 165 const char *reason, 166 const ImmutableString &token) 167 { 168 error(loc, reason, token.data()); 169 } 170 171 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol) 172 { 173 return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) != 174 mLoopSymbolIds.end(); 175 } 176 177 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node) 178 { 179 TLoopType type = node->getType(); 180 if (type == ELoopFor) 181 return true; 182 183 // Reject while and do-while loops. 184 error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do"); 185 return false; 186 } 187 188 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node) 189 { 190 ASSERT(node->getType() == ELoopFor); 191 192 // 193 // The for statement has the form: 194 // for ( init-declaration ; condition ; expression ) statement 195 // 196 int indexSymbolId = validateForLoopInit(node); 197 if (indexSymbolId < 0) 198 return false; 199 if (!validateForLoopCond(node, indexSymbolId)) 200 return false; 201 if (!validateForLoopExpr(node, indexSymbolId)) 202 return false; 203 204 return true; 205 } 206 207 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node) 208 { 209 TIntermNode *init = node->getInit(); 210 if (init == nullptr) 211 { 212 error(node->getLine(), "Missing init declaration", "for"); 213 return -1; 214 } 215 216 // 217 // init-declaration has the form: 218 // type-specifier identifier = constant-expression 219 // 220 TIntermDeclaration *decl = init->getAsDeclarationNode(); 221 if (decl == nullptr) 222 { 223 error(init->getLine(), "Invalid init declaration", "for"); 224 return -1; 225 } 226 // To keep things simple do not allow declaration list. 227 TIntermSequence *declSeq = decl->getSequence(); 228 if (declSeq->size() != 1) 229 { 230 error(decl->getLine(), "Invalid init declaration", "for"); 231 return -1; 232 } 233 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode(); 234 if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize)) 235 { 236 error(decl->getLine(), "Invalid init declaration", "for"); 237 return -1; 238 } 239 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode(); 240 if (symbol == nullptr) 241 { 242 error(declInit->getLine(), "Invalid init declaration", "for"); 243 return -1; 244 } 245 // The loop index has type int or float. 246 TBasicType type = symbol->getBasicType(); 247 if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat)) 248 { 249 error(symbol->getLine(), "Invalid type for loop index", getBasicString(type)); 250 return -1; 251 } 252 // The loop index is initialized with constant expression. 253 if (!isConstExpr(declInit->getRight())) 254 { 255 error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression", 256 symbol->getName()); 257 return -1; 258 } 259 260 return symbol->uniqueId().get(); 261 } 262 263 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId) 264 { 265 TIntermNode *cond = node->getCondition(); 266 if (cond == nullptr) 267 { 268 error(node->getLine(), "Missing condition", "for"); 269 return false; 270 } 271 // 272 // condition has the form: 273 // loop_index relational_operator constant_expression 274 // 275 TIntermBinary *binOp = cond->getAsBinaryNode(); 276 if (binOp == nullptr) 277 { 278 error(node->getLine(), "Invalid condition", "for"); 279 return false; 280 } 281 // Loop index should be to the left of relational operator. 282 TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode(); 283 if (symbol == nullptr) 284 { 285 error(binOp->getLine(), "Invalid condition", "for"); 286 return false; 287 } 288 if (symbol->uniqueId().get() != indexSymbolId) 289 { 290 error(symbol->getLine(), "Expected loop index", symbol->getName()); 291 return false; 292 } 293 // Relational operator is one of: > >= < <= == or !=. 294 switch (binOp->getOp()) 295 { 296 case EOpEqual: 297 case EOpNotEqual: 298 case EOpLessThan: 299 case EOpGreaterThan: 300 case EOpLessThanEqual: 301 case EOpGreaterThanEqual: 302 break; 303 default: 304 error(binOp->getLine(), "Invalid relational operator", 305 GetOperatorString(binOp->getOp())); 306 break; 307 } 308 // Loop index must be compared with a constant. 309 if (!isConstExpr(binOp->getRight())) 310 { 311 error(binOp->getLine(), "Loop index cannot be compared with non-constant expression", 312 symbol->getName()); 313 return false; 314 } 315 316 return true; 317 } 318 319 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId) 320 { 321 TIntermNode *expr = node->getExpression(); 322 if (expr == nullptr) 323 { 324 error(node->getLine(), "Missing expression", "for"); 325 return false; 326 } 327 328 // for expression has one of the following forms: 329 // loop_index++ 330 // loop_index-- 331 // loop_index += constant_expression 332 // loop_index -= constant_expression 333 // ++loop_index 334 // --loop_index 335 // The last two forms are not specified in the spec, but I am assuming 336 // its an oversight. 337 TIntermUnary *unOp = expr->getAsUnaryNode(); 338 TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode(); 339 340 TOperator op = EOpNull; 341 const TFunction *opFunc = nullptr; 342 TIntermSymbol *symbol = nullptr; 343 if (unOp != nullptr) 344 { 345 op = unOp->getOp(); 346 opFunc = unOp->getFunction(); 347 symbol = unOp->getOperand()->getAsSymbolNode(); 348 } 349 else if (binOp != nullptr) 350 { 351 op = binOp->getOp(); 352 symbol = binOp->getLeft()->getAsSymbolNode(); 353 } 354 355 // The operand must be loop index. 356 if (symbol == nullptr) 357 { 358 error(expr->getLine(), "Invalid expression", "for"); 359 return false; 360 } 361 if (symbol->uniqueId().get() != indexSymbolId) 362 { 363 error(symbol->getLine(), "Expected loop index", symbol->getName()); 364 return false; 365 } 366 367 // The operator is one of: ++ -- += -=. 368 switch (op) 369 { 370 case EOpPostIncrement: 371 case EOpPostDecrement: 372 case EOpPreIncrement: 373 case EOpPreDecrement: 374 ASSERT((unOp != nullptr) && (binOp == nullptr)); 375 break; 376 case EOpAddAssign: 377 case EOpSubAssign: 378 ASSERT((unOp == nullptr) && (binOp != nullptr)); 379 break; 380 default: 381 if (BuiltInGroup::IsBuiltIn(op)) 382 { 383 ASSERT(opFunc != nullptr); 384 error(expr->getLine(), "Invalid built-in call", opFunc->name().data()); 385 } 386 else 387 { 388 error(expr->getLine(), "Invalid operator", GetOperatorString(op)); 389 } 390 return false; 391 } 392 393 // Loop index must be incremented/decremented with a constant. 394 if (binOp != nullptr) 395 { 396 if (!isConstExpr(binOp->getRight())) 397 { 398 error(binOp->getLine(), "Loop index cannot be modified by non-constant expression", 399 symbol->getName()); 400 return false; 401 } 402 } 403 404 return true; 405 } 406 407 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node) 408 { 409 ASSERT(node != nullptr); 410 return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst; 411 } 412 413 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node) 414 { 415 ASSERT(node != nullptr); 416 417 ValidateConstIndexExpr validate(mLoopSymbolIds); 418 node->traverse(&validate); 419 return validate.isValid(); 420 } 421 422 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node) 423 { 424 ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect)); 425 426 bool valid = true; 427 TIntermTyped *index = node->getRight(); 428 // The index expession must be a constant-index-expression unless 429 // the operand is a uniform in a vertex shader. 430 TIntermTyped *operand = node->getLeft(); 431 bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform); 432 if (!skip && !isConstIndexExpr(index)) 433 { 434 error(index->getLine(), "Index expression must be constant", "[]"); 435 valid = false; 436 } 437 return valid; 438 } 439 440 } // namespace 441 442 bool ValidateLimitations(TIntermNode *root, 443 GLenum shaderType, 444 TSymbolTable *symbolTable, 445 TDiagnostics *diagnostics) 446 { 447 ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics); 448 root->traverse(&validate); 449 return diagnostics->numErrors() == 0; 450 } 451 452 } // namespace sh