Skip to content

Commit d38c72b

Browse files
committed
[AutoDiff] Serialize derivative function configurations per module.
`@differentiable` and `@derivative` attributes register derivatives for `AbstractFunctionDecl`s for a particular "derivative function configuration": parameter indices and dervative generic signature. To find `@derivative` functions registered in other Swift modules, derivative function configurations must be serialized per module. When configurations for a `AbstractFunctionDecl` are requested, all configurations from imported modules are deserialized. This module serialization technique has precedent: it is used for protocol conformances (e.g. extension declarations for a nominal type) and Obj-C members for a class type. Add `AbstractFunctionDecl::getDerivativeFunctionConfigurations` entry point for accessing derivative function configurations. Use `AbstractFunctionDecl::getDerivativeFunctionConfigurations` to implement `findMinimalDerivativeConfiguration` for canonical derivative function configuration lookup, replacing `getMinimalASTDifferentiableAttr`. Unblocks TF-815: lowering `@derivative` attributes directly to SIL differentiability witnesses without generating implicit `@differentiable` attributes.
1 parent 542c236 commit d38c72b

File tree

24 files changed

+676
-110
lines changed

24 files changed

+676
-110
lines changed

include/swift/AST/ASTContext.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ namespace swift {
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112112
class IndexSubset;
113-
class VectorSpace;
113+
// SWIFT_ENABLE_TENSORFLOW
114+
struct AutoDiffConfig;
115+
struct AutoDiffDerivativeFunctionKind;
116+
class DerivativeAttr;
114117
class DifferentiableAttr;
118+
class VectorSpace;
119+
// SWIFT_ENABLE_TENSORFLOW END
115120

116121
enum class KnownProtocolKind : uint8_t;
117122

@@ -287,11 +292,18 @@ class ASTContext final {
287292
/// Cache of autodiff-associated vector spaces.
288293
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;
289294

290-
/// Cache of `@differentiable` attributes keyed by parameter indices. This
291-
/// helps us diagnose multiple `@differentiable`s that are with respect to the
292-
/// same set of parameters.
295+
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
296+
/// diagnose duplicate `@differentiable` attributes for the same key.
293297
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
294298
DifferentiableAttrs;
299+
300+
/// Cache of `@derivative` attributes keyed by parameter indices and
301+
/// derivative function kind. Used to diagnose duplicate `@derivative`
302+
/// attributes for the same key.
303+
llvm::DenseMap<
304+
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
305+
DerivativeAttr *>
306+
DerivativeAttrs;
295307
// SWIFT_ENABLE_TENSORFLOW END
296308

297309
private:
@@ -702,6 +714,21 @@ class ASTContext final {
702714
unsigned previousGeneration,
703715
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);
704716

717+
// SWIFT_ENABLE_TENSORFLOW
718+
/// Load derivative function configurations for the given
719+
/// AbstractFunctionDecl.
720+
///
721+
/// \param originalAFD The declaration whose derivative function
722+
/// configurations should be loaded.
723+
///
724+
/// \param previousGeneration The previous generation number. The AST already
725+
/// contains derivative function configurations loaded from any generation up
726+
/// to and including this one.
727+
void loadDerivativeFunctionConfigurations(
728+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
729+
llvm::SetVector<AutoDiffConfig> &results);
730+
// SWIFT_ENABLE_TENSORFLOW END
731+
705732
/// Retrieve the Clang module loader for this ASTContext.
706733
///
707734
/// If there is no Clang module loader, returns a null pointer.

include/swift/AST/Attr.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,6 +1947,8 @@ class DerivativeAttr final
19471947
unsigned NumParsedParameters = 0;
19481948
/// The differentiation parameters' indices, resolved by the type checker.
19491949
IndexSubset *ParameterIndices = nullptr;
1950+
/// The derivative function kind (JVP or VJP), resolved by the type checker.
1951+
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
19501952

19511953
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
19521954
DeclNameWithLoc original,
@@ -1975,6 +1977,12 @@ class DerivativeAttr final
19751977
OriginalFunction = decl;
19761978
}
19771979

1980+
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
1981+
assert(Kind && "Derivative function kind has not yet been resolved");
1982+
return *Kind;
1983+
}
1984+
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
1985+
19781986
/// The parsed differentiation parameters, i.e. the list of parameters
19791987
/// specified in 'wrt:'.
19801988
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {

include/swift/AST/AutoDiff.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ struct AutoDiffConfig {
306306
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
307307
const AutoDiffDerivativeFunctionKind kind;
308308
IndexSubset *const parameterIndices;
309+
// TODO(TF-680): Mangle derivative generic signature requirements as well.
309310

310311
AutoDiffDerivativeFunctionIdentifier(
311312
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
@@ -508,6 +509,27 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
508509
}
509510
};
510511

512+
template<> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
513+
static AutoDiffDerivativeFunctionKind getEmptyKey() {
514+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
515+
DenseMapInfo<unsigned>::getEmptyKey());
516+
}
517+
518+
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
519+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
520+
DenseMapInfo<unsigned>::getTombstoneKey());
521+
}
522+
523+
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
524+
return DenseMapInfo<unsigned>::getHashValue(Val);
525+
}
526+
527+
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
528+
const AutoDiffDerivativeFunctionKind &RHS) {
529+
return LHS == RHS;
530+
}
531+
};
532+
511533
template<> struct DenseMapInfo<SILAutoDiffIndices> {
512534
static SILAutoDiffIndices getEmptyKey() {
513535
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

include/swift/AST/Decl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5694,6 +5694,25 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
56945694
private:
56955695
ParameterList *Params;
56965696

5697+
// SWIFT_ENABLE_TENSORFLOW
5698+
private:
5699+
/// The generation at which we last loaded derivative function configurations.
5700+
unsigned DerivativeFunctionConfigGeneration = 0;
5701+
/// Prepare to traverse the list of derivative function configurations.
5702+
void prepareDerivativeFunctionConfigurations();
5703+
5704+
/// A uniqued list of derivative function configurations.
5705+
struct DerivativeFunctionConfigurationList;
5706+
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
5707+
5708+
public:
5709+
/// Get all derivative function configurations.
5710+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
5711+
5712+
/// Add the given derivative function configuration.
5713+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
5714+
// SWIFT_ENABLE_TENSORFLOW END
5715+
56975716
protected:
56985717
// If a function has a body at all, we have either a parsed body AST node or
56995718
// we have saved the end location of the unparsed body.

include/swift/AST/ModuleLoader.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class DependencyCollector;
3434

3535
namespace swift {
3636

37+
// SWIFT_ENABLE_TENSORFLOW
38+
struct AutoDiffConfig;
39+
// SWIFT_ENABLE_TENSORFLOW END
3740
class AbstractFunctionDecl;
3841
class ClangImporterOptions;
3942
class ClassDecl;
@@ -151,6 +154,25 @@ class ModuleLoader {
151154
unsigned previousGeneration,
152155
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;
153156

157+
// SWIFT_ENABLE_TENSORFLOW
158+
/// Load derivative function configurations for the given
159+
/// AbstractFunctionDecl.
160+
///
161+
/// \param originalAFD The declaration whose derivative function
162+
/// configurations should be loaded.
163+
///
164+
/// \param previousGeneration The previous generation number. The AST already
165+
/// contains derivative function configurations loaded from any generation up
166+
/// to and including this one.
167+
///
168+
/// \param results The result list of derivative function configurations.
169+
/// This list will be extended with any methods found in subsequent
170+
/// generations.
171+
virtual void loadDerivativeFunctionConfigurations(
172+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
173+
llvm::SetVector<AutoDiffConfig> &results){};
174+
// SWIFT_ENABLE_TENSORFLOW END
175+
154176
/// Verify all modules loaded by this loader.
155177
virtual void verifyAllModules() { }
156178
};

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4359,7 +4359,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
43594359
/// is necessary for the differentiation transform to support reabstraction
43604360
/// thunk differentiation because the function argument is opaque and cannot
43614361
/// be differentiated. Instead, the argument is made `@differentiable` and
4362-
/// reabstraction thunk JVP/VJP callers are reponsible for passing a
4362+
/// reabstraction thunk JVP/VJP callers are responsible for passing a
43634363
/// `@differentiable` function.
43644364
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
43654365
/// derivative approaches. The last argument can simply be a

include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,26 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
3636
IndexSubset *parameterIndices,
3737
IndexSubset *resultIndices);
3838

39-
/// Finds the "@differentiable" attribute on `original` whose parameter indices
40-
/// are a minimal superset of the specified parameter indices. Returns `nullptr`
41-
/// if no such attribute exists.
39+
/// Finds the derivative configuration (from `@differentiable` and
40+
/// `@derivative` attributes) for `original` whose parameter indices are a
41+
/// minimal superset of the specified AST parameter indices. Returns true if
42+
/// such a configuration is found.
4243
///
4344
/// \param parameterIndices must be lowered to SIL.
44-
/// \param minimalParameterIndices is an output parameter that is set to the SIL
45-
/// indices of the minimal attribute, or to `nullptr` if no attribute exists.
46-
const DifferentiableAttr *
47-
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
48-
IndexSubset *parameterIndices,
49-
IndexSubset *&minimalParameterIndices);
45+
/// \param minimalASTParameterIndices is an output parameter that is set to the
46+
/// AST indices of the minimal configuration, or to `nullptr` if no such
47+
/// configuration exists.
48+
/// \param minimalSILParameterIndices is an output parameter that is set to the
49+
/// SIL indices of the minimal configuration, or to `nullptr` if no such
50+
/// configuration exists.
51+
/// \param derivativeGenericSignature is an output parameter that is set to the
52+
/// derivative generic signature of the minimal configuration, or the `nullptr`
53+
/// if no such configuration exists.
54+
bool findMinimalDerivativeConfiguration(
55+
AbstractFunctionDecl *original, IndexSubset *parameterIndices,
56+
IndexSubset *&minimalASTParameterIndices,
57+
IndexSubset *&minimalSILParameterIndices,
58+
GenericSignature &derivativeGenericSignature);
5059

5160
/// Returns a differentiability witness for `original` whose parameter indices
5261
/// are a minimal superset of the specified parameter indices and whose result

include/swift/Serialization/SerializedModuleLoader.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ class SerializedModuleLoaderBase : public ModuleLoader {
166166
unsigned previousGeneration,
167167
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) override;
168168

169+
// SWIFT_ENABLE_TENSORFLOW
170+
virtual void loadDerivativeFunctionConfigurations(
171+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
172+
llvm::SetVector<AutoDiffConfig> &results) override;
173+
// SWIFT_ENABLE_TENSORFLOW END
174+
169175
virtual void verifyAllModules() override;
170176
};
171177

