Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IDifferentiablePtrType #5031

Merged
merged 19 commits into from
Sep 19, 2024

Conversation

saipraveenb25
Copy link
Collaborator

@saipraveenb25 saipraveenb25 commented Sep 6, 2024

This patch adds initial support for IDifferentiablePtrType interface, which implements a part of the new diff type system proposal

Concretely, we add a new interface: IDifferentiablePtrType and a new pair type IDifferentialPtrPair<T : IDifferentiablePtrType> that can be used to properly represent differentiable reference, pointer, buffer & resource types that need to be transformed into pairs during auto-diff, but should not be used with the transposition logic.

Our forward-mode pass now treats types conforming to IDifferentiable as differentiable 'value' types and types conforming to IDifferentiablePtrType as differentiable 'ptr' types. Upon differentiation, the differential part of differentiable-ptr-typed insts are marked as 'primal' to instruct the later passes to move them to the primal part and not transpose them during the reverse-mode steps.

We also alter the primal context function's signature rules to translate T : IDifferentiablePtrType to the appropriate DifferentialPtrPair<T>. This is because we treat the full pair as a primal value that may be modified in the function, and so we need to run the context function to record it (if necessary).

Effectively, the semantics for our primal context functions have changed from "run primal computation and record intermediate values" to "run primal computation for differentiable value types & ptr-pair computations for differentiable ptr types and record necessary intermediate values"

Fixes: #4998

@saipraveenb25 saipraveenb25 added pr: new feature pr: non-breaking PRs without breaking changes labels Sep 6, 2024
@@ -1606,6 +1606,7 @@ namespace Slang
DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method

DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement
DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove DMulFunc down below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doing dmul removal in a separate patch

// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
if (conformanceWitness)
if (auto conformanceValWitness = as<Witness>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make isTypeDifferentiable function return the witness instead so we don't have to call isSubtype directly.
This should help improve readability.

@@ -16,6 +16,86 @@

namespace Slang
{

IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these functions be inside IRBuilder?

Copy link
Collaborator Author

@saipraveenb25 saipraveenb25 Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to put this one in IRBuilder because it has fairly non-trivial logic.. IRBuilder has its own emitMakeDifferentialPair & emitMakeDifferentialPtrPair which emit the specific op.
This helper function decides which one to emit based on the conformance of the primal type. We need the DiffTypeConformanceContext to be around, so we can't easily move this into IRBuilder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function isn't using DiffTypeConformanceContext, and feels like it fits well in IRBuilder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move into the direction where IRBuilder can do more folding and instruction selection instead of emitting exactly the opcode being requested.

@@ -445,8 +457,17 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType()));
if (!interfaceType)
return nullptr;
List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath(
List<IRInterfaceRequirementEntry*> lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems not very efficient in that it attempts to lookup twice and then drop one of them. Can we at least not do the lookup for ptr conformance if value conformance is found?

I wonder why there is the need for findInterfaceLookupPath in the first place. For any interfaceType, if we want to know if interfaceType:IDifferentiable, there should be a SubtypeWitness that we can construct in the frontend and register into the dictionary, so we don't have to rediscover them in the backend.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be from before the inheritance flattening patch.. you can have interfaceType:middleInterface:IDifferentiable so this was for emitting lookup insts into the current block, to recurse into the sub-type witness, if the witness is a run-time variable.

Post-flattening, is there no need for this?

@@ -1075,4 +1121,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori
return result;
}


void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all this function saying that anything other than a DifferentiablePtr type should be marked as differential?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It marks insts of T : IDifferentiable as differential (to move into the derivative function & transpose) and T : IDifferentiablePtrPair as primal (to move it to the primal blocks, and not consider it during transposition)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 1147 is what I don't understand.
if diffInst isn't an IRType, then it can only be one of lookupWitness, specialization, extractExistentialType. How is this related to the subject of this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think the misunderstanding may be due to what diffInst is supposed to be.
Here diffInst is any inst (not just a type) that is produced during forward-mode as the derivative counter-part for the primal inst. It is not the type of the inst, but the inst itself, which can be an IRType for cases where we construct types within the method body).
This method is responsible for deciding if this inst requires further processing or not by marking it appropriately.

Perhaps markDiffInst is a better thing to do here.

@@ -514,7 +665,8 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp

IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
return _getDiffTypeFromPairType(sharedContext, builder, type);
return this->differentiateType(builder, type->getValueType());
//return _getDiffTypeFromPairType(sharedContext, builder, type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete commented code.

@@ -722,7 +895,8 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst))
{
differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness());
addTypeToDictionary(pairType->getValueType(), pairType->getWitness());
//differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ccommented code.

Copy link
Collaborator

@csyonghe csyonghe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a lot of duplication can be avoided if we consolidate the two dictionaries in type conformance context.

{
return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType);
if (kind == DiffConformanceKind::Any)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should differentiableTypeConformanceContext.tryGetDifferentiableWitness handle the Any case as well?

@@ -1075,4 +1121,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori
return result;
}


void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 1147 is what I don't understand.
if diffInst isn't an IRType, then it can only be one of lookupWitness, specialization, extractExistentialType. How is this related to the subject of this function?

@@ -16,6 +16,86 @@

namespace Slang
{

IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function isn't using DiffTypeConformanceContext, and feels like it fits well in IRBuilder.

@@ -16,6 +16,86 @@

namespace Slang
{

IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move into the direction where IRBuilder can do more folding and instruction selection instead of emitting exactly the opcode being requested.

{
IRInst* foundResult = nullptr;
differentiableWitnessDictionary.tryGetValue(type, foundResult);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still wonder why do we need two separate dictionaries. A type can either conform to valueType witness or reftype witness but not both. Why can't we consolidate the dictionaries so all lookups can be done in a single run?

{
if (isNoDiffType((IRType*)primalType))
return nullptr;

IRInst* witness = lookUpConformanceForType((IRType*)primalType);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing space.

@csyonghe csyonghe merged commit ccc310f into shader-slang:master Sep 19, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pr: new feature pr: non-breaking PRs without breaking changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement IDifferentiablePtrType interface
2 participants