RewriteAtomicCounters.cpp (12342B)
1 // 2 // Copyright 2019 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers. 7 // 8 9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h" 10 11 #include "compiler/translator/Compiler.h" 12 #include "compiler/translator/ImmutableStringBuilder.h" 13 #include "compiler/translator/SymbolTable.h" 14 #include "compiler/translator/tree_util/IntermNode_util.h" 15 #include "compiler/translator/tree_util/IntermTraverse.h" 16 #include "compiler/translator/tree_util/ReplaceVariable.h" 17 18 namespace sh 19 { 20 namespace 21 { 22 constexpr ImmutableString kAtomicCountersVarName = ImmutableString("atomicCounters"); 23 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters"); 24 25 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters. 26 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable) 27 { 28 // Define `uint counters[];` as the only field in the interface block. 29 TFieldList *fieldList = new TFieldList; 30 TType *counterType = new TType(EbtUInt, EbpHigh, EvqGlobal); 31 counterType->makeArray(0); 32 33 TField *countersField = 34 new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal); 35 36 fieldList->push_back(countersField); 37 38 TMemoryQualifier coherentMemory = TMemoryQualifier::Create(); 39 coherentMemory.coherent = true; 40 41 // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS 42 // in libANGLE/Constants.h. 43 constexpr uint32_t kMaxAtomicCounterBuffers = 8; 44 45 // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters". 46 TLayoutQualifier layoutQualifier = TLayoutQualifier::Create(); 47 layoutQualifier.blockStorage = EbsStd430; 48 49 return DeclareInterfaceBlock(root, symbolTable, fieldList, EvqBuffer, layoutQualifier, 50 coherentMemory, kMaxAtomicCounterBuffers, 51 ImmutableString(vk::kAtomicCountersBlockName), 52 kAtomicCountersVarName); 53 } 54 55 TIntermTyped *CreateUniformBufferOffset(const TIntermTyped *uniformBufferOffsets, int binding) 56 { 57 // Each uint in the |acbBufferOffsets| uniform contains offsets for 4 bindings. Therefore, the 58 // expression to get the uniform offset for the binding is: 59 // 60 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF 61 62 // acbBufferOffsets[binding / 4] 63 TIntermBinary *uniformBufferOffsetUint = new TIntermBinary( 64 EOpIndexDirect, uniformBufferOffsets->deepCopy(), CreateIndexNode(binding / 4)); 65 66 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) 67 TIntermBinary *uniformBufferOffsetShifted = uniformBufferOffsetUint; 68 if (binding % 4 != 0) 69 { 70 uniformBufferOffsetShifted = new TIntermBinary(EOpBitShiftRight, uniformBufferOffsetUint, 71 CreateUIntNode((binding % 4) * 8)); 72 } 73 74 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF 75 return new TIntermBinary(EOpBitwiseAnd, uniformBufferOffsetShifted, CreateUIntNode(0xFF)); 76 } 77 78 TIntermBinary *CreateAtomicCounterRef(TIntermTyped *atomicCounterExpression, 79 const TVariable *atomicCounters, 80 const TIntermTyped *uniformBufferOffsets) 81 { 82 // The atomic counters storage buffer declaration looks as such: 83 // 84 // layout(...) buffer ANGLEAtomicCounters 85 // { 86 // uint counters[]; 87 // } atomicCounters[N]; 88 // 89 // Where N is large enough to accommodate atomic counter buffer bindings used in the shader. 90 // 91 // This function takes an expression that uses an atomic counter, which can either be: 92 // 93 // - ac 94 // - acArray[index] 95 // 96 // Note that RewriteArrayOfArrayOfOpaqueUniforms has already flattened array of array of atomic 97 // counters. 98 // 99 // For the first case (ac), the following code is generated: 100 // 101 // atomicCounters[binding].counters[offset] 102 // 103 // For the second case (acArray[index]), the following code is generated: 104 // 105 // atomicCounters[binding].counters[offset + index] 106 // 107 // In either case, an offset given through uniforms is also added to |offset|. The binding is 108 // necessarily a constant thanks to MonomorphizeUnsupportedFunctions. 109 110 // First determine if there's an index, and extract the atomic counter symbol out of the 111 // expression. 112 TIntermSymbol *atomicCounterSymbol = atomicCounterExpression->getAsSymbolNode(); 113 TIntermTyped *atomicCounterIndex = nullptr; 114 int atomicCounterConstIndex = 0; 115 TIntermBinary *asBinary = atomicCounterExpression->getAsBinaryNode(); 116 if (asBinary != nullptr) 117 { 118 atomicCounterSymbol = asBinary->getLeft()->getAsSymbolNode(); 119 120 switch (asBinary->getOp()) 121 { 122 case EOpIndexDirect: 123 atomicCounterConstIndex = asBinary->getRight()->getAsConstantUnion()->getIConst(0); 124 break; 125 case EOpIndexIndirect: 126 atomicCounterIndex = asBinary->getRight(); 127 break; 128 default: 129 UNREACHABLE(); 130 } 131 } 132 133 // Extract binding and offset information out of the atomic counter symbol. 134 ASSERT(atomicCounterSymbol); 135 const TVariable *atomicCounterVar = &atomicCounterSymbol->variable(); 136 const TType &atomicCounterType = atomicCounterVar->getType(); 137 138 const int binding = atomicCounterType.getLayoutQualifier().binding; 139 int offset = atomicCounterType.getLayoutQualifier().offset / 4; 140 141 // Create the expression: 142 // 143 // offset + arrayIndex + uniformOffset 144 // 145 // If arrayIndex is a constant, it's added with offset right here. 146 147 offset += atomicCounterConstIndex; 148 149 TIntermTyped *index = CreateUniformBufferOffset(uniformBufferOffsets, binding); 150 if (atomicCounterIndex != nullptr) 151 { 152 index = new TIntermBinary(EOpAdd, index, atomicCounterIndex); 153 } 154 if (offset != 0) 155 { 156 index = new TIntermBinary(EOpAdd, index, CreateIndexNode(offset)); 157 } 158 159 // Finally, create the complete expression: 160 // 161 // atomicCounters[binding].counters[index] 162 163 TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters); 164 165 // atomicCounters[binding] 166 TIntermBinary *countersBlock = 167 new TIntermBinary(EOpIndexDirect, atomicCountersRef, CreateIndexNode(binding)); 168 169 // atomicCounters[binding].counters 170 TIntermBinary *counters = 171 new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, CreateIndexNode(0)); 172 173 return new TIntermBinary(EOpIndexIndirect, counters, index); 174 } 175 176 // Traverser that: 177 // 178 // 1. Removes the |uniform atomic_uint| declarations and remembers the binding and offset. 179 // 2. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|. 180 class RewriteAtomicCountersTraverser : public TIntermTraverser 181 { 182 public: 183 RewriteAtomicCountersTraverser(TSymbolTable *symbolTable, 184 const TVariable *atomicCounters, 185 const TIntermTyped *acbBufferOffsets) 186 : TIntermTraverser(true, false, false, symbolTable), 187 mAtomicCounters(atomicCounters), 188 mAcbBufferOffsets(acbBufferOffsets) 189 {} 190 191 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override 192 { 193 if (!mInGlobalScope) 194 { 195 return true; 196 } 197 198 const TIntermSequence &sequence = *(node->getSequence()); 199 200 TIntermTyped *variable = sequence.front()->getAsTyped(); 201 const TType &type = variable->getType(); 202 bool isAtomicCounter = type.isAtomicCounter(); 203 204 if (isAtomicCounter) 205 { 206 ASSERT(type.getQualifier() == EvqUniform); 207 TIntermSequence emptySequence; 208 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, 209 std::move(emptySequence)); 210 211 return false; 212 } 213 214 return true; 215 } 216 217 bool visitAggregate(Visit visit, TIntermAggregate *node) override 218 { 219 if (BuiltInGroup::IsBuiltIn(node->getOp())) 220 { 221 bool converted = convertBuiltinFunction(node); 222 return !converted; 223 } 224 225 // AST functions don't require modification as atomic counter function parameters are 226 // removed by MonomorphizeUnsupportedFunctions. 227 return true; 228 } 229 230 void visitSymbol(TIntermSymbol *symbol) override 231 { 232 // Cannot encounter the atomic counter symbol directly. It can only be used with functions, 233 // and therefore it's handled by visitAggregate. 234 ASSERT(!symbol->getType().isAtomicCounter()); 235 } 236 237 bool visitBinary(Visit visit, TIntermBinary *node) override 238 { 239 // Cannot encounter an atomic counter expression directly. It can only be used with 240 // functions, and therefore it's handled by visitAggregate. 241 ASSERT(!node->getType().isAtomicCounter()); 242 return true; 243 } 244 245 private: 246 bool convertBuiltinFunction(TIntermAggregate *node) 247 { 248 const TOperator op = node->getOp(); 249 250 // If the function is |memoryBarrierAtomicCounter|, simply replace it with 251 // |memoryBarrierBuffer|. 252 if (op == EOpMemoryBarrierAtomicCounter) 253 { 254 TIntermSequence emptySequence; 255 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode( 256 "memoryBarrierBuffer", &emptySequence, *mSymbolTable, 310); 257 queueReplacement(substituteCall, OriginalNode::IS_DROPPED); 258 return true; 259 } 260 261 // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent. 262 if (!node->getFunction()->isAtomicCounterFunction()) 263 { 264 return false; 265 } 266 267 // Note: atomicAdd(0) is used for atomic reads. 268 uint32_t valueChange = 0; 269 constexpr char kAtomicAddFunction[] = "atomicAdd"; 270 bool isDecrement = false; 271 272 if (op == EOpAtomicCounterIncrement) 273 { 274 valueChange = 1; 275 } 276 else if (op == EOpAtomicCounterDecrement) 277 { 278 // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1. 279 valueChange = std::numeric_limits<uint32_t>::max(); 280 static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(), 281 "uint32_t max is not -1"); 282 283 isDecrement = true; 284 } 285 else 286 { 287 ASSERT(op == EOpAtomicCounter); 288 } 289 290 TIntermTyped *param = (*node->getSequence())[0]->getAsTyped(); 291 292 TIntermSequence substituteArguments; 293 substituteArguments.push_back( 294 CreateAtomicCounterRef(param, mAtomicCounters, mAcbBufferOffsets)); 295 substituteArguments.push_back(CreateUIntNode(valueChange)); 296 297 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode( 298 kAtomicAddFunction, &substituteArguments, *mSymbolTable, 310); 299 300 // Note that atomicCounterDecrement returns the *new* value instead of the prior value, 301 // unlike atomicAdd. So we need to do a -1 on the result as well. 302 if (isDecrement) 303 { 304 substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntNode(1)); 305 } 306 307 queueReplacement(substituteCall, OriginalNode::IS_DROPPED); 308 return true; 309 } 310 311 const TVariable *mAtomicCounters; 312 const TIntermTyped *mAcbBufferOffsets; 313 }; 314 315 } // anonymous namespace 316 317 bool RewriteAtomicCounters(TCompiler *compiler, 318 TIntermBlock *root, 319 TSymbolTable *symbolTable, 320 const TIntermTyped *acbBufferOffsets) 321 { 322 const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable); 323 324 RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets); 325 root->traverse(&traverser); 326 return traverser.updateTree(compiler, root); 327 } 328 } // namespace sh