RewriteAtomicFunctionExpressions.cpp (6510B)
1 // 2 // Copyright 2018 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // Implementation of the function RewriteAtomicFunctionExpressions. 7 // See the header for more details. 8 9 #include "compiler/translator/tree_ops/d3d/RewriteAtomicFunctionExpressions.h" 10 11 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" 12 #include "compiler/translator/tree_util/IntermNode_util.h" 13 #include "compiler/translator/tree_util/IntermTraverse.h" 14 #include "compiler/translator/util.h" 15 16 namespace sh 17 { 18 namespace 19 { 20 // Traverser that simplifies all the atomic function expressions into the ones that can be directly 21 // translated into HLSL. 22 // 23 // case 1 (only for atomicExchange and atomicCompSwap): 24 // original: 25 // atomicExchange(counter, newValue); 26 // new: 27 // tempValue = atomicExchange(counter, newValue); 28 // 29 // case 2 (atomic function, temporary variable required): 30 // original: 31 // value = atomicAdd(counter, 1) * otherValue; 32 // someArray[atomicAdd(counter, 1)] = someOtherValue; 33 // new: 34 // value = ((tempValue = atomicAdd(counter, 1)), tempValue) * otherValue; 35 // someArray[((tempValue = atomicAdd(counter, 1)), tempValue)] = someOtherValue; 36 // 37 // case 3 (atomic function used directly initialize a variable): 38 // original: 39 // int value = atomicAdd(counter, 1); 40 // new: 41 // tempValue = atomicAdd(counter, 1); 42 // int value = tempValue; 43 // 44 class RewriteAtomicFunctionExpressionsTraverser : public TIntermTraverser 45 { 46 public: 47 RewriteAtomicFunctionExpressionsTraverser(TSymbolTable *symbolTable, int shaderVersion); 48 49 bool visitAggregate(Visit visit, TIntermAggregate *node) override; 50 bool visitBlock(Visit visit, TIntermBlock *node) override; 51 52 private: 53 static bool IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate *node, 54 TIntermNode *parentNode); 55 static bool IsAtomicFunctionInsideExpression(TIntermAggregate *node, TIntermNode *parentNode); 56 57 void rewriteAtomicFunctionCallNode(TIntermAggregate *oldAtomicFunctionNode); 58 59 const TVariable *getTempVariable(const TType *type); 60 61 int mShaderVersion; 62 TIntermSequence mTempVariables; 63 }; 64 65 RewriteAtomicFunctionExpressionsTraverser::RewriteAtomicFunctionExpressionsTraverser( 66 TSymbolTable *symbolTable, 67 int shaderVersion) 68 : TIntermTraverser(false, false, true, symbolTable), mShaderVersion(shaderVersion) 69 {} 70 71 void RewriteAtomicFunctionExpressionsTraverser::rewriteAtomicFunctionCallNode( 72 TIntermAggregate *oldAtomicFunctionNode) 73 { 74 ASSERT(oldAtomicFunctionNode); 75 76 const TVariable *returnVariable = getTempVariable(&oldAtomicFunctionNode->getType()); 77 78 TIntermBinary *rewrittenNode = new TIntermBinary( 79 TOperator::EOpAssign, CreateTempSymbolNode(returnVariable), oldAtomicFunctionNode); 80 81 auto *parentNode = getParentNode(); 82 83 auto *parentBinary = parentNode->getAsBinaryNode(); 84 if (parentBinary && parentBinary->getOp() == EOpInitialize) 85 { 86 insertStatementInParentBlock(rewrittenNode); 87 queueReplacement(CreateTempSymbolNode(returnVariable), OriginalNode::IS_DROPPED); 88 } 89 else 90 { 91 // As all atomic function assignment will be converted to the last argument of an 92 // interlocked function, if we need the return value, assignment needs to be wrapped with 93 // the comma operator and the temporary variables. 94 if (!parentNode->getAsBlock()) 95 { 96 rewrittenNode = TIntermBinary::CreateComma( 97 rewrittenNode, new TIntermSymbol(returnVariable), mShaderVersion); 98 } 99 100 queueReplacement(rewrittenNode, OriginalNode::IS_DROPPED); 101 } 102 } 103 104 const TVariable *RewriteAtomicFunctionExpressionsTraverser::getTempVariable(const TType *type) 105 { 106 TIntermDeclaration *variableDeclaration; 107 TVariable *returnVariable = 108 DeclareTempVariable(mSymbolTable, type, EvqTemporary, &variableDeclaration); 109 mTempVariables.push_back(variableDeclaration); 110 return returnVariable; 111 } 112 113 bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicExchangeOrCompSwapNoReturnValue( 114 TIntermAggregate *node, 115 TIntermNode *parentNode) 116 { 117 ASSERT(node); 118 return (node->getOp() == EOpAtomicExchange || node->getOp() == EOpAtomicCompSwap) && 119 parentNode && parentNode->getAsBlock(); 120 } 121 122 bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicFunctionInsideExpression( 123 TIntermAggregate *node, 124 TIntermNode *parentNode) 125 { 126 ASSERT(node); 127 // We only need to handle atomic functions with a parent that it is not block nodes. If the 128 // parent node is block, it means that the atomic function is not inside an expression. 129 if (!BuiltInGroup::IsAtomicMemory(node->getOp()) || parentNode->getAsBlock()) 130 { 131 return false; 132 } 133 134 auto *parentAsBinary = parentNode->getAsBinaryNode(); 135 // Assignments are handled in OutputHLSL 136 return !parentAsBinary || parentAsBinary->getOp() != EOpAssign; 137 } 138 139 bool RewriteAtomicFunctionExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) 140 { 141 ASSERT(visit == PostVisit); 142 // Skip atomic memory functions for SSBO. They will be processed in the OutputHLSL traverser. 143 if (BuiltInGroup::IsAtomicMemory(node->getOp()) && 144 IsInShaderStorageBlock((*node->getSequence())[0]->getAsTyped())) 145 { 146 return false; 147 } 148 149 TIntermNode *parentNode = getParentNode(); 150 if (IsAtomicExchangeOrCompSwapNoReturnValue(node, parentNode) || 151 IsAtomicFunctionInsideExpression(node, parentNode)) 152 { 153 rewriteAtomicFunctionCallNode(node); 154 } 155 156 return true; 157 } 158 159 bool RewriteAtomicFunctionExpressionsTraverser::visitBlock(Visit visit, TIntermBlock *node) 160 { 161 ASSERT(visit == PostVisit); 162 163 if (!mTempVariables.empty() && getParentNode()->getAsFunctionDefinition()) 164 { 165 insertStatementsInBlockAtPosition(node, 0, mTempVariables, TIntermSequence()); 166 mTempVariables.clear(); 167 } 168 169 return true; 170 } 171 172 } // anonymous namespace 173 174 bool RewriteAtomicFunctionExpressions(TCompiler *compiler, 175 TIntermNode *root, 176 TSymbolTable *symbolTable, 177 int shaderVersion) 178 { 179 RewriteAtomicFunctionExpressionsTraverser traverser(symbolTable, shaderVersion); 180 traverser.traverse(root); 181 return traverser.updateTree(compiler, root); 182 } 183 } // namespace sh