Skip to content

Commit bbe86e9

Browse files
authored
[AutoDiff upstream] Add Differentiable protocol derived conformances. (#30671)
Add `AdditiveArithmetic` derived conformances for structs and classes, gated by the `-enable-experimental-differentiable-programming` flag. Structs and classes whose stored properties all conform to `Differentiable` can derive `Differentiable`: - `associatedtype TangentVector: Differentiable & AdditiveArithmetic` - Member `TangentVector` structs are synthesized whose stored properties are all `var` stored properties that conform to `Differentiable` and that are not `@noDerivative`. - `mutating func move(along: TangentVector)` The `@noDerivative` attribute may be declared on stored properties to opt out of inclusion in synthesized `TangentVector` structs. Some stored properties cannot be used in `TangentVector` struct synthesis and are implicitly marked as `@noDerivative`, with a warning: - `let` stored properties. - These cannot be updated by `mutating func move(along: TangentVector)`. - Non-`Differentiable`-conforming stored properties. `@noDerivative` also implies `@_semantics("autodiff.nonvarying")`, which is relevant for differentiable activity analysis. Add type-checking and SILGen tests. Resolves TF-845.
1 parent df5ba9c commit bbe86e9

27 files changed

+1871
-15
lines changed

include/swift/AST/Attr.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@ DECL_ATTR(transpose, Transpose,
552552
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
553553
99)
554554

555+
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
556+
OnAbstractFunction | OnVar | OnSubscript |
557+
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
558+
100)
559+
555560
#undef TYPE_ATTR
556561
#undef DECL_ATTR_ALIAS
557562
#undef CONTEXTUAL_DECL_ATTR_ALIAS

include/swift/AST/DiagnosticsSema.def

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,6 +2712,19 @@ ERROR(broken_encodable_requirement,none,
27122712
"Encodable protocol is broken: unexpected requirement", ())
27132713
ERROR(broken_decodable_requirement,none,
27142714
"Decodable protocol is broken: unexpected requirement", ())
2715+
ERROR(broken_differentiable_requirement,none,
2716+
"Differentiable protocol is broken: unexpected requirement", ())
2717+
WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
2718+
"stored property %0 has no derivative because %1 does not conform to "
2719+
"'Differentiable'; add an explicit '@noDerivative' attribute"
2720+
"%select{|, or conform %2 to 'AdditiveArithmetic'}3",
2721+
(Identifier, Type, Identifier, bool))
2722+
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2723+
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
2724+
"requires all stored properties to be mutable; use 'var' instead, or add "
2725+
"an explicit '@noDerivative' attribute"
2726+
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
2727+
(Identifier, Identifier, bool))
27152728

27162729
NOTE(codable_extraneous_codingkey_case_here,none,
27172730
"CodingKey case %0 does not match any stored properties", (Identifier))

include/swift/AST/KnownIdentifiers.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ IDENTIFIER_(nsError)
204204
IDENTIFIER(OSLogMessage)
205205

206206
// Differentiable programming
207+
IDENTIFIER(along)
207208
IDENTIFIER(differential)
209+
IDENTIFIER(direction)
210+
IDENTIFIER(move)
208211
IDENTIFIER(pullback)
209212
IDENTIFIER(TangentVector)
210213
IDENTIFIER(zero)

lib/AST/Decl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5021,7 +5021,8 @@ ArrayRef<Requirement> ProtocolDecl::getCachedRequirementSignature() const {
50215021
void ProtocolDecl::computeKnownProtocolKind() const {
50225022
auto module = getModuleContext();
50235023
if (module != module->getASTContext().getStdlibModule() &&
5024-
!module->getName().is("Foundation")) {
5024+
!module->getName().is("Foundation") &&
5025+
!module->getName().is("_Differentiation")) {
50255026
const_cast<ProtocolDecl *>(this)->Bits.ProtocolDecl.KnownProtocol = 1;
50265027
return;
50275028
}

lib/SIL/SILFunctionBuilder.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ void SILFunctionBuilder::addFunctionAttributes(
6666
if (Attrs.hasAttribute<SILGenNameAttr>() || Attrs.hasAttribute<CDeclAttr>())
6767
F->setHasCReferences(true);
6868

69+
// Validate `@differentiable` attributes by calling `getParameterIndices`.
70+
// This is important for:
71+
// - Skipping invalid `@differentiable` attributes in non-primary files.
72+
// - Preventing duplicate SIL differentiability witness creation for
73+
// `@differentiable` attributes on `AbstractStorageDecl` declarations.
74+
// Such `@differentiable` attributes are deleted and recreated on the getter
75+
// `AccessorDecl` of the `AbstractStorageDecl`.
76+
for (auto *A : Attrs.getAttributes<DifferentiableAttr>())
77+
(void)A->getParameterIndices();
78+
79+
// Propagate `@noDerivative` as `[_semantics "autodiff.nonvarying"]`.
80+
//
81+
// `@noDerivative` implies non-varying semantics for differentiable activity
82+
// analysis. SIL values produced from references to `@noDerivative`
83+
// declarations will not be marked as varying; these values do not need a
84+
// derivative.
85+
if (Attrs.hasAttribute<NoDerivativeAttr>())
86+
F->addSemanticsAttr("autodiff.nonvarying");
87+
6988
// Propagate @_dynamicReplacement(for:).
7089
if (constant.isNull())
7190
return;

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_swift_host_library(swiftSema STATIC
1919
DerivedConformanceCaseIterable.cpp
2020
DerivedConformanceCodable.cpp
2121
DerivedConformanceCodingKey.cpp
22+
DerivedConformanceDifferentiable.cpp
2223
DerivedConformanceEquatableHashable.cpp
2324
DerivedConformanceComparable.cpp
2425
DerivedConformanceError.cpp

lib/Sema/CodeSynthesis.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,3 +1369,21 @@ bool swift::hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
13691369
return v->isLet() && v->hasInitialValue();
13701370
});
13711371
}
1372+
1373+
void swift::addFixedLayoutAttr(NominalTypeDecl *nominal) {
1374+
auto &C = nominal->getASTContext();
1375+
// If nominal already has `@_fixed_layout`, return.
1376+
if (nominal->getAttrs().hasAttribute<FixedLayoutAttr>())
1377+
return;
1378+
auto access = nominal->getEffectiveAccess();
1379+
// If nominal does not have at least internal access, return.
1380+
if (access < AccessLevel::Internal)
1381+
return;
1382+
// If nominal is internal, it should have the `@usableFromInline` attribute.
1383+
if (access == AccessLevel::Internal &&
1384+
!nominal->getAttrs().hasAttribute<UsableFromInlineAttr>()) {
1385+
nominal->getAttrs().add(new (C) UsableFromInlineAttr(/*Implicit*/ true));
1386+
}
1387+
// Add `@_fixed_layout` to the nominal.
1388+
nominal->getAttrs().add(new (C) FixedLayoutAttr(/*Implicit*/ true));
1389+
}

lib/Sema/CodeSynthesis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ ValueDecl *getProtocolRequirement(ProtocolDecl *protocol, Identifier name);
6767
// with an initial value.
6868
bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal);
6969

70+
/// Add `@_fixed_layout` attribute to the nominal type, if possible.
71+
void addFixedLayoutAttr(NominalTypeDecl *nominal);
72+
7073
} // end namespace swift
7174

7275
#endif

0 commit comments

Comments
 (0)