-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support IDifferentiablePtrType
#5031
Support IDifferentiablePtrType
#5031
Conversation
@@ -1606,6 +1606,7 @@ namespace Slang | |||
DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method | |||
|
|||
DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement | |||
DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove DMulFunc
down below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm doing dmul
removal in a separate patch
source/slang/slang-check-expr.cpp
Outdated
// Check if the provided type inherits from IDifferentiable. | ||
// If not, return the original type. | ||
if (conformanceWitness) | ||
if (auto conformanceValWitness = as<Witness>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make isTypeDifferentiable
function return the witness instead so we don't have to call isSubtype
directly.
This should help improve readability.
@@ -16,6 +16,86 @@ | |||
|
|||
namespace Slang | |||
{ | |||
|
|||
IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these functions be inside IRBuilder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to put this one in IRBuilder
because it has fairly non-trivial logic.. IRBuilder
has its own emitMakeDifferentialPair
& emitMakeDifferentialPtrPair
which emit the specific op.
This helper function decides which one to emit based on the conformance of the primal type. We need the DiffTypeConformanceContext
to be around, so we can't easily move this into IRBuilder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function isn't using DiffTypeConformanceContext, and feels like it fits well in IRBuilder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should move into the direction where IRBuilder can do more folding and instruction selection instead of emitting exactly the opcode being requested.
@@ -445,8 +457,17 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* | |||
auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType())); | |||
if (!interfaceType) | |||
return nullptr; | |||
List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath( | |||
List<IRInterfaceRequirementEntry*> lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems not very efficient in that it attempts to lookup twice and then drop one of them. Can we at least not do the lookup for ptr conformance if value conformance is found?
I wonder why there is the need for findInterfaceLookupPath
in the first place. For any interfaceType
, if we want to know if interfaceType:IDifferentiable
, there should be a SubtypeWitness
that we can construct in the frontend and register into the dictionary, so we don't have to rediscover them in the backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be from before the inheritance flattening patch.. you can have interfaceType:middleInterface:IDifferentiable
so this was for emitting lookup insts into the current block, to recurse into the sub-type witness, if the witness is a run-time variable.
Post-flattening, is there no need for this?
@@ -1075,4 +1121,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori | |||
return result; | |||
} | |||
|
|||
|
|||
void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is all this function saying that anything other than a DifferentiablePtr type should be marked as differential?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It marks insts of T : IDifferentiable
as differential (to move into the derivative function & transpose) and T : IDifferentiablePtrPair
as primal (to move it to the primal blocks, and not consider it during transposition)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 1147 is what I don't understand.
if diffInst
isn't an IRType
, then it can only be one of lookupWitness, specialization, extractExistentialType. How is this related to the subject of this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think the misunderstanding may be due to what diffInst
is supposed to be.
Here diffInst
is any inst (not just a type) that is produced during forward-mode as the derivative counter-part for the primal inst. It is not the type of the inst, but the inst itself, which can be an IRType
for cases where we construct types within the method body).
This method is responsible for deciding if this inst requires further processing or not by marking it appropriately.
Perhaps markDiffInst
is a better thing to do here.
source/slang/slang-ir-autodiff.cpp
Outdated
@@ -514,7 +665,8 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp | |||
|
|||
IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) | |||
{ | |||
return _getDiffTypeFromPairType(sharedContext, builder, type); | |||
return this->differentiateType(builder, type->getValueType()); | |||
//return _getDiffTypeFromPairType(sharedContext, builder, type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete commented code.
source/slang/slang-ir-autodiff.cpp
Outdated
@@ -722,7 +895,8 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() | |||
{ | |||
if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst)) | |||
{ | |||
differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); | |||
addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); | |||
//differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ccommented code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a lot of duplication can be avoided if we consolidate the two dictionaries in type conformance context.
{ | ||
return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType); | ||
if (kind == DiffConformanceKind::Any) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should differentiableTypeConformanceContext.tryGetDifferentiableWitness
handle the Any
case as well?
@@ -1075,4 +1121,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori | |||
return result; | |||
} | |||
|
|||
|
|||
void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 1147 is what I don't understand.
if diffInst
isn't an IRType
, then it can only be one of lookupWitness, specialization, extractExistentialType. How is this related to the subject of this function?
@@ -16,6 +16,86 @@ | |||
|
|||
namespace Slang | |||
{ | |||
|
|||
IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function isn't using DiffTypeConformanceContext, and feels like it fits well in IRBuilder.
@@ -16,6 +16,86 @@ | |||
|
|||
namespace Slang | |||
{ | |||
|
|||
IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should move into the direction where IRBuilder can do more folding and instruction selection instead of emitting exactly the opcode being requested.
source/slang/slang-ir-autodiff.cpp
Outdated
{ | ||
IRInst* foundResult = nullptr; | ||
differentiableWitnessDictionary.tryGetValue(type, foundResult); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still wonder why do we need two separate dictionaries. A type can either conform to valueType witness or reftype witness but not both. Why can't we consolidate the dictionaries so all lookups can be done in a single run?
{ | ||
if (isNoDiffType((IRType*)primalType)) | ||
return nullptr; | ||
|
||
IRInst* witness = lookUpConformanceForType((IRType*)primalType); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trailing space.
This patch adds initial support for
IDifferentiablePtrType
interface, which implements a part of the new diff type system proposalConcretely, we add a new interface:
IDifferentiablePtrType
and a new pair typeIDifferentialPtrPair<T : IDifferentiablePtrType>
that can be used to properly represent differentiable reference, pointer, buffer & resource types that need to be transformed into pairs during auto-diff, but should not be used with the transposition logic.Our forward-mode pass now treats types conforming to
IDifferentiable
as differentiable 'value' types and types conforming toIDifferentiablePtrType
as differentiable 'ptr' types. Upon differentiation, the differential part of differentiable-ptr-typed insts are marked as 'primal' to instruct the later passes to move them to the primal part and not transpose them during the reverse-mode steps.We also alter the primal context function's signature rules to translate
T : IDifferentiablePtrType
to the appropriateDifferentialPtrPair<T>
. This is because we treat the full pair as a primal value that may be modified in the function, and so we need to run the context function to record it (if necessary).Effectively, the semantics for our primal context functions have changed from "run primal computation and record intermediate values" to "run primal computation for differentiable value types & ptr-pair computations for differentiable ptr types and record necessary intermediate values"
Fixes: #4998