lib/AST/ASTContext.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,19 @@ void ASTContext::loadObjCMethods(
16111611
}
16121612
}
16131613

1614+
// SWIFT_ENABLE_TENSORFLOW
1615+
void ASTContext::loadDerivativeFunctionConfigurations(
1616+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
1617+
llvm::SetVector<AutoDiffConfig> &results) {
1618+
PrettyStackTraceDecl stackTrace(
1619+
"loading derivative function configurations for", originalAFD);
1620+
for (auto &loader : getImpl().ModuleLoaders) {
1621+
loader->loadDerivativeFunctionConfigurations(originalAFD,
1622+
previousGeneration, results);
1623+
}
1624+
}
1625+
// SWIFT_ENABLE_TENSORFLOW END
1626+
16141627
void ASTContext::verifyAllLoadedModules() const {
16151628
#ifndef NDEBUG
16161629
FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules");

lib/AST/Decl.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6998,6 +6998,49 @@ StringRef AbstractFunctionDecl::getInlinableBodyText(
69986998
return extractInlinableText(getASTContext().SourceMgr, body, scratch);
69996999
}
70007000

7001+
// SWIFT_ENABLE_TENSORFLOW
7002+
/// A uniqued list of derivative function configurations.
7003+
struct AbstractFunctionDecl::DerivativeFunctionConfigurationList
7004+
: public llvm::SetVector<AutoDiffConfig> {
7005+
// Necessary for `ASTContext` allocation.
7006+
void *operator new(
7007+
size_t bytes, ASTContext &ctx,
7008+
unsigned alignment = alignof(DerivativeFunctionConfigurationList)) {
7009+
return ctx.Allocate(bytes, alignment);
7010+
}
7011+
};
7012+
7013+
void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
7014+
if (DerivativeFunctionConfigs)
7015+
return;
7016+
auto &ctx = getASTContext();
7017+
DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList();
7018+
// Register an `ASTContext` cleanup calling the list destructor.
7019+
ctx.addCleanup([this]() {
7020+
this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList();
7021+
});
7022+
}
7023+
7024+
ArrayRef<AutoDiffConfig>
7025+
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
7026+
prepareDerivativeFunctionConfigurations();
7027+
auto &ctx = getASTContext();
7028+
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
7029+
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
7030+
DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration();
7031+
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
7032+
*DerivativeFunctionConfigs);
7033+
}
7034+
return DerivativeFunctionConfigs->getArrayRef();
7035+
}
7036+
7037+
void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
7038+
AutoDiffConfig config) {
7039+
prepareDerivativeFunctionConfigurations();
7040+
DerivativeFunctionConfigs->insert(config);
7041+
}
7042+
// SWIFT_ENABLE_TENSORFLOW END
7043+
70017044
FuncDecl *FuncDecl::createImpl(ASTContext &Context,
70027045
SourceLoc StaticLoc,
70037046
StaticSpellingKind StaticSpelling,

0 commit comments

Comments
 (0)