diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7784100a6a..8dee7b0c5f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -311,6 +311,8 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); + }; template @@ -3660,9 +3662,12 @@ namespace Slang // the work of constructing our synthesized method. // + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); + // First, we check that the differentiabliity of the method matches the requirement, // and we don't attempt to synthesize a method if they don't match. - if (getShared()->getFuncDifferentiableLevel( + if (!isInWrapperType && + getShared()->getFuncDifferentiableLevel( as(lookupResult.item.declRef.getDecl())) < getShared()->getFuncDifferentiableLevel( as(requiredMemberDeclRef.getDecl()))) @@ -3689,7 +3694,7 @@ namespace Slang auto synBase = m_astBuilder->create(); synBase->name = requiredMemberDeclRef.getDecl()->getName(); - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { auto aggTypeDecl = as(context->parentDecl); synBase->lookupResult2 = lookUpMember( @@ -3701,6 +3706,10 @@ namespace Slang LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); addModifier(synFuncDecl, m_astBuilder->create()); + + synFuncDecl->parentDecl = aggTypeDecl; + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); } else { @@ -3714,7 +3723,7 @@ namespace Slang // if (synThis) { - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { // If this is a wrapper type, then use the inner // object as the actual this parameter for the redirected @@ -3723,6 +3732,8 @@ namespace Slang innerExpr->scope = synThis->scope; innerExpr->name = getName("inner"); synBase->base = CheckExpr(innerExpr); + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type); } else { @@ -6066,7 +6077,7 @@ namespace Slang checkVisibility(decl); } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); if (newContext.getParentDifferentiableAttribute()) @@ -6086,7 +6097,12 @@ namespace Slang } m_parentDifferentiableAttr = oldAttr; } + return newContext; + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + auto newContext = registerDifferentiableTypesForFunc(decl); if (const auto body = decl->body) { checkStmt(decl->body, newContext); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1d72fd2333..5ec6fa62a7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -491,6 +491,11 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); + // If we have any witness tables that are marked as `KeepAlive`, + // but are not used for dynamic dispatch, unpin them so we don't + // do unnecessary work to lower them. + unpinWitnessTables(irModule); + simplifyIR(targetProgram, irModule, IRSimplificationOptions::getFast(), sink); if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc)) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4865aa0b5a..cd79f05a60 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -815,6 +815,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(AnyValueSizeDecoration, AnyValueSize, 1, 0) INST(SpecializeDecoration, SpecializeDecoration, 0, 0) INST(SequentialIDDecoration, SequentialIDDecoration, 1, 0) + INST(DynamicDispatchWitnessDecoration, DynamicDispatchWitnessDecoration, 0, 0) INST(StaticRequirementDecoration, StaticRequirementDecoration, 0, 0) INST(DispatchFuncDecoration, DispatchFuncDecoration, 1, 0) INST(TypeConstraintDecoration, TypeConstraintDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 6fc94c6576..6fbccab5c0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -708,6 +708,11 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; +struct IRDynamicDispatchWitnessDecoration : IRDecoration +{ + IR_LEAF_ISA(DynamicDispatchWitnessDecoration) +}; + struct IRAutoDiffOriginalValueDecoration : IRDecoration { enum @@ -4692,6 +4697,11 @@ struct IRBuilder addDecoration(inst, kIROp_SequentialIDDecoration, getIntValue(getUIntType(), id)); } + void addDynamicDispatchWitnessDecoration(IRInst* inst) + { + addDecoration(inst, kIROp_DynamicDispatchWitnessDecoration); + } + void addVulkanRayPayloadDecoration(IRInst* inst, int location) { addDecoration(inst, kIROp_VulkanRayPayloadDecoration, getIntValue(getIntType(), location)); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index e81eddab70..18cb850c0e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -447,6 +447,7 @@ static void cloneExtraDecorationsFromInst( case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: case kIROp_NonCopyableTypeDecoration: + case kIROp_DynamicDispatchWitnessDecoration: if (!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-ir-strip-witness-tables.cpp b/source/slang/slang-ir-strip-witness-tables.cpp index 4c8901c52d..b80bd7c232 100644 --- a/source/slang/slang-ir-strip-witness-tables.cpp +++ b/source/slang/slang-ir-strip-witness-tables.cpp @@ -33,4 +33,21 @@ void stripWitnessTables(IRModule* module) } } +void unpinWitnessTables(IRModule* module) +{ + for (auto inst : module->getGlobalInsts()) + { + auto witnessTable = as(inst); + if (!witnessTable) + continue; + + // If a witness table is not used for dynamic dispatch, unpin it. + if (!witnessTable->findDecoration()) + { + while (auto decor = witnessTable->findDecoration()) + decor->removeAndDeallocate(); + } + } +} + } diff --git a/source/slang/slang-ir-strip-witness-tables.h b/source/slang/slang-ir-strip-witness-tables.h index 43bd0127d0..4e31064189 100644 --- a/source/slang/slang-ir-strip-witness-tables.h +++ b/source/slang/slang-ir-strip-witness-tables.h @@ -7,4 +7,7 @@ struct IRModule; /// Strip the contents of all witness table instructions from the given IR `module` void stripWitnessTables(IRModule* module); -} \ No newline at end of file + + /// Remove [KeepAlive] decorations from witness tables. +void unpinWitnessTables(IRModule* module); +} diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index c5a4da1f60..566e5a878b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10872,7 +10872,7 @@ struct TypeConformanceIRGenContext auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness()); builder->addKeepAliveDecoration(witness); builder->addHLSLExportDecoration(witness); - + builder->addDynamicDispatchWitnessDecoration(witness); if (conformanceIdOverride != -1) { builder->addSequentialIDDecoration(witness, conformanceIdOverride); diff --git a/tools/gfx-unit-test/link-time-type.cpp b/tools/gfx-unit-test/link-time-type.cpp index a522b69038..32a6b6775e 100644 --- a/tools/gfx-unit-test/link-time-type.cpp +++ b/tools/gfx-unit-test/link-time-type.cpp @@ -16,7 +16,12 @@ namespace gfx_test slang::ProgramLayout*& slangReflection) { const char* moduleInterfaceSrc = R"( - interface IFoo + interface IBase : IDifferentiable + { + [Differentiable] + float getBaseValue(); + } + interface IFoo : IBase { static const int offset; [mutating] void setValue(float v); @@ -29,6 +34,8 @@ namespace gfx_test static const int offset = -1; [mutating] void setValue(float v) { val = v; } float getValue() { return val + 1.0; } + [Differentiable] + float getBaseValue() { return val; } property float val2 { get { return val + 2.0; } set { val = newValue; } @@ -44,7 +51,7 @@ namespace gfx_test { Foo foo; foo.setValue(3.0); - buffer[0] = foo.getValue() + foo.val2 + Foo.offset; + buffer[0] = foo.getValue() + foo.val2 + Foo.offset + foo.getBaseValue(); } )"; const char* module1Src = R"( @@ -169,7 +176,7 @@ namespace gfx_test compareComputeResult( device, numbersBuffer, - Slang::makeArray(8.0)); + Slang::makeArray(11.0)); } SLANG_UNIT_TEST(linkTimeTypeD3D12)