@@ -119,10 +119,11 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
119119SILDeclRef::SILDeclRef (ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign,
120120 AutoDiffDerivativeFunctionIdentifier *derivativeId)
121121 : loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0 ),
122- derivativeFunctionIdentifier (derivativeId) {}
122+ pointer (derivativeId) {}
123123
124124SILDeclRef::SILDeclRef (SILDeclRef::Loc baseLoc, bool asForeign)
125- : defaultArgIndex(0 ), derivativeFunctionIdentifier(nullptr ) {
125+ : defaultArgIndex(0 ),
126+ pointer((AutoDiffDerivativeFunctionIdentifier *)nullptr) {
126127 if (auto *vd = baseLoc.dyn_cast <ValueDecl*>()) {
127128 if (auto *fd = dyn_cast<FuncDecl>(vd)) {
128129 // Map FuncDecls directly to Func SILDeclRefs.
@@ -164,7 +165,7 @@ SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
164165SILDeclRef::SILDeclRef (SILDeclRef::Loc baseLoc,
165166 GenericSignature prespecializedSig)
166167 : SILDeclRef(baseLoc, false ) {
167- specializedSignature = prespecializedSig;
168+ pointer = prespecializedSig. getPointer () ;
168169}
169170
170171Optional<AnyFunctionRef> SILDeclRef::getAnyFunctionRef () const {
@@ -232,7 +233,7 @@ bool SILDeclRef::isImplicit() const {
232233SILLinkage SILDeclRef::getLinkage (ForDefinition_t forDefinition) const {
233234
234235 // Prespecializations are public.
235- if (specializedSignature ) {
236+ if (getSpecializedSignature () ) {
236237 return SILLinkage::Public;
237238 }
238239
@@ -678,6 +679,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
678679 using namespace Mangle ;
679680 ASTMangler mangler;
680681
682+ auto *derivativeFunctionIdentifier = getDerivativeFunctionIdentifier ();
681683 if (derivativeFunctionIdentifier) {
682684 std::string originalMangled = asAutoDiffOriginalFunction ().mangle (MKind);
683685 auto *silParameterIndices = autodiff::getLoweredParameterIndices (
@@ -716,14 +718,15 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
716718 }
717719
718720 // Mangle prespecializations.
719- if (specializedSignature ) {
721+ if (getSpecializedSignature () ) {
720722 SILDeclRef nonSpecializedDeclRef = *this ;
721- nonSpecializedDeclRef.specializedSignature = GenericSignature ();
723+ nonSpecializedDeclRef.pointer =
724+ (AutoDiffDerivativeFunctionIdentifier *)nullptr ;
722725 auto mangledNonSpecializedString = nonSpecializedDeclRef.mangle ();
723726 auto *funcDecl = cast<AbstractFunctionDecl>(getDecl ());
724727 auto genericSig = funcDecl->getGenericSignature ();
725728 return GenericSpecializationMangler::manglePrespecialization (
726- mangledNonSpecializedString, genericSig, specializedSignature );
729+ mangledNonSpecializedString, genericSig, getSpecializedSignature () );
727730 }
728731
729732 ASTMangler::SymbolKind SKind = ASTMangler::SymbolKind::Default;
@@ -818,7 +821,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
818821// Returns true if the given JVP/VJP SILDeclRef requires a new vtable entry.
819822// FIXME(TF-1213): Also consider derived declaration `@derivative` attributes.
820823static bool derivativeFunctionRequiresNewVTableEntry (SILDeclRef declRef) {
821- assert (declRef.derivativeFunctionIdentifier &&
824+ assert (declRef.getDerivativeFunctionIdentifier () &&
822825 " Expected a derivative function SILDeclRef" );
823826 auto overridden = declRef.getOverridden ();
824827 if (!overridden)
@@ -828,7 +831,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
828831 declRef.getDecl ()->getAttrs ().getAttributes <DifferentiableAttr>(),
829832 [&](const DifferentiableAttr *derivedDiffAttr) {
830833 return derivedDiffAttr->getParameterIndices () ==
831- declRef.derivativeFunctionIdentifier ->getParameterIndices ();
834+ declRef.getDerivativeFunctionIdentifier () ->getParameterIndices ();
832835 });
833836 assert (derivedDiffAttr && " Expected `@differentiable` attribute" );
834837 // Otherwise, if the base `@differentiable` attribute specifies a derivative
@@ -838,7 +841,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
838841 overridden.getDecl ()->getAttrs ().getAttributes <DifferentiableAttr>();
839842 for (auto *baseDiffAttr : baseDiffAttrs) {
840843 if (baseDiffAttr->getParameterIndices () ==
841- declRef.derivativeFunctionIdentifier ->getParameterIndices ())
844+ declRef.getDerivativeFunctionIdentifier () ->getParameterIndices ())
842845 return false ;
843846 }
844847 // Otherwise, if there is no base `@differentiable` attribute exists, then a
@@ -847,7 +850,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
847850}
848851
849852bool SILDeclRef::requiresNewVTableEntry () const {
850- if (derivativeFunctionIdentifier )
853+ if (getDerivativeFunctionIdentifier () )
851854 if (derivativeFunctionRequiresNewVTableEntry (*this ))
852855 return true ;
853856 if (!hasDecl ())
@@ -928,15 +931,16 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
928931
929932 // JVPs/VJPs are overridden only if the base declaration has a
930933 // `@differentiable` attribute with the same parameter indices.
931- if (derivativeFunctionIdentifier ) {
934+ if (getDerivativeFunctionIdentifier () ) {
932935 auto overriddenAttrs =
933936 overridden.getDecl ()->getAttrs ().getAttributes <DifferentiableAttr>();
934937 for (const auto *attr : overriddenAttrs) {
935938 if (attr->getParameterIndices () !=
936- derivativeFunctionIdentifier ->getParameterIndices ())
939+ getDerivativeFunctionIdentifier () ->getParameterIndices ())
937940 continue ;
938- auto *overriddenDerivativeId = overridden.derivativeFunctionIdentifier ;
939- overridden.derivativeFunctionIdentifier =
941+ auto *overriddenDerivativeId =
942+ overridden.getDerivativeFunctionIdentifier ();
943+ overridden.pointer =
940944 AutoDiffDerivativeFunctionIdentifier::get (
941945 overriddenDerivativeId->getKind (),
942946 overriddenDerivativeId->getParameterIndices (),
0 commit comments