From 24df5515d6c2f8537683d0e48d27a161c394e7cd Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 30 Aug 2024 16:32:34 -0700 Subject: [PATCH] Support extension on generic type. (#4968) --- source/slang/slang-ast-builder.h | 12 +- source/slang/slang-check-conformance.cpp | 4 +- source/slang/slang-check-constraint.cpp | 36 ++- source/slang/slang-check-decl.cpp | 55 ++++- source/slang/slang-check-impl.h | 52 ++++- source/slang/slang-check-inheritance.cpp | 205 ++++++++++++------ source/slang/slang-diagnostic-defs.h | 3 +- source/slang/slang-lookup.cpp | 6 +- source/slang/slang-parser.cpp | 15 +- .../extensions/generic-extension-1.slang | 49 +++++ .../extensions/generic-extension-2.slang | 30 +++ .../extensions/interface-extension.slang | 3 +- .../extensions/this-in-extension.slang | 38 ++++ 13 files changed, 406 insertions(+), 102 deletions(-) create mode 100644 tests/language-feature/extensions/generic-extension-1.slang create mode 100644 tests/language-feature/extensions/generic-extension-2.slang create mode 100644 tests/language-feature/extensions/this-in-extension.slang diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 3e2a88dd88..b9b1f7ab85 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -72,6 +72,14 @@ class SharedASTBuilder : public RefObject ASTBuilder* getInnerASTBuilder() { return m_astBuilder; } + Name* getThisTypeName() + { + if (!m_thisTypeName) + { + m_thisTypeName = getNamePool()->getName("This"); + } + return m_thisTypeName; + } protected: // State shared between ASTBuilders @@ -105,6 +113,8 @@ class SharedASTBuilder : public RefObject NamePool* m_namePool = nullptr; + Name* m_thisTypeName = nullptr; + // This is a private builder used for these shared types ASTBuilder* m_astBuilder = nullptr; Session* m_session = nullptr; @@ -289,7 +299,7 @@ class ASTBuilder : public RefObject auto interfaceDecl = create(); // Always include a `This` member and a `This:IThisInterface` member. auto thisDecl = create(); - thisDecl->nameAndLoc.name = m_sharedASTBuilder->getNamePool()->getName(UnownedStringSlice("This", 4)); + thisDecl->nameAndLoc.name = getSharedASTBuilder()->getThisTypeName(); thisDecl->nameAndLoc.loc = loc; interfaceDecl->addMember(thisDecl); auto thisConstraint = create(); diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 9a44cbbb47..ffa0379962 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -60,7 +60,7 @@ namespace Slang return result; result = checkAndConstructSubtypeWitness(subType, superType, isSubTypeOptions); - if(int(isSubTypeOptions) & int(IsSubTypeOptions::NotReadyForLookup)) + if(!result && (int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching))) return result; getShared()->cacheSubtypeWitness(subType, superType, result); @@ -112,7 +112,7 @@ namespace Slang // tangling convertibility into it. // First, make sure both sub type and super type decl are ready for lookup. - if ( !(int(isSubTypeOptions) & int(IsSubTypeOptions::NotReadyForLookup)) ) + if ( !(int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching)) ) { if (auto subDeclRefType = as(subType)) { diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index afcde8a5bb..e5551a875b 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -83,9 +83,17 @@ namespace Slang Type* interfaceType) { // The most basic test here should be: does the type declare conformance to the trait. - if(isSubtype(type, interfaceType, IsSubTypeOptions::None)) + if (isSubtype(type, interfaceType, constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None)) return type; + // If additional subtype witnesses are provided for `type` in `constraints`, + // try to use them to see if the interface is satisfied. + if (constraints->subTypeForAdditionalWitnesses == type) + { + if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType)) + return type; + } + // Just because `type` doesn't conform to the given `interfaceDeclRef`, that // doesn't necessarily indicate a failure. It is possible that we have a call // like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is @@ -183,6 +191,15 @@ namespace Slang return type; } } + if (constraints->subTypeForAdditionalWitnesses) + { + for (auto witnessKV : *constraints->additionalSubtypeWitnesses) + { + auto unificationResult = TryUnifyTypes(*constraints, ValUnificationContext(), QualType(witnessKV.first), interfaceType); + if (unificationResult) + return type; + } + } } } @@ -610,7 +627,7 @@ namespace Slang HashSet constrainedGenericParams; - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) + for (auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType()) { DeclRef constraintDeclRef = m_astBuilder->getGenericAppDeclRef( genericDeclRef, args.getArrayView().arrayView, constraintDecl).as(); @@ -618,7 +635,7 @@ namespace Slang // Extract the (substituted) sub- and super-type from the constraint. auto sub = getSub(m_astBuilder, constraintDeclRef); auto sup = getSup(m_astBuilder, constraintDeclRef); - + // Mark sub type as constrained. if (auto subDeclRefType = as(constraintDeclRef.getDecl()->sub.type)) constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl()); @@ -636,7 +653,18 @@ namespace Slang } // Search for a witness that shows the constraint is satisfied. - auto subTypeWitness = isSubtype(sub, sup, IsSubTypeOptions::None); + auto subTypeWitness = isSubtype( + sub, + sup, + system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None); + if (!subTypeWitness) + { + if (sub == system->subTypeForAdditionalWitnesses) + { + // If no witness was found, try to find the witness from additional witness. + system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness); + } + } if(subTypeWitness) { // We found a witness, so it will become an (implicit) argument. diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 190433e2f5..5654ac7a6b 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1075,12 +1075,13 @@ namespace Slang DeclRef applyExtensionToType( SemanticsVisitor* semantics, ExtensionDecl* extDecl, - Type* type) + Type* type, + Dictionary* additionalSubtypeWitness) { if(!semantics) return DeclRef(); - return semantics->applyExtensionToType(extDecl, type); + return semantics->applyExtensionToType(extDecl, type, additionalSubtypeWitness); } bool SemanticsVisitor::isDeclUsableAsStaticMember( @@ -6320,7 +6321,7 @@ namespace Slang { // Force add IDefaultInitializable to any struct missing (transitively) `IDefaultInitializable`. auto* defaultInitializableType = m_astBuilder->getDefaultInitializableType(); - if(!isSubtype(DeclRefType::create(m_astBuilder, decl), defaultInitializableType, IsSubTypeOptions::NotReadyForLookup)) + if(!isSubtype(DeclRefType::create(m_astBuilder, decl), defaultInitializableType, IsSubTypeOptions::NoCaching)) { InheritanceDecl* conformanceDecl = m_astBuilder->create(); conformanceDecl->parentDecl = decl; @@ -8253,7 +8254,31 @@ namespace Slang return; } + else if (auto genericTypeParamDecl = targetDeclRefType->getDeclRef().as()) + { + // If we are extending a generic type parameter as in `extension T`, + // we want to register the extension with the interface type `IFoo` instead. + auto genericDecl = as(genericTypeParamDecl.getDecl()->parentDecl); + if (!genericDecl) + goto error; + if (genericDecl != decl->parentDecl) + goto error; + bool isTypeConstrained = false; + for (auto constraintDecl : genericDecl->getMembersOfType()) + { + ensureDecl(constraintDecl, DeclCheckState::ReadyForReference); + if (targetDeclRefType == constraintDecl->sub.type) + { + auto supTypeDeclRef = isDeclRefTypeOf(constraintDecl->sup.type); + getShared()->registerCandidateExtension(supTypeDeclRef.getDecl(), decl); + isTypeConstrained = true; + } + } + if (isTypeConstrained) + return; + } } + error:; if (!as(decl->targetType.type)) { getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnType, decl->targetType); @@ -8357,6 +8382,15 @@ namespace Slang // return DeclRefType::create(m_astBuilder, aggTypeDeclRef); } + else if (auto genTypeParam = declRef.as()) + { + // We will reach here when we are checking `extension T {...}`, + // where inside the extension, `This` type is the target type + // of the extension, in this case this is a DeclRefType to + // a GenericTypeParamDecl. + // + return DeclRefType::create(m_astBuilder, declRef); + } else if (auto extDeclRef = declRef.as()) { // In the body of an `extension`, the `This` @@ -8661,7 +8695,8 @@ namespace Slang DeclRef SemanticsVisitor::applyExtensionToType( ExtensionDecl* extDecl, - Type* type) + Type* type, + Dictionary* additionalSubtypeWitnessesForType) { DeclRef extDeclRef = makeDeclRef(extDecl); @@ -8674,6 +8709,11 @@ namespace Slang ConstraintSystem constraints; constraints.loc = extDecl->loc; constraints.genericDecl = extGenericDecl; + if (additionalSubtypeWitnessesForType) + { + constraints.subTypeForAdditionalWitnesses = type; + constraints.additionalSubtypeWitnesses = additionalSubtypeWitnessesForType; + } // Inside the body of an extension declaration, we may end up trying to apply that // extension to its own target type. @@ -8681,13 +8721,6 @@ namespace Slang // without any additional substitutions. if (extDecl->targetType->equals(type)) { - /* - auto subst = trySolveConstraintSystem( - &constraints, - DeclRef(extGenericDecl, nullptr).as(), - as(as(type)->declRef.substitutions.substitutions)); - return DeclRef(extDecl, subst).as(); - */ return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as(); } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index caec9dceec..44f7e0029d 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -22,8 +22,10 @@ namespace Slang enum class IsSubTypeOptions { None = 0, - /// Type may not be finished 'DeclCheckState::ReadyForLookup` - NotReadyForLookup = 1 << 0, + + /// A type may not be finished 'DeclCheckState::ReadyForLookup` while `isSubType` is called. + /// We should not cache any negative results when this flag is set. + NoCaching = 1 << 0, }; /// Should the given `decl` be treated as a static rather than instance declaration? @@ -686,11 +688,37 @@ namespace Slang FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit); FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func); + struct InheritanceCircularityInfo + { + InheritanceCircularityInfo( + Decl* decl, + InheritanceCircularityInfo* next) + : decl(decl) + , next(next) + {} + + /// A declaration whose inheritance is being calculated + Decl* decl = nullptr; + + /// The rest of the links in the chain of declarations being processed + InheritanceCircularityInfo* next = nullptr; + }; + /// Get the processed inheritance information for `type`, including all its facets - InheritanceInfo getInheritanceInfo(Type* type); + InheritanceInfo getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo = nullptr); /// Get the processed inheritance information for `extension`, including all its facets - InheritanceInfo getInheritanceInfo(DeclRef const& extension); + InheritanceInfo getInheritanceInfo(DeclRef const& extension, InheritanceCircularityInfo* circularityInfo = nullptr); + + /// Prevent an unsupported case of + /// ``` + /// extension : IBar{}; + /// extesnion : IFoo{}; + /// ``` + /// from causing infinite recursion. + bool _checkForCircularityInExtensionTargetType( + Decl* decl, + InheritanceCircularityInfo* circularityInfo); /// Try get subtype witness from cache, returns true if cache contains a result for the query. bool tryGetSubtypeWitnessFromCache(Type* sub, Type* sup, SubtypeWitness*& outWitness) @@ -734,9 +762,9 @@ namespace Slang ASTBuilder* _getASTBuilder() { return m_linkage->getASTBuilder(); } - InheritanceInfo _getInheritanceInfo(DeclRef declRef, DeclRefType* correspondingType); - InheritanceInfo _calcInheritanceInfo(Type* type); - InheritanceInfo _calcInheritanceInfo(DeclRef declRef, DeclRefType* correspondingType); + InheritanceInfo _getInheritanceInfo(DeclRef declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo); + InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo); + InheritanceInfo _calcInheritanceInfo(DeclRef declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo); struct DirectBaseInfo { @@ -2101,6 +2129,10 @@ namespace Slang // Constraints we have accumulated, which constrain // the possible arguments for those parameters. List constraints; + + // Additional subtype witnesses available to the currentt constraint solving context. + Type* subTypeForAdditionalWitnesses = nullptr; + Dictionary* additionalSubtypeWitnesses = nullptr; }; Type* TryJoinVectorAndScalarType( @@ -2536,7 +2568,8 @@ namespace Slang // Is the candidate extension declaration actually applicable to the given type DeclRef applyExtensionToType( ExtensionDecl* extDecl, - Type* type); + Type* type, + Dictionary* additionalSubtypeWitnessesForType = nullptr); // Take a generic declaration that is being applied // in a context and attempt to infer any missing generic @@ -2702,7 +2735,8 @@ namespace Slang DeclRef applyExtensionToType( SemanticsVisitor* semantics, ExtensionDecl* extDecl, - Type* type); + Type* type, + Dictionary* additionalSubtypeWitness = nullptr); struct SemanticsExprVisitor diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 3e59c5e8dc..0dc80cdc31 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -7,14 +7,14 @@ namespace Slang { - InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type) + InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo) { // We cache the computed inheritance information for types, // and re-use that information whenever possible. // DeclRefTypes will have their inheritance info cached in m_mapDeclRefToInheritanceInfo. if (auto declRefType = as(type)) - return _getInheritanceInfo(declRefType->getDeclRef(), declRefType); + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); // Non ordinary types are cached on m_mapTypeToInheritanceInfo. if (auto found = m_mapTypeToInheritanceInfo.tryGetValue(type)) @@ -29,22 +29,48 @@ namespace Slang // m_mapTypeToInheritanceInfo[type] = InheritanceInfo(); - auto info = _calcInheritanceInfo(type); + auto info = _calcInheritanceInfo(type, circularityInfo); m_mapTypeToInheritanceInfo[type] = info; return info; } - InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef const& extension) + InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef const& extension, InheritanceCircularityInfo* circularityInfo) { + if (_checkForCircularityInExtensionTargetType(extension.getDecl(), circularityInfo)) + { + // If we detect a circularity in the inheritance graph, + // we will return an empty `InheritanceInfo` to avoid + // infinite recursion. + // + return InheritanceInfo(); + } + // We bottleneck the calculation of inheritance information // for type and `extension` `DeclRef`s through a single // routine with an optional `Type` parameter. // - return _getInheritanceInfo(extension, nullptr); + InheritanceCircularityInfo newCircularityInfo(extension.getDecl(), circularityInfo); + return _getInheritanceInfo(extension, nullptr, &newCircularityInfo); } - InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef declRef, DeclRefType* declRefType) + bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType( + Decl* decl, + InheritanceCircularityInfo* circularityInfo) + { + for (auto info = circularityInfo; info; info = info->next) + { + if (decl == info->decl) + { + getSink()->diagnose(decl, Diagnostics::circularityInExtension, decl); + return true; + } + } + + return false; + } + + InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) { // Just as with `Type`s, we cache and re-use the inheritance // information that has been computed for a `DeclRef` whenever @@ -62,7 +88,7 @@ namespace Slang // m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo(); - auto info = _calcInheritanceInfo(declRef, declRefType); + auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo); m_mapDeclRefToInheritanceInfo[declRef] = info; getSession()->m_typeDictionarySize = Math::Max( @@ -71,7 +97,7 @@ namespace Slang return info; } - InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef declRef, DeclRefType* declRefType) + InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) { // This method is the main engine for computing linearized inheritance // lists for types and `extension` declarations. @@ -215,7 +241,7 @@ namespace Slang // SLANG_ASSERT(selfIsBaseWitness); - auto baseInheritanceInfo = getInheritanceInfo(baseType); + auto baseInheritanceInfo = getInheritanceInfo(baseType, circularityInfo); DeclRef baseDeclRef; if (auto baseDeclRefType = as(baseType)) @@ -231,6 +257,51 @@ namespace Slang baseInheritanceInfo); }; + // If we know the type has a facet represented by `extensionTargetDeclRef`, we can consider + // all extensions on this decl to see if they apply to the type. + // + auto considerExtension = [&](DeclRef extensionTargetDeclRef, Dictionary* additionalSubtypeWitness) + { + bool result = false; + for (auto extDecl : getCandidateExtensions(extensionTargetDeclRef, &visitor)) + { + // The list of *candidate* extensions is computed and + // cached based on the identity of the declaration alone, + // and does not take into account any generic arguments + // of either the type or the `extension`. + // + // For example, we might have an `extension` that applies + // to `vector` for any `N`, but the `selfType` + // that we are working with could be `` so + // that the extension doesn't match. + // + // In order to make sure that we don't enumerate members + // that don't make sense in context, we must apply + // the extension to the type and see if we succeed in + // making a match. + // + auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType, additionalSubtypeWitness); + if (!extDeclRef) + continue; + + // In the case where we *do* find an extension that + // applies to the type, we add a declared base to + // represent the `extension`, knowing that its + // own linearized inheritance list will include + // any transitive based declared on the `extension`. + // + auto extInheritanceInfo = getInheritanceInfo(extDeclRef, circularityInfo); + addDirectBaseFacet( + Facet::Kind::Extension, + selfType, + selfIsSelf, + extDeclRef, + extInheritanceInfo); + result = true; + } + return result; + }; + // We now look at the structure of the declaration itself // to help us enumerate the direct bases. // @@ -280,9 +351,26 @@ namespace Slang auto genericDeclRef = genericTypeParamDeclRef.getParent().as(); SLANG_ASSERT(genericDeclRef); - ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); + if (auto extensionDecl = as(genericDeclRef.getDecl()->inner)) + { + if (isDeclRefTypeOf(extensionDecl->targetType.type) == genericTypeParamDeclRef) + { + // If `T` is a generic parameter where the same generic is an extension on `T`, + // then we need to add the extension itself as a facet. + // + auto extDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, &visitor, extensionDecl); + auto selfExtFacet = new(arena) Facet::Impl( + Facet::Kind::Extension, + Facet::Directness::Direct, + extDeclRef, + selfType, + astBuilder->getTypeEqualityWitness(selfType)); + allFacets.add(selfExtFacet); + } + } + for (auto constraintDeclRef : getMembersOfType(astBuilder, genericDeclRef)) { auto subType = getSub(astBuilder, constraintDeclRef); @@ -326,63 +414,48 @@ namespace Slang // to consider any `extension` declarations that might apply to // a type being delared. // - // In our current system, only nominal types (those with `Decl`s) - // can be extended, so we begin by checking if the `selfType` - // is a nominal/`DeclRef` type. - // - // Note: this step will *not* apply when `declRef` is an `extension` - // declaration, since it directly checks for an `AggTypeDecl` - // instead of an `AggTypeDeclBase`. - // - // Similarly, we do *not* add the type being extended to the list - // of bases for an `extension`. - // - // These choices are important to avoid circular dependencies, where - // the linearization of an `extension` would end up depending on its - // own linearization (either directly or through a dependency on - // the linearization of the type being extended). - // - // Instead, the linearization we create here for an `extension` will - // *only* contain facets for the members introduced by the `extension` - // itself, as well as any transitive bases declared on that `extension`. + // An `extension` may apply to our type, if it directly extends + // the type, or extends a generic `T` type that are constrained + // on one of the interfaces that our type conforms to. // if (auto directAggTypeDeclRef = declRef.as()) { - for (auto extDecl : getCandidateExtensions(directAggTypeDeclRef, &visitor)) + considerExtension(directAggTypeDeclRef, nullptr); + } + HashSet supTypesConsideredForExtensionApplication; + Dictionary additionalSubtypeWitnesses; + for (;;) + { + // After we flatten the list of bases, we may discover additional opportunities + // to apply extensions. + List> supTypeWorkList; + for (auto curFacet : directBaseFacets) { - // The list of *candidate* extensions is computed and - // cached based on the identity of the declaration alone, - // and does not take into account any generic arguments - // of either the type or the `extension`. - // - // For example, we might have an `extension` that applies - // to `vector` for any `N`, but the `selfType` - // that we are working with could be `` so - // that the extension doesn't match. - // - // In order to make sure that we don't enumerate members - // that don't make sense in context, we must apply - // the extension to the type and see if we succeed in - // making a match. - // - auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType); - if (!extDeclRef) + if (!curFacet->subtypeWitness) continue; - - // In the case where we *do* find an extension that - // applies to the type, we add a declared base to - // represent the `extension`, knowing that its - // own linearized inheritance list will include - // any transitive based declared on the `extension`. - // - auto extInheritanceInfo = getInheritanceInfo(extDeclRef); - addDirectBaseFacet( - Facet::Kind::Extension, - selfType, - selfIsSelf, - extDeclRef, - extInheritanceInfo); + auto inheritanceInfo = getInheritanceInfo(curFacet->subtypeWitness->getSup(), circularityInfo); + for (auto facet : inheritanceInfo.facets) + { + if (auto interfaceDeclRef = facet->origin.declRef.as()) + { + SubtypeWitness* transitiveWitness = curFacet->subtypeWitness; + transitiveWitness = astBuilder->getTransitiveSubtypeWitness(curFacet->subtypeWitness, facet->subtypeWitness); + additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness); + if (supTypesConsideredForExtensionApplication.add(facet->origin.type)) + { + supTypeWorkList.add(interfaceDeclRef); + } + } + } } + bool canExit = true; + for (auto baseItem : supTypeWorkList) + { + if (considerExtension(baseItem, &additionalSubtypeWitnesses)) + canExit = false; + } + if (canExit) + break; } // At this point, the list of direct bases (each with its own linearization) @@ -846,7 +919,7 @@ namespace Slang return false; } - InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type) + InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo) { // The majority of the interesting for for computing linearized // inheritance information arises for `DeclRef`s, but we still @@ -861,7 +934,7 @@ namespace Slang // bottleneck through the logic that gets shared between // type and `extension` declarations. // - return _getInheritanceInfo(declRefType->getDeclRef(), declRefType); + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); } else if (auto conjunctionType = as(type)) { @@ -875,8 +948,8 @@ namespace Slang // must include all the facets from the lists for `L` // and `R`, respectively. // - auto leftInfo = getInheritanceInfo(leftType); - auto rightInfo = getInheritanceInfo(rightType); + auto leftInfo = getInheritanceInfo(leftType, circularityInfo); + auto rightInfo = getInheritanceInfo(rightType, circularityInfo); // We have a case of subtype witness that can show that // `T : L` or `T : R` based on `T : L&R`. In this case, @@ -931,7 +1004,7 @@ namespace Slang } else if (auto eachType = as(type)) { - auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType()); + auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType(), circularityInfo); SemanticsVisitor visitor(this); auto directFacet = new(arena) Facet::Impl( Facet::Kind::Type, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index b35acd2c37..5285b5c6e1 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -542,6 +542,7 @@ DIAGNOSTIC(30811, Error, baseOfStructMustBeStructOrInterface, "struct '$0' canno DIAGNOSTIC(30812, Error, baseOfEnumMustBeIntegerOrInterface, "enum '$0' cannot inherit from type '$1' that is neither an interface not a builtin integer type") DIAGNOSTIC(30813, Error, baseOfExtensionMustBeInterface, "extension cannot inherit from non-interface type '$1'") DIAGNOSTIC(30814, Error, baseOfClassMustBeClassOrInterface, "class '$0' cannot inherit from type '$1' that is neither a class nor an interface") +DIAGNOSTIC(30815, Error, circularityInExtension, "circular extension is not allowed.") DIAGNOSTIC(30820, Error, baseStructMustBeListedFirst, "a struct type may only inherit from one other struct type, and that type must appear first in the list of bases") DIAGNOSTIC(30821, Error, tagTypeMustBeListedFirst, "an unum type may only have a single tag type, and that type must be listed first in the list of bases") @@ -552,7 +553,7 @@ DIAGNOSTIC(30831, Error, cannotInheritFromImplicitlySealedDeclarationInAnotherMo DIAGNOSTIC(30832, Error, invalidTypeForInheritance, "type '$0' cannot be used for inheritance") DIAGNOSTIC(30850, Error, invalidExtensionOnType, "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.") -DIAGNOSTIC(30851, Error, invalidMemberTypeInExtension, "$0 cannot be apart of an `extension`") +DIAGNOSTIC(30851, Error, invalidMemberTypeInExtension, "$0 cannot be a part of an `extension`") // 309xx: subscripts DIAGNOSTIC(30900, Error, multiDimensionalArrayNotSupported, "multi-dimensional array is not supported.") diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 7acaa030b1..1d35fe915a 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -499,7 +499,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( BreadcrumbInfo* inBreadcrumbs) { auto semantics = request.semantics; - if (!as(declRef.getDecl()) && getText(name) == "This") + if (!as(declRef.getDecl()) && name == astBuilder->getSharedASTBuilder()->getThisTypeName()) { // If we are looking for `This` in anything other than an InterfaceDecl, // we just need to return the declRef itself. @@ -806,6 +806,10 @@ static void _lookUpInScopes( // a type that uses the "target type" of the `extension`. // type = getTargetType(astBuilder, extDeclRef); + if (name == astBuilder->getSharedASTBuilder()->getThisTypeName()) + { + breadcrumbPtr = nullptr; + } } else { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 04ada006c0..121c68be70 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3335,12 +3335,15 @@ namespace Slang static NodeBase* parseExtensionDecl(Parser* parser, void* /*userData*/) { - ExtensionDecl* decl = parser->astBuilder->create(); - parser->FillPosition(decl); - decl->targetType = parser->ParseTypeExp(); - parseOptionalInheritanceClause(parser, decl); - parseDeclBody(parser, decl); - return decl; + return parseOptGenericDecl(parser, [&](GenericDecl*) + { + ExtensionDecl* decl = parser->astBuilder->create(); + parser->FillPosition(decl); + decl->targetType = parser->ParseTypeExp(); + parseOptionalInheritanceClause(parser, decl); + parseDeclBody(parser, decl); + return decl; + }); } diff --git a/tests/language-feature/extensions/generic-extension-1.slang b/tests/language-feature/extensions/generic-extension-1.slang new file mode 100644 index 0000000000..47940e31a7 --- /dev/null +++ b/tests/language-feature/extensions/generic-extension-1.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +interface IFoo +{ + int getVal(); +} + +interface IBar +{ + int getValPlusOne(); +} + +interface IBaz +{ + int getValPlusTwo(); +} + +struct MyInt { int v; } + +extension MyInt : IFoo +{ + int getVal() { return v; } +} + +// Since MyInt:IFoo, the following extension will make MyInt:IBar. +extension T : IBar +{ + int getValPlusOne() { return this.getVal() + 1; } +} + +// Since MyInt:IBar, the following extension will make MyInt:IBaz. +extension T : IBaz +{ + int getValPlusTwo() { return this.getValPlusOne() + 1; } +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + MyInt v = {1}; + + // Check that the extensions applied to MyInt correctly, i.e. + // MyInt.getValPlusTwo() eixsts. + + // CHECK: 3 + outputBuffer[0] = v.getValPlusTwo(); +} \ No newline at end of file diff --git a/tests/language-feature/extensions/generic-extension-2.slang b/tests/language-feature/extensions/generic-extension-2.slang new file mode 100644 index 0000000000..2728a73d67 --- /dev/null +++ b/tests/language-feature/extensions/generic-extension-2.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +interface IFoo +{ + T getFirst(); +} + +__generic> +extension T : IFoo +{ + S getFirst() + { + return this[0]; + } +} + +T getFirstOuter(IFoo arr) +{ + return arr.getFirst(); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + int arr[] = {1,2,3}; + // CHECK: 1 + outputBuffer[0] = getFirstOuter(arr); +} \ No newline at end of file diff --git a/tests/language-feature/extensions/interface-extension.slang b/tests/language-feature/extensions/interface-extension.slang index 1d84ba8443..50bbbe22a8 100644 --- a/tests/language-feature/extensions/interface-extension.slang +++ b/tests/language-feature/extensions/interface-extension.slang @@ -16,7 +16,8 @@ struct MyCounter : ICounter [mutating] void add(int value) { _state += value; } } -extension ICounter +__generic +extension T { [mutating] void increment() { diff --git a/tests/language-feature/extensions/this-in-extension.slang b/tests/language-feature/extensions/this-in-extension.slang new file mode 100644 index 0000000000..374eabe6fb --- /dev/null +++ b/tests/language-feature/extensions/this-in-extension.slang @@ -0,0 +1,38 @@ +// this-in-extension.slang + +// Test that an `This` type works correctly when there is an extension. + +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj + +interface IFoo +{ + static const This identity; +} + +__generic +extension T +{ + This getIdentity() + { + return identity; + } +} + +struct FooImpl : IFoo +{ + int v = 1; + static const This identity = This(); +} + + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + FooImpl foo; + var ident = foo.getIdentity(); + // CHECK: 1 + outputBuffer[0] = ident.v; +}