From e919e2859a2e1ccf8a8fcdde8500b49778a56f5e Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 15:12:24 -0800 Subject: [PATCH 1/7] Push buffer load to end of access chain. --- source/slang/slang-emit.cpp | 4 + source/slang/slang-ir-defer-buffer-load.cpp | 205 ++++++++++++++++++++ source/slang/slang-ir-defer-buffer-load.h | 26 +++ source/slang/slang-ir.cpp | 2 +- tests/spirv/sb-load.slang | 24 +++ 5 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 source/slang/slang-ir-defer-buffer-load.cpp create mode 100644 source/slang/slang-ir-defer-buffer-load.h create mode 100644 tests/spirv/sb-load.slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1950f251c2..435d5b1155 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -30,6 +30,7 @@ #include "slang-ir-com-interface.h" #include "slang-ir-composite-reg-to-mem.h" #include "slang-ir-dce.h" +#include "slang-ir-defer-buffer-load.h" #include "slang-ir-defunctionalization.h" #include "slang-ir-diff-call.h" #include "slang-ir-dll-export.h" @@ -951,6 +952,9 @@ Result linkAndOptimizeIR( // Inline calls to any functions marked with [__unsafeInlineEarly] or [ForceInline]. performForceInlining(irModule); + // Push `structuredBufferLoad` to the end of access chain to avoid loading unnecessary data. + deferBufferLoad(irModule); + // Specialization can introduce dead code that could trip // up downstream passes like type legalization, so we // will run a DCE pass to clean up after the specialization. diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp new file mode 100644 index 0000000000..389c770939 --- /dev/null +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -0,0 +1,205 @@ +#include "slang-ir-defer-buffer-load.h" + +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" +namespace Slang +{ +struct DeferBufferLoadContext +{ + // Map an original SSA value to a pointer that can be used to load the value. + Dictionary mapValueToPtr; + + // Map an original SSA value to a load(ptr) where ptr is mapValueToPtr[value]. + Dictionary mapValueToMaterializedValue; + + // Ensure that for an original SSA value, we have formed a pointer that can be used to load the + // value. + IRInst* ensurePtr(IRInst* valueInst) + { + IRInst* result = nullptr; + if (mapValueToPtr.tryGetValue(valueInst, result)) + return result; + IRBuilder b(valueInst); + b.setInsertBefore(valueInst); + switch (valueInst->getOp()) + { + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoadStatus: + { + result = b.emitRWStructuredBufferGetElementPtr( + valueInst->getOperand(0), + valueInst->getOperand(1)); + break; + } + case kIROp_GetElement: + { + auto ptr = ensurePtr(valueInst->getOperand(0)); + if (!ptr) + return nullptr; + result = b.emitElementAddress(ptr, valueInst->getOperand(1)); + break; + } + case kIROp_FieldExtract: + { + auto ptr = ensurePtr(valueInst->getOperand(0)); + if (!ptr) + return nullptr; + result = b.emitFieldAddress(ptr, valueInst->getOperand(1)); + break; + } + } + if (result) + mapValueToPtr[valueInst] = result; + return result; + } + + static bool isStructuredBufferLoad(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoadStatus: + return true; + default: + return false; + } + } + + // Ensure that for a pointer value, we have created a load instruction to materialize the value. + IRInst* materializePointer(IRBuilder& builder, IRInst* loadInst) + { + IRInst* result; + if (mapValueToMaterializedValue.tryGetValue(loadInst, result)) + return result; + auto ptr = ensurePtr(loadInst); + builder.setInsertAfter(ptr); + result = builder.emitLoad(ptr); + mapValueToMaterializedValue[loadInst] = result; + return result; + } + + static bool isSimpleType(IRInst* type) + { + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + return false; + } + + void deferBufferLoadInst(IRBuilder& builder, List& workList, IRInst* loadInst) + { + // Don't defer the load anymore if the type is simple. + if (isSimpleType(loadInst->getDataType())) + { + if (!isStructuredBufferLoad(loadInst)) + { + auto materializedVal = materializePointer(builder, loadInst); + loadInst->replaceUsesWith(materializedVal); + } + return; + } + + // Otherwise, look for all uses and try to defer the load before actual use of the value. + ShortList pendingWorkList; + bool needMaterialize = false; + traverseUses( + loadInst, + [&](IRUse* use) + { + if (needMaterialize) + return; + + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_GetElement: + case kIROp_FieldExtract: + { + auto basePtr = ensurePtr(loadInst); + if (!basePtr) + return; + pendingWorkList.add(user); + } + break; + default: + if (!isStructuredBufferLoad(loadInst)) + { + needMaterialize = true; + return; + } + break; + } + }); + + if (needMaterialize) + { + auto val = materializePointer(builder, loadInst); + loadInst->replaceUsesWith(val); + loadInst->removeAndDeallocate(); + } + else + { + for (auto item : pendingWorkList) + workList.add(item); + } + } + + void deferBufferLoadInFunc(IRFunc* func) + { + List workList; + + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (isStructuredBufferLoad(inst)) + { + workList.add(inst); + } + } + } + + IRBuilder builder(func); + for (Index i = 0; i < workList.getCount(); i++) + { + auto inst = workList[i]; + deferBufferLoadInst(builder, workList, inst); + } + } + + void deferBufferLoad(IRGlobalValueWithCode* inst) + { + if (auto func = as(inst)) + { + deferBufferLoadInFunc(func); + } + else if (auto generic = as(inst)) + { + auto inner = findGenericReturnVal(generic); + if (auto innerFunc = as(inner)) + deferBufferLoadInFunc(innerFunc); + } + } +}; + +void deferBufferLoad(IRModule* module) +{ + DeferBufferLoadContext context; + for (auto childInst : module->getGlobalInsts()) + { + if (auto code = as(childInst)) + { + context.deferBufferLoad(code); + } + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-defer-buffer-load.h b/source/slang/slang-ir-defer-buffer-load.h new file mode 100644 index 0000000000..b542718838 --- /dev/null +++ b/source/slang/slang-ir-defer-buffer-load.h @@ -0,0 +1,26 @@ +#pragma once + +namespace Slang +{ + +/* +This pass implements a targeted optimization that defers the loading of structured buffer elements +to the end of the access chain to avoid loading and repacking unnecessary data. +For example, if we see: + val = StructuredBufferLoad(s, i) + val2 = GetElement(val, j) + val3 = FieldExtract(val2, field_key_0) + call(foo, val3) +We should rewrite the code into: + ptr = RWStructuredBufferGetElementPtr(s, i) + ptr2 = ElementAddress(ptr, j) + ptr3 = FieldAddress(ptr2, field_key_0) + val3 = Load(ptr3) + call(foo, val3) +*/ + +struct IRModule; + +void deferBufferLoad(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 823b3cd7db..5e4db43b2c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6175,7 +6175,7 @@ IRInst* IRBuilder::emitGenericAsm(UnownedStringSlice asmText) IRInst* IRBuilder::emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index) { - const auto sbt = cast(structuredBuffer->getDataType()); + const auto sbt = cast(structuredBuffer->getDataType()); const auto t = getPtrType(sbt->getElementType()); IRInst* const operands[2] = {structuredBuffer, index}; const auto i = createInst( diff --git a/tests/spirv/sb-load.slang b/tests/spirv/sb-load.slang new file mode 100644 index 0000000000..be3454c0d1 --- /dev/null +++ b/tests/spirv/sb-load.slang @@ -0,0 +1,24 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +#define FILL_PATTERN_DIMENSIONS_X 16 +#define FILL_PATTERN_DIMENSIONS_Y 16 + +struct FillPatternBuffer +{ + float4 px[FILL_PATTERN_DIMENSIONS_Y][FILL_PATTERN_DIMENSIONS_X]; +}; + +RWStructuredBuffer dp; +RWStructuredBuffer outputBuffer; + +// CHECK-NOT: OpCompositeConstruct + +[numthreads(4, 4, 1)] +void main(uint3 GTid : SV_GroupThreadID, + uint GI : SV_GroupIndex) +{ + const uint ii = GTid.x; + const uint jj = GTid.y; + const float4 pmv = dp[0].px[ii][jj]; + outputBuffer[GI] = pmv; +} \ No newline at end of file From 49acf541f440c908d0dbe0af85a3a31306e64483 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 15:15:39 -0800 Subject: [PATCH 2/7] Update test. --- tests/spirv/sb-load.slang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spirv/sb-load.slang b/tests/spirv/sb-load.slang index be3454c0d1..1b0df0be8a 100644 --- a/tests/spirv/sb-load.slang +++ b/tests/spirv/sb-load.slang @@ -8,7 +8,7 @@ struct FillPatternBuffer float4 px[FILL_PATTERN_DIMENSIONS_Y][FILL_PATTERN_DIMENSIONS_X]; }; -RWStructuredBuffer dp; +StructuredBuffer dp; RWStructuredBuffer outputBuffer; // CHECK-NOT: OpCompositeConstruct From a2babae3f93510676f046deec6d13f5002483ef5 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 15:24:54 -0800 Subject: [PATCH 3/7] Fix. --- source/slang/slang-ir-defer-buffer-load.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index 389c770939..023aa76991 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -58,12 +58,12 @@ struct DeferBufferLoadContext static bool isStructuredBufferLoad(IRInst* inst) { + // Note: we cannot defer loads from RWStructuredBuffer because there can be other + // instructions that modify the buffer. switch (inst->getOp()) { case kIROp_StructuredBufferLoad: - case kIROp_RWStructuredBufferLoad: case kIROp_StructuredBufferLoadStatus: - case kIROp_RWStructuredBufferLoadStatus: return true; default: return false; From 6abc9b35b5c63868f2b89f864f9744bef29c77ce Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 15:25:06 -0800 Subject: [PATCH 4/7] Fix. --- source/slang/slang-ir-defer-buffer-load.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index 023aa76991..b4174d565b 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -25,9 +25,7 @@ struct DeferBufferLoadContext switch (valueInst->getOp()) { case kIROp_StructuredBufferLoad: - case kIROp_RWStructuredBufferLoad: case kIROp_StructuredBufferLoadStatus: - case kIROp_RWStructuredBufferLoadStatus: { result = b.emitRWStructuredBufferGetElementPtr( valueInst->getOperand(0), From ca9c952ba4c8b19796cf25df68dbc60743f2b256 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 16:05:03 -0800 Subject: [PATCH 5/7] Fix. --- source/slang/slang-emit.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 435d5b1155..05bb12ecc7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -953,7 +953,9 @@ Result linkAndOptimizeIR( performForceInlining(irModule); // Push `structuredBufferLoad` to the end of access chain to avoid loading unnecessary data. - deferBufferLoad(irModule); + if (isKhronosTarget(targetRequest) || isMetalTarget(targetRequest) || + isWGPUTarget(targetRequest)) + deferBufferLoad(irModule); // Specialization can introduce dead code that could trip // up downstream passes like type legalization, so we From 4a780220222135d2a702a4191af7b2eed0615454 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 19:25:40 -0800 Subject: [PATCH 6/7] Make more robust. --- source/slang/slang-ir-defer-buffer-load.cpp | 183 +++++++++++++++++++- tests/spirv/sb-load-2.slang | 23 +++ 2 files changed, 197 insertions(+), 9 deletions(-) create mode 100644 tests/spirv/sb-load-2.slang diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index b4174d565b..b94284d894 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -1,17 +1,139 @@ #include "slang-ir-defer-buffer-load.h" #include "slang-ir-clone.h" +#include "slang-ir-dominators.h" #include "slang-ir-insts.h" +#include "slang-ir-redundancy-removal.h" +#include "slang-ir-util.h" #include "slang-ir.h" + namespace Slang { struct DeferBufferLoadContext { + struct AccessChain + { + List chain; + mutable HashCode64 hash = 0; + + bool operator==(const AccessChain& rhs) const + { + ensureHash(); + rhs.ensureHash(); + if (hash != rhs.hash) + return false; + if (chain.getCount() != rhs.chain.getCount()) + return false; + for (Index i = 0; i < chain.getCount(); i++) + { + if (chain[i] != rhs.chain[i]) + return false; + } + return true; + } + void ensureHash() const + { + if (hash == 0) + { + for (auto inst : chain) + { + hash = combineHash(hash, Slang::getHashCode(inst)); + } + } + } + HashCode64 getHashCode() const + { + ensureHash(); + return hash; + } + }; + // Map an original SSA value to a pointer that can be used to load the value. + Dictionary mapAccessChainToPtr; Dictionary mapValueToPtr; + // Map an ptr to its loaded value. + Dictionary mapPtrToValue; + + IRFunc* currentFunc = nullptr; + IRDominatorTree* dominatorTree = nullptr; - // Map an original SSA value to a load(ptr) where ptr is mapValueToPtr[value]. - Dictionary mapValueToMaterializedValue; + // Find the block that is dominated by all dependent blocks, and is the earliest block that + // dominates the target block. + // This is the place where we can insert the load instruction such that all access chain + // operands are defined and the load can be made avaialble to the location of valueInst. + // + IRBlock* findEarliestDominatingBlock(IRInst* valueInst, List& dependentBlocks) + { + auto targetBlock = getBlock(valueInst); + while (targetBlock) + { + auto idom = dominatorTree->getImmediateDominator(targetBlock); + if (!idom) + break; + bool isValid = true; + for (auto block : dependentBlocks) + { + if (!dominatorTree->dominates(block, idom)) + { + isValid = false; + break; + } + } + if (isValid) + { + targetBlock = idom; + } + else + { + break; + } + } + return targetBlock; + } + + // Find the earliest instruction before which we can insert the load instruction such that + // all dependent instructions for the load address are defined, and the load can reach all + // locations where the address is available. + // + IRInst* findEarliestInsertionPoint(IRInst* valueInst, AccessChain& chain) + { + List dependentBlocks; + List dependentInsts; + for (auto inst : chain.chain) + { + if (auto block = getBlock(inst)) + { + dependentBlocks.add(block); + dependentInsts.add(inst); + } + } + auto targetBlock = findEarliestDominatingBlock(valueInst, dependentBlocks); + IRInst* insertBeforeInst = targetBlock->getTerminator(); + for (;;) + { + auto prev = insertBeforeInst->getPrevInst(); + if (!prev) + break; + bool valid = true; + for (auto inst : dependentInsts) + { + if (!dominatorTree->dominates(inst, prev) || inst == prev) + { + valid = false; + break; + } + } + if (valid) + { + insertBeforeInst = prev; + } + else + { + break; + } + } + return insertBeforeInst; + } // Ensure that for an original SSA value, we have formed a pointer that can be used to load the // value. @@ -20,8 +142,39 @@ struct DeferBufferLoadContext IRInst* result = nullptr; if (mapValueToPtr.tryGetValue(valueInst, result)) return result; + AccessChain chain; + IRInst* current = valueInst; + while (current) + { + bool processed = false; + switch (current->getOp()) + { + case kIROp_GetElement: + case kIROp_FieldExtract: + chain.chain.add(current->getOperand(1)); + current = current->getOperand(0); + processed = true; + break; + default: + break; + } + if (!processed) + break; + } + chain.chain.add(current); + chain.chain.reverse(); + if (mapAccessChainToPtr.tryGetValue(chain, result)) + return result; + + // Find the proper place to insert the load instruction. + // This is the location where all operands of the access chain are defined. + // And is the earliest block so all possible uses of the value at access chain + // can be reached. IRBuilder b(valueInst); - b.setInsertBefore(valueInst); + + auto insertBeforeInst = findEarliestInsertionPoint(valueInst, chain); + b.setInsertBefore(insertBeforeInst); + switch (valueInst->getOp()) { case kIROp_StructuredBufferLoad: @@ -50,7 +203,10 @@ struct DeferBufferLoadContext } } if (result) + { + mapAccessChainToPtr[chain] = result; mapValueToPtr[valueInst] = result; + } return result; } @@ -71,13 +227,15 @@ struct DeferBufferLoadContext // Ensure that for a pointer value, we have created a load instruction to materialize the value. IRInst* materializePointer(IRBuilder& builder, IRInst* loadInst) { - IRInst* result; - if (mapValueToMaterializedValue.tryGetValue(loadInst, result)) - return result; auto ptr = ensurePtr(loadInst); + if (!ptr) + return nullptr; + IRInst* result = nullptr; + if (mapPtrToValue.tryGetValue(ptr, result)) + return result; builder.setInsertAfter(ptr); result = builder.emitLoad(ptr); - mapValueToMaterializedValue[loadInst] = result; + mapPtrToValue[ptr] = result; return result; } @@ -145,13 +303,20 @@ struct DeferBufferLoadContext } else { - for (auto item : pendingWorkList) - workList.add(item); + // Append to worklist in reverse order so we process the uses in natural appearance + // order. + for (Index i = pendingWorkList.getCount() - 1; i >= 0; i--) + workList.add(pendingWorkList[i]); } } void deferBufferLoadInFunc(IRFunc* func) { + removeRedundancyInFunc(func); + + currentFunc = func; + dominatorTree = func->getModule()->findOrCreateDominatorTree(func); + List workList; for (auto block : func->getBlocks()) diff --git a/tests/spirv/sb-load-2.slang b/tests/spirv/sb-load-2.slang new file mode 100644 index 0000000000..b4c10cb4a5 --- /dev/null +++ b/tests/spirv/sb-load-2.slang @@ -0,0 +1,23 @@ +//TEST:SIMPLE(filecheck=CHECK): -target glsl -entry main -stage compute + +struct Test1 +{ + float2x3 a; // 24B + float3x4 b; // 48B + float16_t3x2 c; // 12B + float16_t2x4 d; // 16B +}; + +StructuredBuffer dp; +RWStructuredBuffer outputBuffer; + +// CHECK-COUNT-2: unpackStorage +// CHECK-NOT: unpackStorage +[numthreads(4, 4, 1)] +void main(uint3 GTid : SV_GroupThreadID, + uint GI : SV_GroupIndex) +{ + var tmp = dp[0]; + var rs = tmp.a[0][0] + tmp.a[0][1]; + outputBuffer[GI] = float4(rs); +} \ No newline at end of file From 65fc2df763c397c6964cc34d19f0bfcfd01bcb13 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 12 Nov 2024 19:35:51 -0800 Subject: [PATCH 7/7] Fix. --- source/slang/slang-ir-defer-buffer-load.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index b94284d894..d1eb4b5e5c 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -108,7 +108,8 @@ struct DeferBufferLoadContext } } auto targetBlock = findEarliestDominatingBlock(valueInst, dependentBlocks); - IRInst* insertBeforeInst = targetBlock->getTerminator(); + IRInst* insertBeforeInst = + targetBlock == getBlock(valueInst) ? valueInst : targetBlock->getTerminator(); for (;;) { auto prev = insertBeforeInst->getPrevInst();