From 714fdaaaf751656ac2b9a48ca95530343e82d215 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 23 Jul 2024 11:28:41 -0400 Subject: [PATCH 1/2] Fix auto-diff synthesized method naming conventions --- .../slang-ir-autodiff-transcriber-base.cpp | 21 +++++++++++++------ source/slang/slang-ir-autodiff-unzip.cpp | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index f4a34c5aab..0bbe7f63e1 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1079,6 +1079,20 @@ IRInst* getActualInstToTranscribe(IRInst* inst) return inst; } +void handleNameHint(IRBuilder* builder, IRInst* primal, IRInst* diff) +{ + // Ignore types that already have a name hint. + if (as(diff) && diff->findDecoration()) + return; + + if (auto nameHint = primal->findDecoration()) + { + StringBuilder sb; + sb << "s_diff_" << nameHint->getName(); + builder->addNameHintDecoration(diff, sb.getUnownedSlice()); + } +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -1124,12 +1138,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst break; default: // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } + handleNameHint(builder, pair.primal, pair.differential); // Automatically tag the primal and differential results // if they haven't already been handled by the diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 59653c4ae7..9b3e3a324a 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -346,7 +346,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } if (auto originalNameHint = originalFunc->findDecoration()) { - auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName()); + auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName()); builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice())); } From 41a272d7bb1942fd724f32052747bd183c2ffcb6 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 6 Aug 2024 17:45:29 -0400 Subject: [PATCH 2/2] Update tests; remove unused var --- source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 +- tests/autodiff/reverse-checkpoint-1.slang | 2 +- tests/autodiff/reverse-checkpoint-2.slang | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 0bbe7f63e1..a1fa5f21a8 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1113,7 +1113,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst instsInProgress.remove(origInst); - if (auto primalInst = pair.primal) + if (pair.primal) { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang index beb983b3b7..5172970135 100644 --- a/tests/autodiff/reverse-checkpoint-1.slang +++ b/tests/autodiff/reverse-checkpoint-1.slang @@ -29,7 +29,7 @@ float f(int p, float x) // Check that there are no calls to primal_g in bwd_f. // CHECK: void s_bwd_f_{{[0-9]+}} -// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}} // CHECK: return diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang index 68ff62176f..8a7262aa4d 100644 --- a/tests/autodiff/reverse-checkpoint-2.slang +++ b/tests/autodiff/reverse-checkpoint-2.slang @@ -29,7 +29,7 @@ float f(int p, float x) // Check that there are no calls to primal_g in bwd_f. // CHECK: void s_bwd_prop_f_{{[0-9]+}} -// CHECK: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}} // CHECK: return