Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report AD checkpoint contexts #5058

Merged
merged 35 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
61c34d0
Transferring source locations when creating phi instructions
venkataram-nv Sep 6, 2024
95532d2
Tracking for simple variables
venkataram-nv Sep 6, 2024
bae1c71
Deriving source locations for loop counters
venkataram-nv Sep 6, 2024
490933b
Printing checkpoint structure breakdown
venkataram-nv Sep 9, 2024
95eb9df
More readable output format
venkataram-nv Sep 9, 2024
401185f
Special behavior for loop counters
venkataram-nv Sep 9, 2024
d7577d9
Writing report to file
venkataram-nv Sep 9, 2024
25c3fd4
Add slangc option to enable checkpoint reports
venkataram-nv Sep 9, 2024
e1ae8f4
Display types of checkpointed fields
venkataram-nv Sep 9, 2024
1d96da7
Message in case there are no checkpointing contexts
venkataram-nv Sep 9, 2024
d7ebae6
Catch source locations for function calls
venkataram-nv Sep 10, 2024
6060bdb
Source cleanup
venkataram-nv Sep 10, 2024
0cf45fe
Fix compilation warnings
venkataram-nv Sep 11, 2024
d96bd61
Remove stray dump()
venkataram-nv Sep 16, 2024
7f44e44
Merge branch 'master' into report-ad-checkpoint-info
venkataram-nv Sep 16, 2024
9e1439f
Provide the report through diagnostic notes
venkataram-nv Sep 17, 2024
b402855
Add missing path for sourceLoc during unzip pass
venkataram-nv Sep 17, 2024
f1c6200
Add tests for reporting intermediates
venkataram-nv Sep 17, 2024
5fadff3
Include more transfer cases for source locations
venkataram-nv Sep 17, 2024
0c17bcf
Fix ordering in address elimination
venkataram-nv Sep 18, 2024
62f2115
Fill in more holes with source location transfer
venkataram-nv Sep 18, 2024
9b222d1
Merge remote-tracking branch 'original/master' into report-ad-checkpo…
venkataram-nv Sep 18, 2024
52f1310
Remove debugging line
venkataram-nv Sep 18, 2024
8779f24
Reverting changes to diagnostic sink
venkataram-nv Sep 18, 2024
c9f1e5b
Simplify address elimination using source location RAII contexts
venkataram-nv Sep 18, 2024
5efd9c8
Eliminating manual source loc transfers in forward transcription
venkataram-nv Sep 18, 2024
6f82f6e
Fix local var adaptation to use RAII location setter
venkataram-nv Sep 18, 2024
83c68c1
Simplify primal hoisting logic for source location transfer
venkataram-nv Sep 18, 2024
a85fba7
Simplify unzipping with RAII location scopes
venkataram-nv Sep 18, 2024
e876019
Simplify transpose logic
venkataram-nv Sep 18, 2024
e4ab417
Cleaning up for rev.cpp
venkataram-nv Sep 18, 2024
bc1587e
Reverting spacing changes
venkataram-nv Sep 18, 2024
ca83b4b
Fix mistake with source loc RAII instantiation
venkataram-nv Sep 18, 2024
67eb327
Merge remote-tracking branch 'original/master' into report-ad-checkpo…
venkataram-nv Sep 18, 2024
1960e23
Fix formatting issues
venkataram-nv Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ extern "C"
EmitIr, // bool
ReportDownstreamTime, // bool
ReportPerfBenchmark, // bool
ReportCheckpointIntermediates, // bool
SkipSPIRVValidation, // bool
SourceEmbedStyle,
SourceEmbedName,
Expand Down
1 change: 1 addition & 0 deletions source/slang-record-replay/util/emum-to-string.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ namespace SlangRecord
CASE(EmitIr);
CASE(ReportDownstreamTime);
CASE(ReportPerfBenchmark);
CASE(ReportCheckpointIntermediates);
CASE(SkipSPIRVValidation);
CASE(SourceEmbedStyle);
CASE(SourceEmbedName);
Expand Down
6 changes: 5 additions & 1 deletion source/slang/slang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2451,12 +2451,16 @@ namespace Slang
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}

bool CodeGenContext::shouldReportCheckpointIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates);
}

