diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index f4a34c5aab..a1fa5f21a8 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1079,6 +1079,20 @@ IRInst* getActualInstToTranscribe(IRInst* inst) return inst; } +void handleNameHint(IRBuilder* builder, IRInst* primal, IRInst* diff) +{ + // Ignore types that already have a name hint. + if (as(diff) && diff->findDecoration()) + return; + + if (auto nameHint = primal->findDecoration()) + { + StringBuilder sb; + sb << "s_diff_" << nameHint->getName(); + builder->addNameHintDecoration(diff, sb.getUnownedSlice()); + } +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -1099,7 +1113,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst instsInProgress.remove(origInst); - if (auto primalInst = pair.primal) + if (pair.primal) { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); @@ -1124,12 +1138,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst break; default: // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } + handleNameHint(builder, pair.primal, pair.differential); // Automatically tag the primal and differential results // if they haven't already been handled by the diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 59653c4ae7..9b3e3a324a 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -346,7 +346,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } if (auto originalNameHint = originalFunc->findDecoration()) { - auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName()); + auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName()); builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice())); } diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang index beb983b3b7..5172970135 100644 --- a/tests/autodiff/reverse-checkpoint-1.slang +++ b/tests/autodiff/reverse-checkpoint-1.slang @@ -29,7 +29,7 @@ float f(int p, float x) // Check that there are no calls to primal_g in bwd_f. // CHECK: void s_bwd_f_{{[0-9]+}} -// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}} // CHECK: return diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang index 68ff62176f..8a7262aa4d 100644 --- a/tests/autodiff/reverse-checkpoint-2.slang +++ b/tests/autodiff/reverse-checkpoint-2.slang @@ -29,7 +29,7 @@ float f(int p, float x) // Check that there are no calls to primal_g in bwd_f. // CHECK: void s_bwd_prop_f_{{[0-9]+}} -// CHECK: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}} // CHECK: return