Skip to content

Commit ee7644d

Browse files
authored
[AutoDiff] Serialize derivative function configurations per module. (#28608)
`@differentiable` and `@derivative` attributes register derivatives for `AbstractFunctionDecl`s for a particular "derivative function configuration": parameter indices and dervative generic signature. To find `@derivative` functions registered in other Swift modules, derivative function configurations must be serialized per module. When configurations for a `AbstractFunctionDecl` are requested, all configurations from imported modules are deserialized. This module serialization technique has precedent: it is used for protocol conformances (e.g. extension declarations for a nominal type) and Obj-C members for a class type. Add `AbstractFunctionDecl::getDerivativeFunctionConfigurations` entry point for accessing derivative function configurations. Use `AbstractFunctionDecl::getDerivativeFunctionConfigurations` to implement `findMinimalDerivativeConfiguration` for canonical derivative function configuration lookup, replacing `getMinimalASTDifferentiableAttr`. Unblocks TF-815: lowering `@derivative` attributes directly to SIL differentiability witnesses without generating implicit `@differentiable` attributes.
1 parent 59816d8 commit ee7644d

File tree

19 files changed

+521
-75
lines changed

19 files changed

+521
-75
lines changed

include/swift/AST/ASTContext.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ namespace swift {
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112112
class IndexSubset;
113+
// SWIFT_ENABLE_TENSORFLOW
114+
struct AutoDiffConfig;
113115
class VectorSpace;
114116
class DifferentiableAttr;
117+
// SWIFT_ENABLE_TENSORFLOW END
115118

116119
enum class KnownProtocolKind : uint8_t;
117120

@@ -702,6 +705,21 @@ class ASTContext final {
702705
unsigned previousGeneration,
703706
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);
704707

708+
// SWIFT_ENABLE_TENSORFLOW
709+
/// Load derivative function configurations for the given
710+
/// AbstractFunctionDecl.
711+
///
712+
/// \param originalAFD The declaration whose derivative function
713+
/// configurations should be loaded.
714+
///
715+
/// \param previousGeneration The previous generation number. The AST already
716+
/// contains derivative function configurations loaded from any generation up
717+
/// to and including this one.
718+
void loadDerivativeFunctionConfigurations(
719+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
720+
llvm::SetVector<AutoDiffConfig> &results);
721+
// SWIFT_ENABLE_TENSORFLOW END
722+
705723
/// Retrieve the Clang module loader for this ASTContext.
706724
///
707725
/// If there is no Clang module loader, returns a null pointer.

include/swift/AST/Decl.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5694,6 +5694,30 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
56945694
private:
56955695
ParameterList *Params;
56965696

5697+
// SWIFT_ENABLE_TENSORFLOW
5698+
private:
5699+
/// The generation at which we last loaded derivative function configurations.
5700+
unsigned DerivativeFunctionConfigGeneration = 0;
5701+
/// Prepare to traverse the list of derivative function configurations.
5702+
void prepareDerivativeFunctionConfigurations();
5703+
5704+
/// A uniqued list of derivative function configurations.
5705+
/// - `@differentiable` and `@derivative` attribute type-checking is
5706+
/// responsible for populating derivative function configurations specified
5707+
/// in the current module.
5708+
/// - Module loading is responsible for populating derivative function
5709+
/// configurations from imported modules.
5710+
struct DerivativeFunctionConfigurationList;
5711+
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
5712+
5713+
public:
5714+
/// Get all derivative function configurations.
5715+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
5716+
5717+
/// Add the given derivative function configuration.
5718+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
5719+
// SWIFT_ENABLE_TENSORFLOW END
5720+
56975721
protected:
56985722
// If a function has a body at all, we have either a parsed body AST node or
56995723
// we have saved the end location of the unparsed body.

include/swift/AST/ModuleLoader.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class DependencyCollector;
3434

3535
namespace swift {
3636

37+
// SWIFT_ENABLE_TENSORFLOW
38+
struct AutoDiffConfig;
39+
// SWIFT_ENABLE_TENSORFLOW END
3740
class AbstractFunctionDecl;
3841
class ClangImporterOptions;
3942
class ClassDecl;
@@ -151,6 +154,25 @@ class ModuleLoader {
151154
unsigned previousGeneration,
152155
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;
153156

157+
// SWIFT_ENABLE_TENSORFLOW
158+
/// Load derivative function configurations for the given
159+
/// AbstractFunctionDecl.
160+
///
161+
/// \param originalAFD The declaration whose derivative function
162+
/// configurations should be loaded.
163+
///
164+
/// \param previousGeneration The previous generation number. The AST already
165+
/// contains derivative function configurations loaded from any generation up
166+
/// to and including this one.
167+
///
168+
/// \param results The result list of derivative function configurations.
169+
/// This list will be extended with any methods found in subsequent
170+
/// generations.
171+
virtual void loadDerivativeFunctionConfigurations(
172+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
173+
llvm::SetVector<AutoDiffConfig> &results) {};
174+
// SWIFT_ENABLE_TENSORFLOW END
175+
154176
/// Verify all modules loaded by this loader.
155177
virtual void verifyAllModules() { }
156178
};

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4359,7 +4359,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
43594359
/// is necessary for the differentiation transform to support reabstraction
43604360
/// thunk differentiation because the function argument is opaque and cannot
43614361
/// be differentiated. Instead, the argument is made `@differentiable` and
4362-
/// reabstraction thunk JVP/VJP callers are reponsible for passing a
4362+
/// reabstraction thunk JVP/VJP callers are responsible for passing a
43634363
/// `@differentiable` function.
43644364
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
43654365
/// derivative approaches. The last argument can simply be a

include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,19 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
3636
IndexSubset *parameterIndices,
3737
IndexSubset *resultIndices);
3838

39-
/// Finds the "@differentiable" attribute on `original` whose parameter indices
40-
/// are a minimal superset of the specified parameter indices. Returns `nullptr`
41-
/// if no such attribute exists.
39+
/// Finds the derivative configuration (from `@differentiable` and
40+
/// `@derivative` attributes) for `original` whose parameter indices are a
41+
/// minimal superset of the specified AST parameter indices. Returns `None` if
42+
/// no such configuration is found.
4243
///
4344
/// \param parameterIndices must be lowered to SIL.
44-
/// \param minimalParameterIndices is an output parameter that is set to the SIL
45-
/// indices of the minimal attribute, or to `nullptr` if no attribute exists.
46-
const DifferentiableAttr *
47-
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
48-
IndexSubset *parameterIndices,
49-
IndexSubset *&minimalParameterIndices);
45+
/// \param minimalASTParameterIndices is an output parameter that is set to the
46+
/// AST indices of the minimal configuration, or to `nullptr` if no such
47+
/// configuration exists.
48+
Optional<AutoDiffConfig>
49+
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
50+
IndexSubset *parameterIndices,
51+
IndexSubset *&minimalASTParameterIndices);
5052

