Skip to content

Commit

Permalink
Support extension on generic type. (#4968)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Aug 30, 2024
1 parent 49862e7 commit 24df551
Show file tree
Hide file tree
Showing 13 changed files with 406 additions and 102 deletions.
12 changes: 11 additions & 1 deletion source/slang/slang-ast-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class SharedASTBuilder : public RefObject

ASTBuilder* getInnerASTBuilder() { return m_astBuilder; }

Name* getThisTypeName()
{
if (!m_thisTypeName)
{
m_thisTypeName = getNamePool()->getName("This");
}
return m_thisTypeName;
}
protected:
// State shared between ASTBuilders

Expand Down Expand Up @@ -105,6 +113,8 @@ class SharedASTBuilder : public RefObject

NamePool* m_namePool = nullptr;

Name* m_thisTypeName = nullptr;

// This is a private builder used for these shared types
ASTBuilder* m_astBuilder = nullptr;
Session* m_session = nullptr;
Expand Down Expand Up @@ -289,7 +299,7 @@ class ASTBuilder : public RefObject
auto interfaceDecl = create<InterfaceDecl>();
// Always include a `This` member and a `This:IThisInterface` member.
auto thisDecl = create<ThisTypeDecl>();
thisDecl->nameAndLoc.name = m_sharedASTBuilder->getNamePool()->getName(UnownedStringSlice("This", 4));
thisDecl->nameAndLoc.name = getSharedASTBuilder()->getThisTypeName();
thisDecl->nameAndLoc.loc = loc;
interfaceDecl->addMember(thisDecl);
auto thisConstraint = create<ThisTypeConstraintDecl>();
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-check-conformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace Slang
return result;
result = checkAndConstructSubtypeWitness(subType, superType, isSubTypeOptions);

if(int(isSubTypeOptions) & int(IsSubTypeOptions::NotReadyForLookup))
if(!result && (int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching)))
return result;

getShared()->cacheSubtypeWitness(subType, superType, result);
Expand Down Expand Up @@ -112,7 +112,7 @@ namespace Slang
// tangling convertibility into it.

// First, make sure both sub type and super type decl are ready for lookup.
if ( !(int(isSubTypeOptions) & int(IsSubTypeOptions::NotReadyForLookup)) )
if ( !(int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching)) )
{
if (auto subDeclRefType = as<DeclRefType>(subType))
{
Expand Down
36 changes: 32 additions & 4 deletions source/slang/slang-check-constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,17 @@ namespace Slang
Type* interfaceType)
{
// The most basic test here should be: does the type declare conformance to the trait.
if(isSubtype(type, interfaceType, IsSubTypeOptions::None))
if (isSubtype(type, interfaceType, constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None))
return type;

// If additional subtype witnesses are provided for `type` in `constraints`,
// try to use them to see if the interface is satisfied.
if (constraints->subTypeForAdditionalWitnesses == type)
{
if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType))
return type;
}

// Just because `type` doesn't conform to the given `interfaceDeclRef`, that
// doesn't necessarily indicate a failure. It is possible that we have a call
// like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is
Expand Down Expand Up @@ -183,6 +191,15 @@ namespace Slang
return type;
}
}
if (constraints->subTypeForAdditionalWitnesses)
{
for (auto witnessKV : *constraints->additionalSubtypeWitnesses)
{
auto unificationResult = TryUnifyTypes(*constraints, ValUnificationContext(), QualType(witnessKV.first), interfaceType);
if (unificationResult)
return type;
}
}
}
}

Expand Down Expand Up @@ -610,15 +627,15 @@ namespace Slang

HashSet<Decl*> constrainedGenericParams;

for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
for (auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>())
{
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
genericDeclRef, args.getArrayView().arrayView, constraintDecl).as<GenericTypeConstraintDecl>();

// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
auto sup = getSup(m_astBuilder, constraintDeclRef);

// Mark sub type as constrained.
if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type))
constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl());
Expand All @@ -636,7 +653,18 @@ namespace Slang
}