bool CodeGenContext::shouldDumpIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates);
}


bool CodeGenContext::shouldTrackLiveness()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness);
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -2728,6 +2728,7 @@ namespace Slang

bool shouldValidateIR();
bool shouldDumpIR();
bool shouldReportCheckpointIntermediates();

bool shouldTrackLiveness();

Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,12 @@ DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage B

DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.")

// Autodiff checkpoint reporting
DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'")
DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:")
DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:")
DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report")

//
// 8xxxx - Issues specific to a particular library/technology/platform/etc.
//
Expand Down
67 changes: 67 additions & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "slang-ir-wgsl-legalize.h"
#include "slang-ir-insts.h"
#include "slang-ir-inline.h"
#include "slang-ir-layout.h"
#include "slang-ir-legalize-array-return-type.h"
#include "slang-ir-legalize-mesh-outputs.h"
#include "slang-ir-legalize-varying-params.h"
Expand Down Expand Up @@ -214,6 +215,68 @@ static void dumpIRIfEnabled(
}
}

static void reportCheckpointIntermediates(CodeGenContext* codeGenContext, DiagnosticSink* sink, IRModule* irModule)
{
// Report checkpointing information
CompilerOptionSet& optionSet = codeGenContext->getTargetProgram()->getOptionSet();
SourceManager* sourceManager = sink->getSourceManager();

SourceWriter typeWriter(sourceManager, LineDirectiveMode::None, nullptr);

CLikeSourceEmitter::Desc description;
description.codeGenContext = codeGenContext;
description.sourceWriter = &typeWriter;

CPPSourceEmitter emitter(description);

int nonEmptyStructs = 0;
for (auto inst : irModule->getGlobalInsts())
{
IRStructType *structType = as<IRStructType>(inst);
if (!structType)
continue;

auto checkpointDecoration = structType->findDecoration<IRCheckpointIntermediateDecoration>();
if (!checkpointDecoration)
continue;

IRSizeAndAlignment structSize;
getNaturalSizeAndAlignment(optionSet, structType, &structSize);

// Reporting happens before empty structs are optimized out
// and we still want to keep the checkpointing decorations,
// so we end up needing to check for non-zero-ness
if (structSize.size == 0)
continue;

auto func = checkpointDecoration->getSourceFunction();
sink->diagnose(structType, Diagnostics::reportCheckpointIntermediates, func, structSize.size);
nonEmptyStructs++;

for (auto field : structType->getFields())
{
IRType *fieldType = field->getFieldType();
IRSizeAndAlignment fieldSize;
getNaturalSizeAndAlignment(optionSet, fieldType, &fieldSize);
if (fieldSize.size == 0)
continue;

typeWriter.clearContent();
emitter.emitType(fieldType);

sink->diagnose(field->sourceLoc,
field->findDecoration<IRLoopCounterDecoration>()
? Diagnostics::reportCheckpointCounter
: Diagnostics::reportCheckpointVariable,
fieldSize.size,
typeWriter.getContent());
}
}

if (nonEmptyStructs == 0)
sink->diagnose(SourceLoc(), Diagnostics::reportCheckpointNone);
}

struct LinkingAndOptimizationOptions
{
bool shouldLegalizeExistentialAndResourceTypes = true;
Expand Down Expand Up @@ -767,6 +830,10 @@ Result linkAndOptimizeIR(
break;
}

// Report checkpointing information
if (codeGenContext->shouldReportCheckpointIntermediates())
reportCheckpointIntermediates(codeGenContext, sink, irModule);

if (requiredLoweringPassSet.autodiff)
finalizeAutoDiffPass(targetProgram, irModule);

Expand Down
18 changes: 9 additions & 9 deletions source/slang/slang-ir-addr-inst-elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,28 @@ struct AddressInstEliminationContext
}
}

void transformLoadAddr(IRUse* use)
void transformLoadAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto load = as<IRLoad>(use->getUser());

IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
auto value = getValue(builder, addr);
load->replaceUsesWith(value);
load->removeAndDeallocate();
}

void transformStoreAddr(IRUse* use)
void transformStoreAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto store = as<IRStore>(use->getUser());

IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
storeValue(builder, addr, store->getVal());
store->removeAndDeallocate();
}