5153
/// Returns a differentiability witness for `original` whose parameter indices
5254
/// are a minimal superset of the specified parameter indices and whose result

include/swift/Serialization/SerializedModuleLoader.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ class SerializedModuleLoaderBase : public ModuleLoader {
166166
unsigned previousGeneration,
167167
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) override;
168168

169+
// SWIFT_ENABLE_TENSORFLOW
170+
virtual void loadDerivativeFunctionConfigurations(
171+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
172+
llvm::SetVector<AutoDiffConfig> &results) override;
173+
// SWIFT_ENABLE_TENSORFLOW END
174+
169175
virtual void verifyAllModules() override;
170176
};
171177

lib/AST/ASTContext.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,19 @@ void ASTContext::loadObjCMethods(
16111611
}
16121612
}
16131613

1614+
// SWIFT_ENABLE_TENSORFLOW
1615+
void ASTContext::loadDerivativeFunctionConfigurations(
1616+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
1617+
llvm::SetVector<AutoDiffConfig> &results) {
1618+
PrettyStackTraceDecl stackTrace(
1619+
"loading derivative function configurations for", originalAFD);
1620+
for (auto &loader : getImpl().ModuleLoaders) {
1621+
loader->loadDerivativeFunctionConfigurations(originalAFD,
1622+
previousGeneration, results);
1623+
}
1624+
}
1625+
// SWIFT_ENABLE_TENSORFLOW END
1626+
16141627
void ASTContext::verifyAllLoadedModules() const {
16151628
#ifndef NDEBUG
16161629
FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules");

lib/AST/Decl.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6998,6 +6998,49 @@ StringRef AbstractFunctionDecl::getInlinableBodyText(
69986998
return extractInlinableText(getASTContext().SourceMgr, body, scratch);
69996999
}
70007000

7001+
// SWIFT_ENABLE_TENSORFLOW
7002+
/// A uniqued list of derivative function configurations.
7003+
struct AbstractFunctionDecl::DerivativeFunctionConfigurationList
7004+
: public llvm::SetVector<AutoDiffConfig> {
7005+
// Necessary for `ASTContext` allocation.
7006+
void *operator new(
7007+
size_t bytes, ASTContext &ctx,
7008+
unsigned alignment = alignof(DerivativeFunctionConfigurationList)) {
7009+
return ctx.Allocate(bytes, alignment);
7010+
}
7011+
};
7012+
7013+
void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
7014+
if (DerivativeFunctionConfigs)
7015+
return;
7016+
auto &ctx = getASTContext();
7017+
DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList();
7018+
// Register an `ASTContext` cleanup calling the list destructor.
7019+
ctx.addCleanup([this]() {
7020+
this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList();
7021+
});
7022+
}
7023+
7024+
ArrayRef<AutoDiffConfig>
7025+
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
7026+
prepareDerivativeFunctionConfigurations();
7027+
auto &ctx = getASTContext();
7028+
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
7029+
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
7030+
DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration();
7031+
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
7032+
*DerivativeFunctionConfigs);
7033+
}
7034+
return DerivativeFunctionConfigs->getArrayRef();
7035+
}
7036+
7037+
void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
7038+
AutoDiffConfig config) {
7039+
prepareDerivativeFunctionConfigurations();
7040+
DerivativeFunctionConfigs->insert(config);
7041+
}
7042+
// SWIFT_ENABLE_TENSORFLOW END
7043+
70017044
FuncDecl *FuncDecl::createImpl(ASTContext &Context,
70027045
SourceLoc StaticLoc,
70037046
StaticSpellingKind StaticSpelling,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -679,27 +679,27 @@ emitDerivativeFunctionReference(
679679
original, invoker, diag::autodiff_protocol_member_not_differentiable);
680680
return None;
681681
}
682-
// Get the minimal `@differentiable` attribute and parameter index subset.
683-
IndexSubset *minimalParamIndexSet = nullptr;
684-
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
685-
requirementDecl, desiredIndices.parameters, minimalParamIndexSet);
686-
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
687-
// If minimal `@differentiable` attribute does not exist, then no attribute
688-
// exists with a superset of the desired indices. Produce an error.
689-
if (!minimalAttr) {
682+
// Find the minimal derivative configuration: minimal parameter indices and
683+
// corresponding derivative generic signature. If it does not exist, produce
684+
// an error.
685+
IndexSubset *minimalASTParamIndices = nullptr;
686+
auto minimalConfig = findMinimalDerivativeConfiguration(
687+
requirementDecl, desiredIndices.parameters, minimalASTParamIndices);
688+
if (!minimalConfig) {
690689
context.emitNondifferentiabilityError(
691690
original, invoker,
692691
diag::autodiff_member_subset_indices_not_differentiable);
693692
return None;
694693
}
694+
auto minimalIndices = minimalConfig->getSILAutoDiffIndices();
695695
// Emit a `witness_method` instruction for the derivative function.
696696
auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
697697
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
698698
minimalIndices.parameters, minimalIndices.source,
699699
kind, context.getTypeConverter(),
700700
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
701701
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
702-
kind, minimalAttr->getParameterIndices(), context.getASTContext());
702+
kind, minimalASTParamIndices, context.getASTContext());
703703
auto *ref = builder.createWitnessMethod(
704704
loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
705705
requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
@@ -723,28 +723,27 @@ emitDerivativeFunctionReference(
723723
original, invoker, diag::autodiff_class_member_not_differentiable);
724724
return None;
725725
}
726-
// Get the minimal `@differentiable` attribute and parameter index subset.
727-
IndexSubset *minimalParamIndexSet = nullptr;
728-
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
729-
methodDecl, desiredIndices.parameters, minimalParamIndexSet);
730-
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
731-
// If minimal `@differentiable` attribute does not exist, then no attribute
732-
// exists with a superset of the desired indices. Produce an error.
733-
if (!minimalAttr) {
726+
// Find the minimal derivative configuration: minimal parameter indices and
727+
// corresponding derivative generic signature. If it does not exist, produce
728+
// an error.
729+
IndexSubset *minimalASTParamIndices = nullptr;
730+
auto minimalConfig = findMinimalDerivativeConfiguration(
731+
methodDecl, desiredIndices.parameters, minimalASTParamIndices);
732+
if (!minimalConfig) {
734733
context.emitNondifferentiabilityError(
735734
original, invoker,
736735
diag::autodiff_member_subset_indices_not_differentiable);
737736
return None;
738737
}
738+
auto minimalIndices = minimalConfig->getSILAutoDiffIndices();
739739
// Emit a `class_method` instruction for the derivative function.
740740
auto originalType = classMethodInst->getType().castTo<SILFunctionType>();
741741
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
742-
minimalIndices.parameters, minimalIndices.source,
743-
kind, context.getTypeConverter(),
742+
minimalIndices.parameters, minimalIndices.source, kind,
743+
context.getTypeConverter(),
744744
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
745745
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
746-
kind, minimalAttr->getParameterIndices(),
747-
context.getASTContext());
746+
kind, minimalASTParamIndices, context.getASTContext());
748747
auto *ref = builder.createClassMethod(
749748
loc, classMethodInst->getOperand(),
750749
methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),

lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,35 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
4545
return nullptr;
4646
}
4747

48-
const DifferentiableAttr *
49-
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
50-
IndexSubset *parameterIndices,
51-
IndexSubset *&minimalParameterIndices) {
52-
const DifferentiableAttr *minimalAttr = nullptr;
53-
minimalParameterIndices = nullptr;
54-
for (auto *attr : original->getAttrs().getAttributes<DifferentiableAttr>()) {
55-
auto *attrParameterIndices = autodiff::getLoweredParameterIndices(
56-
attr->getParameterIndices(),
48+
Optional<AutoDiffConfig>
49+
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
50+
IndexSubset *parameterIndices,
51+
IndexSubset *&minimalASTParameterIndices) {
52+
Optional<AutoDiffConfig> minimalConfig = None;
53+
auto configs = original->getDerivativeFunctionConfigurations();
54+
for (auto config : configs) {
55+
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
56+
config.parameterIndices,
5757
original->getInterfaceType()->castTo<AnyFunctionType>());
58-
// If all indices in `parameterIndices` are in `daParameterIndices`, and it
59-
// has fewer indices than our current candidate and a primitive VJP, then
60-
// `attr` is our new candidate.
58+
// If all indices in `parameterIndices` are in `daParameterIndices`, and
59+
// it has fewer indices than our current candidate and a primitive VJP,
60+
// then `attr` is our new candidate.
6161
//
6262
// NOTE(TF-642): `attr` may come from a un-partial-applied function and
6363
// have larger capacity than the desired indices. We expect this logic to
6464
// go away when `partial_apply` supports `@differentiable` callees.
65-
if (attrParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
66-
original->getASTContext(), attrParameterIndices->getCapacity())) &&
65+
if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
66+
original->getASTContext(), silParameterIndices->getCapacity())) &&
6767
// fewer parameters than before
68-
(!minimalParameterIndices ||
69-
attrParameterIndices->getNumIndices() <
70-
minimalParameterIndices->getNumIndices())) {
71-
minimalAttr = attr;
72-
minimalParameterIndices = attrParameterIndices;
68+
(!minimalConfig ||
69+
silParameterIndices->getNumIndices() <
70+
minimalConfig->parameterIndices->getNumIndices())) {
71+
minimalASTParameterIndices = config.parameterIndices;
72+
minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices,
73+
config.derivativeGenericSignature);
7374
}
7475
}
75-
return minimalAttr;
76+
return minimalConfig;
7677
}
7778