// Search for a witness that shows the constraint is satisfied.
auto subTypeWitness = isSubtype(sub, sup, IsSubTypeOptions::None);
auto subTypeWitness = isSubtype(
sub,
sup,
system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None);
if (!subTypeWitness)
{
if (sub == system->subTypeForAdditionalWitnesses)
{
// If no witness was found, try to find the witness from additional witness.
system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness);
}
}
if(subTypeWitness)
{
// We found a witness, so it will become an (implicit) argument.
Expand Down
55 changes: 44 additions & 11 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,12 +1075,13 @@ namespace Slang
DeclRef<ExtensionDecl> applyExtensionToType(
SemanticsVisitor* semantics,
ExtensionDecl* extDecl,
Type* type)
Type* type,
Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitness)
{
if(!semantics)
return DeclRef<ExtensionDecl>();

return semantics->applyExtensionToType(extDecl, type);
return semantics->applyExtensionToType(extDecl, type, additionalSubtypeWitness);
}

bool SemanticsVisitor::isDeclUsableAsStaticMember(
Expand Down Expand Up @@ -6320,7 +6321,7 @@ namespace Slang
{
// Force add IDefaultInitializable to any struct missing (transitively) `IDefaultInitializable`.
auto* defaultInitializableType = m_astBuilder->getDefaultInitializableType();
if(!isSubtype(DeclRefType::create(m_astBuilder, decl), defaultInitializableType, IsSubTypeOptions::NotReadyForLookup))
if(!isSubtype(DeclRefType::create(m_astBuilder, decl), defaultInitializableType, IsSubTypeOptions::NoCaching))
{
InheritanceDecl* conformanceDecl = m_astBuilder->create<InheritanceDecl>();
conformanceDecl->parentDecl = decl;
Expand Down Expand Up @@ -8253,7 +8254,31 @@ namespace Slang

return;
}
else if (auto genericTypeParamDecl = targetDeclRefType->getDeclRef().as<GenericTypeParamDecl>())
{
// If we are extending a generic type parameter as in `extension<T:IFoo> T`,
// we want to register the extension with the interface type `IFoo` instead.
auto genericDecl = as<GenericDecl>(genericTypeParamDecl.getDecl()->parentDecl);
if (!genericDecl)
goto error;
if (genericDecl != decl->parentDecl)
goto error;
bool isTypeConstrained = false;
for (auto constraintDecl : genericDecl->getMembersOfType<GenericTypeConstraintDecl>())
{
ensureDecl(constraintDecl, DeclCheckState::ReadyForReference);
if (targetDeclRefType == constraintDecl->sub.type)
{
auto supTypeDeclRef = isDeclRefTypeOf<AggTypeDecl>(constraintDecl->sup.type);
getShared()->registerCandidateExtension(supTypeDeclRef.getDecl(), decl);
isTypeConstrained = true;
}
}
if (isTypeConstrained)
return;
}
}
error:;
if (!as<ErrorType>(decl->targetType.type))
{
getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnType, decl->targetType);
Expand Down Expand Up @@ -8357,6 +8382,15 @@ namespace Slang
//
return DeclRefType::create(m_astBuilder, aggTypeDeclRef);
}
else if (auto genTypeParam = declRef.as<GenericTypeParamDecl>())
{
// We will reach here when we are checking `extension<T> T {...}`,
// where inside the extension, `This` type is the target type
// of the extension, in this case this is a DeclRefType to
// a GenericTypeParamDecl.
//
return DeclRefType::create(m_astBuilder, declRef);
}
else if (auto extDeclRef = declRef.as<ExtensionDecl>())
{
// In the body of an `extension`, the `This`
Expand Down Expand Up @@ -8661,7 +8695,8 @@ namespace Slang

DeclRef<ExtensionDecl> SemanticsVisitor::applyExtensionToType(
ExtensionDecl* extDecl,
Type* type)
Type* type,
Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitnessesForType)
{
DeclRef<ExtensionDecl> extDeclRef = makeDeclRef(extDecl);

Expand All @@ -8674,20 +8709,18 @@ namespace Slang
ConstraintSystem constraints;
constraints.loc = extDecl->loc;
constraints.genericDecl = extGenericDecl;
if (additionalSubtypeWitnessesForType)
{
constraints.subTypeForAdditionalWitnesses = type;
constraints.additionalSubtypeWitnesses = additionalSubtypeWitnessesForType;
}

// Inside the body of an extension declaration, we may end up trying to apply that
// extension to its own target type.
// If we see that we are in that case, we can apply the extension declaration as - is,
// without any additional substitutions.
if (extDecl->targetType->equals(type))
{
/*
auto subst = trySolveConstraintSystem(
&constraints,
DeclRef<Decl>(extGenericDecl, nullptr).as<GenericDecl>(),
as<GenericSubstitution>(as<DeclRefType>(type)->declRef.substitutions.substitutions));
return DeclRef<Decl>(extDecl, subst).as<ExtensionDecl>();
*/
return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as<ExtensionDecl>();
}

Expand Down
52 changes: 43 additions & 9 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ namespace Slang
enum class IsSubTypeOptions
{
None = 0,
/// Type may not be finished 'DeclCheckState::ReadyForLookup`
NotReadyForLookup = 1 << 0,

/// A type may not be finished 'DeclCheckState::ReadyForLookup` while `isSubType` is called.
/// We should not cache any negative results when this flag is set.
NoCaching = 1 << 0,
};

/// Should the given `decl` be treated as a static rather than instance declaration?
Expand Down Expand Up @@ -686,11 +688,37 @@ namespace Slang
FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit);
FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func);

