diff --git a/include/slang.h b/include/slang.h index 3024aa8844..3bcdcbba8c 100644 --- a/include/slang.h +++ b/include/slang.h @@ -852,6 +852,7 @@ extern "C" EmitIr, // bool ReportDownstreamTime, // bool ReportPerfBenchmark, // bool + ReportCheckpointIntermediates, // bool SkipSPIRVValidation, // bool SourceEmbedStyle, SourceEmbedName, diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h index 7226edc04c..8c140cf3d6 100644 --- a/source/slang-record-replay/util/emum-to-string.h +++ b/source/slang-record-replay/util/emum-to-string.h @@ -149,6 +149,7 @@ namespace SlangRecord CASE(EmitIr); CASE(ReportDownstreamTime); CASE(ReportPerfBenchmark); + CASE(ReportCheckpointIntermediates); CASE(SkipSPIRVValidation); CASE(SourceEmbedStyle); CASE(SourceEmbedName); diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 541085b4ee..c89d94c807 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2451,12 +2451,16 @@ namespace Slang return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); } + bool CodeGenContext::shouldReportCheckpointIntermediates() + { + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates); + } + bool CodeGenContext::shouldDumpIntermediates() { return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); } - bool CodeGenContext::shouldTrackLiveness() { return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 0c788ae182..4b20d1f763 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2728,6 +2728,7 @@ namespace Slang bool shouldValidateIR(); bool shouldDumpIR(); + bool shouldReportCheckpointIntermediates(); bool shouldTrackLiveness(); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 81170fac3e..e0f1e90c5c 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -894,6 +894,12 @@ DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage B DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.") +// Autodiff checkpoint reporting +DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'") +DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:") +DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:") +DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report") + // // 8xxxx - Issues specific to a particular library/technology/platform/etc. // diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index cdd2ca5b66..6e3556064e 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -34,6 +34,7 @@ #include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" +#include "slang-ir-layout.h" #include "slang-ir-legalize-array-return-type.h" #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir-legalize-varying-params.h" @@ -214,6 +215,68 @@ static void dumpIRIfEnabled( } } +static void reportCheckpointIntermediates(CodeGenContext* codeGenContext, DiagnosticSink* sink, IRModule* irModule) +{ + // Report checkpointing information + CompilerOptionSet& optionSet = codeGenContext->getTargetProgram()->getOptionSet(); + SourceManager* sourceManager = sink->getSourceManager(); + + SourceWriter typeWriter(sourceManager, LineDirectiveMode::None, nullptr); + + CLikeSourceEmitter::Desc description; + description.codeGenContext = codeGenContext; + description.sourceWriter = &typeWriter; + + CPPSourceEmitter emitter(description); + + int nonEmptyStructs = 0; + for (auto inst : irModule->getGlobalInsts()) + { + IRStructType *structType = as(inst); + if (!structType) + continue; + + auto checkpointDecoration = structType->findDecoration(); + if (!checkpointDecoration) + continue; + + IRSizeAndAlignment structSize; + getNaturalSizeAndAlignment(optionSet, structType, &structSize); + + // Reporting happens before empty structs are optimized out + // and we still want to keep the checkpointing decorations, + // so we end up needing to check for non-zero-ness + if (structSize.size == 0) + continue; + + auto func = checkpointDecoration->getSourceFunction(); + sink->diagnose(structType, Diagnostics::reportCheckpointIntermediates, func, structSize.size); + nonEmptyStructs++; + + for (auto field : structType->getFields()) + { + IRType *fieldType = field->getFieldType(); + IRSizeAndAlignment fieldSize; + getNaturalSizeAndAlignment(optionSet, fieldType, &fieldSize); + if (fieldSize.size == 0) + continue; + + typeWriter.clearContent(); + emitter.emitType(fieldType); + + sink->diagnose(field->sourceLoc, + field->findDecoration() + ? Diagnostics::reportCheckpointCounter + : Diagnostics::reportCheckpointVariable, + fieldSize.size, + typeWriter.getContent()); + } + } + + if (nonEmptyStructs == 0) + sink->diagnose(SourceLoc(), Diagnostics::reportCheckpointNone); +} + struct LinkingAndOptimizationOptions { bool shouldLegalizeExistentialAndResourceTypes = true; @@ -767,6 +830,10 @@ Result linkAndOptimizeIR( break; } + // Report checkpointing information + if (codeGenContext->shouldReportCheckpointIntermediates()) + reportCheckpointIntermediates(codeGenContext, sink, irModule); + if (requiredLoweringPassSet.autodiff) finalizeAutoDiffPass(targetProgram, irModule); diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 8a48936d7e..b55f6b93d8 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -69,30 +69,28 @@ struct AddressInstEliminationContext } } - void transformLoadAddr(IRUse* use) + void transformLoadAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto load = as(use->getUser()); - IRBuilder builder(module); builder.setInsertBefore(use->getUser()); auto value = getValue(builder, addr); load->replaceUsesWith(value); load->removeAndDeallocate(); } - void transformStoreAddr(IRUse* use) + void transformStoreAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto store = as(use->getUser()); - IRBuilder builder(module); builder.setInsertBefore(use->getUser()); storeValue(builder, addr, store->getVal()); store->removeAndDeallocate(); } - void transformCallAddr(IRUse* use) + void transformCallAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto call = as(use->getUser()); @@ -103,7 +101,6 @@ struct AddressInstEliminationContext return; } - IRBuilder builder(module); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast(addr->getFullType())->getValueType()); @@ -155,17 +152,20 @@ struct AddressInstEliminationContext use = nextUse; continue; } + + IRBuilder transformBuilder(module); + IRBuilderSourceLocRAII sourceLocationScope(&transformBuilder, use->getUser()->sourceLoc); switch (use->getUser()->getOp()) { case kIROp_Load: - transformLoadAddr(use); + transformLoadAddr(transformBuilder, use); break; case kIROp_Store: - transformStoreAddr(use); + transformStoreAddr(transformBuilder, use); break; case kIROp_Call: - transformCallAddr(use); + transformCallAddr(transformBuilder, use); break; case kIROp_GetElementPtr: case kIROp_FieldAddress: diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 9fe4ec70b6..f51178f0fc 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -3,8 +3,9 @@ #include "slang-ir-autodiff-region.h" #include "slang-ir-simplify-cfg.h" #include "slang-ir-util.h" -#include "../core/slang-func-ptr.h" +#include "slang-ir-insts.h" #include "slang-ir.h" +#include "../core/slang-func-ptr.h" namespace Slang { @@ -1092,7 +1093,8 @@ IRType* getTypeForLocalStorage( IRVar* emitIndexedLocalVar( IRBlock* varBlock, IRType* baseType, - const List& defBlockIndices) + const List& defBlockIndices, + SourceLoc location) { // Cannot store pointers. Case should have been handled by now. SLANG_RELEASE_ASSERT(!as(baseType)); @@ -1101,6 +1103,8 @@ IRVar* emitIndexedLocalVar( SLANG_RELEASE_ASSERT(!as(baseType)); IRBuilder varBuilder(varBlock->getModule()); + IRBuilderSourceLocRAII sourceLocationScope(&varBuilder, location); + varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst()); IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices); @@ -1179,9 +1183,14 @@ IRVar* storeIndexedValue( IRInst* instToStore, const List& defBlockIndices) { - IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices); + IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, + instToStore->getDataType(), + defBlockIndices, + instToStore->sourceLoc); - IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices); + IRInst* addr = emitIndexedStoreAddressForVar(builder, + localVar, + defBlockIndices); builder->emitStore(addr, instToStore); @@ -1574,12 +1583,16 @@ RefPtr ensurePrimalAvailability( // region, that means there's no need to allocate a fully indexed var. // defBlockIndices = maybeTrimIndices(defBlockIndices, indexedBlockInfo, outOfScopeUses); - - IRVar* localVar = storeIndexedValue( - &builder, - varBlock, - builder.emitLoad(varToStore), - defBlockIndices); + + IRVar* localVar = nullptr; + { + IRBuilderSourceLocRAII sourceLocationScope(&builder, varToStore->sourceLoc); + localVar = storeIndexedValue( + &builder, + varBlock, + builder.emitLoad(varToStore), + defBlockIndices); + } for (auto use : outOfScopeUses) { @@ -1626,6 +1639,8 @@ RefPtr ensurePrimalAvailability( } else { + IRBuilderSourceLocRAII sourceLocationScope(&builder, instToStore->sourceLoc); + // Handle the special case of loop counters. // The only case where there will be a reference of primal loop counter from rev blocks // is the start of a loop in the reverse code. Since loop counters are not considered a @@ -1643,6 +1658,8 @@ RefPtr ensurePrimalAvailability( setInsertAfterOrdinaryInst(&builder, instToStore); auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices); + if (isLoopCounter) + builder.addLoopCounterDecoration(localVar); for (auto use : outOfScopeUses) { @@ -1728,6 +1745,8 @@ static IRBlock* getUpdateBlock(IRLoop* loop) void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam) { IRBuilder builder(primalLoop); + IRBuilderSourceLocRAII sourceLocationScope(&builder, primalLoop->sourceLoc); + primalCountParam = nullptr; // Grab first primal block. @@ -1899,8 +1918,7 @@ RefPtr applyCheckpointPolicy(IRGlobalValueWithCode* func) // Legalize the primal inst accesses by introducing local variables / arrays and emitting // necessary load/store logic. // - primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); - return primalsInfo; + return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); } void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 35a197f29b..2fb73c4ac9 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -403,8 +403,11 @@ namespace Slang List primalTypes, propagateTypes; IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType()); + IRParam *currentParam = origFunc->getFirstParam(); for (UInt i = 0; i < origFuncType->getParamCount(); i++) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc); + auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i)); if (propagateParamType) @@ -453,6 +456,7 @@ namespace Slang primalArgs.add(var); } primalTypes.add(primalParamType); + currentParam = currentParam->getNextParam(); } // Add dOut argument to propagateArgs. @@ -588,6 +592,8 @@ namespace Slang autoDiffSharedContext->transcriberSet.forwardTranscriber); auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); IRFunc* fwdDiffFunc = as(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent))); + fwdDiffFunc->sourceLoc = primalFunc->sourceLoc; + SLANG_ASSERT(fwdDiffFunc); auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); for (auto i = oldCount; i < newCount; i++) @@ -712,8 +718,10 @@ namespace Slang } // Transpose the first block (parameter block) - auto paramTransposeInfo = - splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable); + auto paramTransposeInfo = splitAndTransposeParameterBlock(builder, + diffPropagateFunc, + primalFunc->sourceLoc, + isResultDifferentiable); // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc // may be used by write back logic that we are going to insert later. @@ -815,6 +823,7 @@ namespace Slang ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, + SourceLoc primalLoc, bool isResultDifferentiable) { // This method splits transposes the all the parameters for both the primal and propagate computation. @@ -841,6 +850,7 @@ namespace Slang auto nextBlockBuilder = *builder; nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst()); + SourceLoc returnLoc; IRBlock* firstDiffBlock = nullptr; for (auto block : diffFunc->getBlocks()) { @@ -849,6 +859,13 @@ namespace Slang firstDiffBlock = block; break; } + + auto terminator = block->getTerminator(); + if (as(terminator)) + { + returnLoc = terminator->sourceLoc; + break; + } } SLANG_RELEASE_ASSERT(firstDiffBlock); @@ -895,6 +912,8 @@ namespace Slang // from the primal compuation logic in the future propagate function be replaced to. for (auto fwdParam : fwdParams) { + IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc); + // Define the replacement insts that we are going to fill in for each case. IRInst* diffRefReplacement = nullptr; IRInst* primalRefReplacement = nullptr; @@ -1186,6 +1205,7 @@ namespace Slang SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); + dOutParam->sourceLoc = returnLoc; builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut")); result.propagateFuncParams.add(dOutParam); } @@ -1196,6 +1216,10 @@ namespace Slang result.primalFuncParams.add(ctxParam); result.propagateFuncParams.add(ctxParam); result.dOutParam = dOutParam; + + diffFunc->sourceLoc = primalLoc; + ctxParam->sourceLoc = primalLoc; + return result; } diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 68cb4e0c9a..b65701a7a9 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -105,6 +105,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase ParameterBlockTransposeInfo splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, + SourceLoc primalLoc, bool isResultDifferentiable); void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index da69ed8aea..1fa76c7303 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1033,8 +1033,9 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori if (as(origInst->getParent()) && !as(origInst)) return InstPair(origInst, nullptr); - auto result = transcribeInstImpl(builder, origInst); + IRBuilderSourceLocRAII sourceLocationScope(builder, origInst->sourceLoc); + auto result = transcribeInstImpl(builder, origInst); if (result.primal == nullptr && result.differential == nullptr) { if (auto origType = as(origInst)) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index d42462e1ba..1f8c3052ed 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -609,6 +609,8 @@ struct DiffTransposePass auto nextInst = inst->getNextInst(); if (auto varInst = as(inst)) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, varInst->sourceLoc); + if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst)) { if (auto ptrPrimalType = as(tryGetPrimalTypeFromDiffInst(varInst))) @@ -692,7 +694,11 @@ struct DiffTransposePass SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr); builder.setInsertInto(lastRevBlock); - builder.emitReturn(); + + { + IRBuilderSourceLocRAII sourceLocationScope(&builder, revDiffFunc->sourceLoc); + builder.emitReturn(); + } // Remove fwd-mode blocks. for (auto block : workList) @@ -703,6 +709,8 @@ struct DiffTransposePass IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst) { + IRBuilderSourceLocRAII sourceLocationScope(builder, fwdInst->sourceLoc); + if (auto accVar = getOrCreateAccumulatorVar(fwdInst)) { auto gradValue = builder->emitLoad(accVar); @@ -731,6 +739,7 @@ struct DiffTransposePass return revAccumulatorVarMap[fwdInst]; IRBuilder tempVarBuilder(autodiffContext->moduleInst->getModule()); + IRBuilderSourceLocRAII sourceLocationSCope(&tempVarBuilder, fwdInst->sourceLoc); IRBlock* firstDiffBlock = firstRevDiffBlockMap[as(fwdInst->getParent()->getParent())]; @@ -785,6 +794,8 @@ struct DiffTransposePass for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) { auto arg = branchInst->getArg(ii); + + IRBuilderSourceLocRAII sourceLocationScope(&builder, arg->sourceLoc); if (isDifferentialInst(arg)) { // If the arg is a differential, emit a parameter @@ -885,6 +896,8 @@ struct DiffTransposePass List phiParamRevGradInsts; for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam()) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, param->sourceLoc); + if (isDifferentialInst(param)) { // This param might be used outside this block. @@ -949,6 +962,8 @@ struct DiffTransposePass if (auto accVar = getOrCreateAccumulatorVar(externInst)) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, externInst->sourceLoc); + // Accumulate all gradients, including our accumulator variable, // into one inst. // @@ -1050,6 +1065,7 @@ struct DiffTransposePass // Emit the aggregate of all the gradients here. // This will form the total derivative for this inst. + IRBuilderSourceLocRAII sourceLocationScope(builder, inst->sourceLoc); auto revValue = emitAggregateValue(builder, primalType, gradients); auto transposeResult = transposeInst(builder, inst, revValue); @@ -2738,7 +2754,6 @@ struct DiffTransposePass gradient.revGradInst, gradient.fwdGradInst )); - } for (auto pair : bucketedGradients) diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 9b3e3a324a..0953c535a5 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -75,6 +75,9 @@ struct ExtractPrimalFuncContext builder.setInsertBefore(destFunc); IRFuncType* originalFuncType = nullptr; outIntermediateType = createIntermediateType(destFunc); + + builder.addCheckpointIntermediateDecoration(outIntermediateType, originalFunc); + outIntermediateType->sourceLoc = originalFunc->sourceLoc; GenericChildrenMigrationContext migrationContext; migrationContext.init(as(findOuterGeneric(originalFunc)), as(findOuterGeneric(destFunc)), destFunc); @@ -154,6 +157,7 @@ struct ExtractPrimalFuncContext IRInst* intermediateOutput) { auto field = addIntermediateContextField(inst->getDataType(), intermediateOutput); + field->sourceLoc = inst->sourceLoc; auto key = field->getKey(); if (auto nameHint = inst->findDecoration()) cloneDecoration(nameHint, key); @@ -219,6 +223,10 @@ struct ExtractPrimalFuncContext if (inst->hasUses()) { auto field = addIntermediateContextField(cast(inst->getDataType())->getValueType(), outIntermediary); + field->sourceLoc = inst->sourceLoc; + if (inst->findDecoration()) + builder.addLoopCounterDecoration(field); + builder.setInsertBefore(inst); auto fieldAddr = builder.emitFieldAddress( inst->getFullType(), outIntermediary, field->getKey()); @@ -379,12 +387,16 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( use->set(builder.getVoidValue()); continue; } + + IRBuilderSourceLocRAII sourceLocationScope(&builder, use->getUser()->sourceLoc); + builder.setInsertBefore(use->getUser()); auto valType = cast(inst->getFullType())->getValueType(); auto val = builder.emitFieldExtract( valType, intermediateVar, structKeyDecor->getStructKey()); + if (use->getUser()->getOp() == kIROp_Load) { use->getUser()->replaceUsesWith(val); @@ -392,8 +404,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } else { - auto tempVar = - builder.emitVar(valType); + auto tempVar = builder.emitVar(valType); builder.emitStore(tempVar, val); use->set(tempVar); } @@ -401,7 +412,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } else { - // Orindary value. + // Ordinary value. // We insert a fieldExtract at each use site instead of before `inst`, // since at this stage of autodiff pass, `inst` does not necessarily // dominate all the use sites if `inst` is defined in partial branch @@ -417,6 +428,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( inst->getFullType(), intermediateVar, structKeyDecor->getStructKey()); + val->sourceLoc = user->sourceLoc; builder.replaceOperand(iuse, val); } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 9f18db6e06..6ae5126f9b 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -588,7 +588,6 @@ struct DiffUnzipPass as(diffMap[targetBlock]), diffArgs.getCount(), diffArgs.getBuffer())); - } case kIROp_conditionalBranch: @@ -710,6 +709,9 @@ struct DiffUnzipPass void splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst) { + IRBuilderSourceLocRAII primalLocationScope(primalBuilder, inst->sourceLoc); + IRBuilderSourceLocRAII diffLocationScope(diffBuilder, inst->sourceLoc); + auto instPair = _splitMixedInst(primalBuilder, diffBuilder, inst); primalMap[inst] = instPair.primal; diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 0979c097c4..07a6a76fb3 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1203,6 +1203,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntermediateContextFieldDifferentialTypeDecoration: + case kIROp_CheckpointIntermediateDecoration: decor->removeAndDeallocate(); break; case kIROp_AutoDiffBuiltinDecoration: diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index e2297bcb2c..a8b9b548e0 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -220,6 +220,7 @@ static void _cloneInstDecorationsAndChildren( auto oldType = oldParam->getFullType(); auto newType = (IRType*)findCloneForOperand(env, oldType); newParam->setFullType(newType); + newParam->sourceLoc = oldParam->sourceLoc; } } diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp index b17fad6ec0..0db2fc765c 100644 --- a/source/slang/slang-ir-eliminate-phis.cpp +++ b/source/slang/slang-ir-eliminate-phis.cpp @@ -462,6 +462,7 @@ struct PhiEliminationContext // to the temporary that will replace it. // param->transferDecorationsTo(temp); + temp->sourceLoc = param->sourceLoc; } // The other main auxilliary sxtructure is used to track @@ -550,6 +551,7 @@ struct PhiEliminationContext auto user = use->getUser(); m_builder.setInsertBefore(user); auto newVal = m_builder.emitLoad(temp); + newVal->sourceLoc = param->sourceLoc; m_builder.replaceOperand(use, newVal); } @@ -938,6 +940,7 @@ struct PhiEliminationContext newOperands.getCount(), newOperands.getArrayView().getBuffer()); oldBranch->transferDecorationsTo(newBranch); + newBranch->sourceLoc = oldBranch->sourceLoc; // TODO: We could consider just modifying `branch` in-place by clearing // the relevant operands for the phi arguments and setting its operand diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp index 34a0e5ff4c..fa556bc58e 100644 --- a/source/slang/slang-ir-init-local-var.cpp +++ b/source/slang/slang-ir-init-local-var.cpp @@ -47,6 +47,9 @@ void initializeLocalVariables(IRModule* module, IRGlobalValueWithCode* func) breakLabel:; if (initialized) continue; + + IRBuilderSourceLocRAII sourceLocationScope(&builder, inst->sourceLoc); + builder.setInsertAfter(inst); builder.emitStore( inst, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b526df3a92..301a9c789f 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1056,6 +1056,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /// Hint that the result from a call to the decorated function should be recomputed in backward prop function. INST(PreferRecomputeDecoration, PreferRecomputeDecoration, 0, 0) + /// Hint that a struct is used for reverse mode checkpointing + INST(CheckpointIntermediateDecoration, CheckpointIntermediateDecoration, 1, 0) + INST_RANGE(CheckpointHintDecoration, PreferCheckpointDecoration, PreferRecomputeDecoration) /// Marks a function whose return value is never dynamic uniform. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 69f1299862..37f242e55a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -947,6 +947,16 @@ struct IRPreferCheckpointDecoration : IRCheckpointHintDecoration IR_LEAF_ISA(PreferCheckpointDecoration) }; +struct IRCheckpointIntermediateDecoration : IRCheckpointHintDecoration +{ + enum + { + kOp = kIROp_CheckpointIntermediateDecoration + }; + IR_LEAF_ISA(CheckpointIntermediateDecoration) + + IRInst* getSourceFunction() { return getOperand(0); } +}; struct IRLoopCounterDecoration : IRDecoration { @@ -5152,6 +5162,11 @@ struct IRBuilder { addDecoration(inst, kIROp_MemoryQualifierSetDecoration, getIntValue(getIntType(), flags)); } + + void addCheckpointIntermediateDecoration(IRInst* inst, IRGlobalValueWithCode *func) + { + addDecoration(inst, kIROp_CheckpointIntermediateDecoration, func); + } }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 753c930a86..ef05511612 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -526,6 +526,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) // we will now introduce a breakable region for each iteration. IRBuilder builder(module); + IRBuilderSourceLocRAII sourceLocationScope(&builder, loopInst->sourceLoc); auto targetBlock = loopInst->getTargetBlock(); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index e44c4079b4..506e6a3350 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -431,6 +431,7 @@ PhiInfo* addPhi( RefPtr phiInfo = new PhiInfo(); context->phiInfos.add(phi, phiInfo); + phi->sourceLoc = var->sourceLoc; phiInfo->phi = phi; phiInfo->var = var; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9305d17830..6c7691d134 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3512,6 +3512,7 @@ namespace Slang auto inst = createInstWithTrailingArgs( this, kIROp_MakeDifferentialPair, type, 2, args); addInst(inst); + inst->sourceLoc = primal->sourceLoc; return inst; } @@ -3524,6 +3525,7 @@ namespace Slang auto inst = createInstWithTrailingArgs( this, kIROp_MakeDifferentialPairUserCode, type, 2, args); addInst(inst); + inst->sourceLoc = primal->sourceLoc; return inst; } diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index c02a009570..b9a12f971e 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -339,6 +339,7 @@ void initCommandOptions(CommandOptions& options) { OptionKind::InputFilesRemain, "--", nullptr, "Treat the rest of the command line as input files."}, { OptionKind::ReportDownstreamTime, "-report-downstream-time", nullptr, "Reports the time spent in the downstream compiler." }, { OptionKind::ReportPerfBenchmark, "-report-perf-benchmark", nullptr, "Reports compiler performance benchmark results." }, + { OptionKind::ReportCheckpointIntermediates, "-report-checkpoint-intermediates", nullptr, "Reports information about checkpoint contexts used for reverse-mode automatic differentiation." }, { OptionKind::SkipSPIRVValidation, "-skip-spirv-validation", nullptr, "Skips spirv validation." }, { OptionKind::SourceEmbedStyle, "-source-embed-style", "-source-embed-style ", "If source embedding is enabled, defines the style used. When enabled (with any style other than `none`), " @@ -1703,6 +1704,7 @@ SlangResult OptionsParser::_parse( case OptionKind::DumpReproOnError: case OptionKind::ReportDownstreamTime: case OptionKind::ReportPerfBenchmark: + case OptionKind::ReportCheckpointIntermediates: case OptionKind::SkipSPIRVValidation: case OptionKind::DisableSpecialization: case OptionKind::DisableDynamicDispatch: diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang index 5172970135..3d6e9e702f 100644 --- a/tests/autodiff/reverse-checkpoint-1.slang +++ b/tests/autodiff/reverse-checkpoint-1.slang @@ -2,6 +2,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj //TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -16,13 +17,16 @@ float g(float x) return log(x); } +//CHK: note: checkpointing context of 4 bytes associated with function: 'f' [BackwardDifferentiable] float f(int p, float x) { float y = 1.0; // Test that phi parameter can be restored. if (p == 0) + //CHK: note: 4 bytes (float) used to checkpoint the following item: y = g(x); + return y * y; } @@ -41,3 +45,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __bwd_diff(f)(0, dpa, 1.0f); outputBuffer[0] = dpa.d; // Expect: 1 } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang index 8a7262aa4d..1dd3f29638 100644 --- a/tests/autodiff/reverse-checkpoint-2.slang +++ b/tests/autodiff/reverse-checkpoint-2.slang @@ -41,3 +41,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __bwd_diff(f)(0, dpa, 1.0f); outputBuffer[0] = dpa.d; // Expect: 1 } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 0f95026734..0b6e56f783 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -8,11 +9,14 @@ RWStructuredBuffer outputBuffer; typedef DifferentialPair dpfloat; typedef float.Differential dfloat; +//CHK: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue' [BackwardDifferentiable] float test_loop_with_continue(float y) { + //CHK: note: 20 bytes (FixedArray ) used to checkpoint the following item: float t = y; + //CHK: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 3; i++) { if (t > 4.0) @@ -41,3 +45,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = dpa.d; // Expect: 0.0131072 } } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-1.slang b/tests/autodiff/reverse-control-flow-1.slang index 7d2f518be9..334de4137e 100644 --- a/tests/autodiff/reverse-control-flow-1.slang +++ b/tests/autodiff/reverse-control-flow-1.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -40,3 +41,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = dpa.d; // Expect: 1.0 } } + +//CHK: (0): note: no checkpoint contexts to report \ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-2.slang b/tests/autodiff/reverse-control-flow-2.slang index cde707b4d3..c3790367cf 100644 --- a/tests/autodiff/reverse-control-flow-2.slang +++ b/tests/autodiff/reverse-control-flow-2.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -73,3 +74,5 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[1] = dpx.d; } } + +//CHK: (0): note: no checkpoint contexts to report \ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-3.slang b/tests/autodiff/reverse-control-flow-3.slang index 01b5332793..b4fa68e3a3 100644 --- a/tests/autodiff/reverse-control-flow-3.slang +++ b/tests/autodiff/reverse-control-flow-3.slang @@ -1,4 +1,5 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer @@ -75,7 +76,8 @@ void d_getParam(uint id, MaterialParam.Differential diff) outputBuffer[id] += diff.roughness; } - +//CHK-DAG: note: checkpointing context of 8 bytes associated with function: 'updatePathThroughput' +//CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item: [BackwardDifferentiable] void updatePathThroughput(inout PathResult path, const float weight) { @@ -122,9 +124,13 @@ bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, ino \param[in,out] path The path state. \return True if a ray was generated, false otherwise. */ + +//CHK-DAG: note: checkpointing context of 16 bytes associated with function: 'generateScatterRay' [BackwardDifferentiable] bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes) { + //CHK-DAG: note: 8 bytes (s_bwd_prop_updatePathThroughput_Intermediates_0) used to checkpoint the following item: + //CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item: updatePathThroughput(pathRes, bs.val); return true; } @@ -215,5 +221,6 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) var dpx = diffPair(pathRes, pathResD); __bwd_diff(tracePath)(1, dpx); // Expect: 5.0 in outputBuffer[3] } - } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang index fc206e1289..68ad823ac2 100644 --- a/tests/autodiff/reverse-loop-checkpoint-test.slang +++ b/tests/autodiff/reverse-loop-checkpoint-test.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -44,13 +45,18 @@ float3 infinitesimal(float3 x) return x - detach(x); } +//CHK: note: checkpointing context of 20 bytes associated with function: 'computeLoop' [BackwardDifferentiable] [PreferRecompute] float3 computeLoop(float y) { + //CHK: note: 4 bytes (float) used to checkpoint the following item: float w = 0; + + //CHK: note: 12 bytes (Vector ) used to checkpoint the following item: float3 w3 = float3(0, 0, 0); + //CHK: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 8; i++) { float k = compute(i, y); @@ -93,3 +99,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[2] = computeLoop(1.0).x; } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang index a2c826be98..2ba8535bee 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -8,11 +9,14 @@ RWStructuredBuffer outputBuffer; typedef DifferentialPair dpfloat; typedef float.Differential dfloat; +//CHK: note: checkpointing context of 24 bytes associated with function: 'test_simple_loop' [Differentiable] float test_simple_loop(float y) { + //CHK: note: 20 bytes (FixedArray ) used to checkpoint the following item: float t = y; + //CHK: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 3; i++) { t = t * t; @@ -38,3 +42,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = dpa.d; // Expect: 0.0131072 } } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang index caf2df6f8c..3c1a52c21c 100644 --- a/tests/autodiff/reverse-nested-calls.slang +++ b/tests/autodiff/reverse-nested-calls.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -15,9 +16,11 @@ float g(float y) return result * result; } +//CHK: note: checkpointing context of 4 bytes associated with function: 'f' [BackwardDifferentiable] float f(float x) { + //CHK: note: 4 bytes (float) used to checkpoint the following item: return 3.0f * g(2.0f * x); } @@ -29,3 +32,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __bwd_diff(f)(dpa, 1.0f); outputBuffer[0] = dpa.d; // Expect: 96.0 } + +//CHK-NOT: note \ No newline at end of file