7879
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
@@ -88,22 +89,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
8889
if (!originalAFD)
8990
return nullptr;
9091

91-
IndexSubset *minimalParameterIndices = nullptr;
92-
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
93-
originalAFD, parameterIndices, minimalParameterIndices);
94-
95-
// TODO(TF-835): This will also need to search all `@differentiating`
96-
// attributes after we stop synthesizing `@differentiable` attributes for
97-
// `@differentiating` attributes.
98-
99-
if (!minimalAttr)
92+
IndexSubset *minimalASTParameterIndices = nullptr;
93+
auto minimalConfig = findMinimalDerivativeConfiguration(
94+
originalAFD, parameterIndices, minimalASTParameterIndices);
95+
if (!minimalConfig)
10096
return nullptr;
10197

102-
AutoDiffConfig minimalConfig(minimalParameterIndices, resultIndices,
103-
minimalAttr->getDerivativeGenericSignature());
104-
10598
auto *existingWitness = module.lookUpDifferentiabilityWitness(
106-
{original->getName(), minimalConfig});
99+
{original->getName(), *minimalConfig});
107100
if (existingWitness)
108101
return existingWitness;
109102

@@ -113,8 +106,8 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
113106

114107
return SILDifferentiabilityWitness::createDeclaration(
115108
module, SILLinkage::PublicExternal, original,
116-
minimalConfig.parameterIndices, minimalConfig.resultIndices,
117-
minimalConfig.derivativeGenericSignature);
109+
minimalConfig->parameterIndices, minimalConfig->resultIndices,
110+
minimalConfig->derivativeGenericSignature);
118111
}
119112

120113
} // end namespace swift

0 commit comments

Comments
 (0)