struct InheritanceCircularityInfo
{
InheritanceCircularityInfo(
Decl* decl,
InheritanceCircularityInfo* next)
: decl(decl)
, next(next)
{}

/// A declaration whose inheritance is being calculated
Decl* decl = nullptr;

/// The rest of the links in the chain of declarations being processed
InheritanceCircularityInfo* next = nullptr;
};

/// Get the processed inheritance information for `type`, including all its facets
InheritanceInfo getInheritanceInfo(Type* type);
InheritanceInfo getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo = nullptr);

/// Get the processed inheritance information for `extension`, including all its facets
InheritanceInfo getInheritanceInfo(DeclRef<ExtensionDecl> const& extension);
InheritanceInfo getInheritanceInfo(DeclRef<ExtensionDecl> const& extension, InheritanceCircularityInfo* circularityInfo = nullptr);

/// Prevent an unsupported case of
/// ```
/// extension<T:IFoo> : IBar{};
/// extesnion<T:IBar> : IFoo{};
/// ```
/// from causing infinite recursion.
bool _checkForCircularityInExtensionTargetType(
Decl* decl,
InheritanceCircularityInfo* circularityInfo);

/// Try get subtype witness from cache, returns true if cache contains a result for the query.
bool tryGetSubtypeWitnessFromCache(Type* sub, Type* sup, SubtypeWitness*& outWitness)
Expand Down Expand Up @@ -734,9 +762,9 @@ namespace Slang

ASTBuilder* _getASTBuilder() { return m_linkage->getASTBuilder(); }

InheritanceInfo _getInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType);
InheritanceInfo _calcInheritanceInfo(Type* type);
InheritanceInfo _calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType);
InheritanceInfo _getInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo);
InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo);
InheritanceInfo _calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo);

struct DirectBaseInfo
{
Expand Down Expand Up @@ -2101,6 +2129,10 @@ namespace Slang
// Constraints we have accumulated, which constrain
// the possible arguments for those parameters.
List<Constraint> constraints;

// Additional subtype witnesses available to the currentt constraint solving context.
Type* subTypeForAdditionalWitnesses = nullptr;
Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitnesses = nullptr;
};

Type* TryJoinVectorAndScalarType(
Expand Down Expand Up @@ -2536,7 +2568,8 @@ namespace Slang
// Is the candidate extension declaration actually applicable to the given type
DeclRef<ExtensionDecl> applyExtensionToType(
ExtensionDecl* extDecl,
Type* type);
Type* type,
Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitnessesForType = nullptr);

// Take a generic declaration that is being applied
// in a context and attempt to infer any missing generic
Expand Down Expand Up @@ -2702,7 +2735,8 @@ namespace Slang
DeclRef<ExtensionDecl> applyExtensionToType(
SemanticsVisitor* semantics,
ExtensionDecl* extDecl,
Type* type);
Type* type,
Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitness = nullptr);


struct SemanticsExprVisitor
Expand Down
Loading

0 comments on commit 24df551

Please sign in to comment.