void transformCallAddr(IRUse* use)
void transformCallAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto call = as<IRCall>(use->getUser());
Expand All @@ -103,7 +101,6 @@ struct AddressInstEliminationContext
return;
}

IRBuilder builder(module);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());

Expand Down Expand Up @@ -155,17 +152,20 @@ struct AddressInstEliminationContext
use = nextUse;
continue;
}

IRBuilder transformBuilder(module);
IRBuilderSourceLocRAII sourceLocationScope(&transformBuilder, use->getUser()->sourceLoc);

switch (use->getUser()->getOp())
{
case kIROp_Load:
transformLoadAddr(use);
transformLoadAddr(transformBuilder, use);
break;
case kIROp_Store:
transformStoreAddr(use);
transformStoreAddr(transformBuilder, use);
break;
case kIROp_Call:
transformCallAddr(use);
transformCallAddr(transformBuilder, use);
break;
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
Expand Down
42 changes: 30 additions & 12 deletions source/slang/slang-ir-autodiff-primal-hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
#include "slang-ir-autodiff-region.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-util.h"
#include "../core/slang-func-ptr.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
#include "../core/slang-func-ptr.h"

namespace Slang
{
Expand Down Expand Up @@ -1092,7 +1093,8 @@ IRType* getTypeForLocalStorage(
IRVar* emitIndexedLocalVar(
IRBlock* varBlock,
IRType* baseType,
const List<IndexTrackingInfo>& defBlockIndices)
const List<IndexTrackingInfo>& defBlockIndices,
SourceLoc location)
{
// Cannot store pointers. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
Expand All @@ -1101,6 +1103,8 @@ IRVar* emitIndexedLocalVar(
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));

IRBuilder varBuilder(varBlock->getModule());
IRBuilderSourceLocRAII sourceLocationScope(&varBuilder, location);

varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst());

IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices);
Expand Down Expand Up @@ -1179,9 +1183,14 @@ IRVar* storeIndexedValue(
IRInst* instToStore,
const List<IndexTrackingInfo>& defBlockIndices)
{
IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices);
IRVar* localVar = emitIndexedLocalVar(defaultVarBlock,
instToStore->getDataType(),
defBlockIndices,
instToStore->sourceLoc);

IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices);
IRInst* addr = emitIndexedStoreAddressForVar(builder,
localVar,
defBlockIndices);

builder->emitStore(addr, instToStore);

Expand Down Expand Up @@ -1574,12 +1583,16 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
// region, that means there's no need to allocate a fully indexed var.
//
defBlockIndices = maybeTrimIndices(defBlockIndices, indexedBlockInfo, outOfScopeUses);

IRVar* localVar = storeIndexedValue(
&builder,
varBlock,
builder.emitLoad(varToStore),
defBlockIndices);

IRVar* localVar = nullptr;
{
IRBuilderSourceLocRAII sourceLocationScope(&builder, varToStore->sourceLoc);
localVar = storeIndexedValue(
&builder,
varBlock,
builder.emitLoad(varToStore),
defBlockIndices);
}

for (auto use : outOfScopeUses)
{
Expand Down Expand Up @@ -1626,6 +1639,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
}
else
{
IRBuilderSourceLocRAII sourceLocationScope(&builder, instToStore->sourceLoc);

// Handle the special case of loop counters.
// The only case where there will be a reference of primal loop counter from rev blocks
// is the start of a loop in the reverse code. Since loop counters are not considered a
Expand All @@ -1643,6 +1658,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(

setInsertAfterOrdinaryInst(&builder, instToStore);
auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices);
if (isLoopCounter)
builder.addLoopCounterDecoration(localVar);

for (auto use : outOfScopeUses)
{
Expand Down Expand Up @@ -1728,6 +1745,8 @@ static IRBlock* getUpdateBlock(IRLoop* loop)
void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam)
{
IRBuilder builder(primalLoop);
IRBuilderSourceLocRAII sourceLocationScope(&builder, primalLoop->sourceLoc);

primalCountParam = nullptr;

// Grab first primal block.
Expand Down Expand Up @@ -1899,8 +1918,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
// necessary load/store logic.
//
primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
return primalsInfo;
return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
}

void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
Expand Down
Loading
Loading