diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 2347b879e4ee6..3ecc78d049ae9 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -110,8 +110,11 @@ namespace swift { class VarDecl; class UnifiedStatsReporter; class IndexSubset; + // SWIFT_ENABLE_TENSORFLOW + struct AutoDiffConfig; class VectorSpace; class DifferentiableAttr; + // SWIFT_ENABLE_TENSORFLOW END enum class KnownProtocolKind : uint8_t; @@ -702,6 +705,21 @@ class ASTContext final { unsigned previousGeneration, llvm::TinyPtrVector &methods); + // SWIFT_ENABLE_TENSORFLOW + /// Load derivative function configurations for the given + /// AbstractFunctionDecl. + /// + /// \param originalAFD The declaration whose derivative function + /// configurations should be loaded. + /// + /// \param previousGeneration The previous generation number. The AST already + /// contains derivative function configurations loaded from any generation up + /// to and including this one. + void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results); + // SWIFT_ENABLE_TENSORFLOW END + /// Retrieve the Clang module loader for this ASTContext. /// /// If there is no Clang module loader, returns a null pointer. diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 39d346c8679bf..c0046fbfe180a 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -5694,6 +5694,30 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { private: ParameterList *Params; +// SWIFT_ENABLE_TENSORFLOW +private: + /// The generation at which we last loaded derivative function configurations. + unsigned DerivativeFunctionConfigGeneration = 0; + /// Prepare to traverse the list of derivative function configurations. + void prepareDerivativeFunctionConfigurations(); + + /// A uniqued list of derivative function configurations. + /// - `@differentiable` and `@derivative` attribute type-checking is + /// responsible for populating derivative function configurations specified + /// in the current module. + /// - Module loading is responsible for populating derivative function + /// configurations from imported modules. + struct DerivativeFunctionConfigurationList; + DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr; + +public: + /// Get all derivative function configurations. + ArrayRef getDerivativeFunctionConfigurations(); + + /// Add the given derivative function configuration. + void addDerivativeFunctionConfiguration(AutoDiffConfig config); +// SWIFT_ENABLE_TENSORFLOW END + protected: // If a function has a body at all, we have either a parsed body AST node or // we have saved the end location of the unparsed body. diff --git a/include/swift/AST/ModuleLoader.h b/include/swift/AST/ModuleLoader.h index af1d372659b69..d4a4a07639d43 100644 --- a/include/swift/AST/ModuleLoader.h +++ b/include/swift/AST/ModuleLoader.h @@ -34,6 +34,9 @@ class DependencyCollector; namespace swift { +// SWIFT_ENABLE_TENSORFLOW +struct AutoDiffConfig; +// SWIFT_ENABLE_TENSORFLOW END class AbstractFunctionDecl; class ClangImporterOptions; class ClassDecl; @@ -151,6 +154,25 @@ class ModuleLoader { unsigned previousGeneration, llvm::TinyPtrVector &methods) = 0; + // SWIFT_ENABLE_TENSORFLOW + /// Load derivative function configurations for the given + /// AbstractFunctionDecl. + /// + /// \param originalAFD The declaration whose derivative function + /// configurations should be loaded. + /// + /// \param previousGeneration The previous generation number. The AST already + /// contains derivative function configurations loaded from any generation up + /// to and including this one. + /// + /// \param results The result list of derivative function configurations. + /// This list will be extended with any methods found in subsequent + /// generations. + virtual void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results) {}; + // SWIFT_ENABLE_TENSORFLOW END + /// Verify all modules loaded by this loader. virtual void verifyAllModules() { } }; diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 9e732f7975d79..21d079cb05a32 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4359,7 +4359,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, /// is necessary for the differentiation transform to support reabstraction /// thunk differentiation because the function argument is opaque and cannot /// be differentiated. Instead, the argument is made `@differentiable` and - /// reabstraction thunk JVP/VJP callers are reponsible for passing a + /// reabstraction thunk JVP/VJP callers are responsible for passing a /// `@differentiable` function. /// - TODO(TF-1036): Investigate more efficient reabstraction thunk /// derivative approaches. The last argument can simply be a diff --git a/include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h b/include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h index 46ce13b1ff717..d951018bbd963 100644 --- a/include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h +++ b/include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h @@ -36,17 +36,19 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, IndexSubset *parameterIndices, IndexSubset *resultIndices); -/// Finds the "@differentiable" attribute on `original` whose parameter indices -/// are a minimal superset of the specified parameter indices. Returns `nullptr` -/// if no such attribute exists. +/// Finds the derivative configuration (from `@differentiable` and +/// `@derivative` attributes) for `original` whose parameter indices are a +/// minimal superset of the specified AST parameter indices. Returns `None` if +/// no such configuration is found. /// /// \param parameterIndices must be lowered to SIL. -/// \param minimalParameterIndices is an output parameter that is set to the SIL -/// indices of the minimal attribute, or to `nullptr` if no attribute exists. -const DifferentiableAttr * -getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original, - IndexSubset *parameterIndices, - IndexSubset *&minimalParameterIndices); +/// \param minimalASTParameterIndices is an output parameter that is set to the +/// AST indices of the minimal configuration, or to `nullptr` if no such +/// configuration exists. +Optional +findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, + IndexSubset *parameterIndices, + IndexSubset *&minimalASTParameterIndices); /// Returns a differentiability witness for `original` whose parameter indices /// are a minimal superset of the specified parameter indices and whose result diff --git a/include/swift/Serialization/SerializedModuleLoader.h b/include/swift/Serialization/SerializedModuleLoader.h index 6b638cfd166a4..ad31c2992a925 100644 --- a/include/swift/Serialization/SerializedModuleLoader.h +++ b/include/swift/Serialization/SerializedModuleLoader.h @@ -166,6 +166,12 @@ class SerializedModuleLoaderBase : public ModuleLoader { unsigned previousGeneration, llvm::TinyPtrVector &methods) override; + // SWIFT_ENABLE_TENSORFLOW + virtual void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results) override; + // SWIFT_ENABLE_TENSORFLOW END + virtual void verifyAllModules() override; }; diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 1d057d5f3d6a6..db78d8fdb41ed 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -1611,6 +1611,19 @@ void ASTContext::loadObjCMethods( } } +// SWIFT_ENABLE_TENSORFLOW +void ASTContext::loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results) { + PrettyStackTraceDecl stackTrace( + "loading derivative function configurations for", originalAFD); + for (auto &loader : getImpl().ModuleLoaders) { + loader->loadDerivativeFunctionConfigurations(originalAFD, + previousGeneration, results); + } +} +// SWIFT_ENABLE_TENSORFLOW END + void ASTContext::verifyAllLoadedModules() const { #ifndef NDEBUG FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules"); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index d43caf2e06a8d..060925ba42420 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -6998,6 +6998,49 @@ StringRef AbstractFunctionDecl::getInlinableBodyText( return extractInlinableText(getASTContext().SourceMgr, body, scratch); } +// SWIFT_ENABLE_TENSORFLOW +/// A uniqued list of derivative function configurations. +struct AbstractFunctionDecl::DerivativeFunctionConfigurationList + : public llvm::SetVector { + // Necessary for `ASTContext` allocation. + void *operator new( + size_t bytes, ASTContext &ctx, + unsigned alignment = alignof(DerivativeFunctionConfigurationList)) { + return ctx.Allocate(bytes, alignment); + } +}; + +void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { + if (DerivativeFunctionConfigs) + return; + auto &ctx = getASTContext(); + DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList(); + // Register an `ASTContext` cleanup calling the list destructor. + ctx.addCleanup([this]() { + this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList(); + }); +} + +ArrayRef +AbstractFunctionDecl::getDerivativeFunctionConfigurations() { + prepareDerivativeFunctionConfigurations(); + auto &ctx = getASTContext(); + if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) { + unsigned previousGeneration = DerivativeFunctionConfigGeneration; + DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration(); + ctx.loadDerivativeFunctionConfigurations(this, previousGeneration, + *DerivativeFunctionConfigs); + } + return DerivativeFunctionConfigs->getArrayRef(); +} + +void AbstractFunctionDecl::addDerivativeFunctionConfiguration( + AutoDiffConfig config) { + prepareDerivativeFunctionConfigurations(); + DerivativeFunctionConfigs->insert(config); +} +// SWIFT_ENABLE_TENSORFLOW END + FuncDecl *FuncDecl::createImpl(ASTContext &Context, SourceLoc StaticLoc, StaticSpellingKind StaticSpelling, diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 3c62fe7cc8a4a..357799d52d99d 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -748,19 +748,19 @@ emitDerivativeFunctionReference( original, invoker, diag::autodiff_protocol_member_not_differentiable); return None; } - // Get the minimal `@differentiable` attribute and parameter index subset. - IndexSubset *minimalParamIndexSet = nullptr; - const auto *minimalAttr = getMinimalASTDifferentiableAttr( - requirementDecl, desiredIndices.parameters, minimalParamIndexSet); - SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet); - // If minimal `@differentiable` attribute does not exist, then no attribute - // exists with a superset of the desired indices. Produce an error. - if (!minimalAttr) { + // Find the minimal derivative configuration: minimal parameter indices and + // corresponding derivative generic signature. If it does not exist, produce + // an error. + IndexSubset *minimalASTParamIndices = nullptr; + auto minimalConfig = findMinimalDerivativeConfiguration( + requirementDecl, desiredIndices.parameters, minimalASTParamIndices); + if (!minimalConfig) { context.emitNondifferentiabilityError( original, invoker, diag::autodiff_member_subset_indices_not_differentiable); return None; } + auto minimalIndices = minimalConfig->getSILAutoDiffIndices(); // Emit a `witness_method` instruction for the derivative function. auto originalType = witnessMethod->getType().castTo(); auto assocType = originalType->getAutoDiffDerivativeFunctionType( @@ -768,7 +768,7 @@ emitDerivativeFunctionReference( kind, context.getTypeConverter(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( - kind, minimalAttr->getParameterIndices(), context.getASTContext()); + kind, minimalASTParamIndices, context.getASTContext()); auto *ref = builder.createWitnessMethod( loc, witnessMethod->getLookupType(), witnessMethod->getConformance(), requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), @@ -792,28 +792,27 @@ emitDerivativeFunctionReference( original, invoker, diag::autodiff_class_member_not_differentiable); return None; } - // Get the minimal `@differentiable` attribute and parameter index subset. - IndexSubset *minimalParamIndexSet = nullptr; - const auto *minimalAttr = getMinimalASTDifferentiableAttr( - methodDecl, desiredIndices.parameters, minimalParamIndexSet); - SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet); - // If minimal `@differentiable` attribute does not exist, then no attribute - // exists with a superset of the desired indices. Produce an error. - if (!minimalAttr) { + // Find the minimal derivative configuration: minimal parameter indices and + // corresponding derivative generic signature. If it does not exist, produce + // an error. + IndexSubset *minimalASTParamIndices = nullptr; + auto minimalConfig = findMinimalDerivativeConfiguration( + methodDecl, desiredIndices.parameters, minimalASTParamIndices); + if (!minimalConfig) { context.emitNondifferentiabilityError( original, invoker, diag::autodiff_member_subset_indices_not_differentiable); return None; } + auto minimalIndices = minimalConfig->getSILAutoDiffIndices(); // Emit a `class_method` instruction for the derivative function. auto originalType = classMethodInst->getType().castTo(); auto assocType = originalType->getAutoDiffDerivativeFunctionType( - minimalIndices.parameters, minimalIndices.source, - kind, context.getTypeConverter(), + minimalIndices.parameters, minimalIndices.source, kind, + context.getTypeConverter(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( - kind, minimalAttr->getParameterIndices(), - context.getASTContext()); + kind, minimalASTParamIndices, context.getASTContext()); auto *ref = builder.createClassMethod( loc, classMethodInst->getOperand(), methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), diff --git a/lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp b/lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp index 0abe40f0844c3..dbac4ebec9fba 100644 --- a/lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp +++ b/lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp @@ -45,34 +45,35 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, return nullptr; } -const DifferentiableAttr * -getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original, - IndexSubset *parameterIndices, - IndexSubset *&minimalParameterIndices) { - const DifferentiableAttr *minimalAttr = nullptr; - minimalParameterIndices = nullptr; - for (auto *attr : original->getAttrs().getAttributes()) { - auto *attrParameterIndices = autodiff::getLoweredParameterIndices( - attr->getParameterIndices(), +Optional +findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, + IndexSubset *parameterIndices, + IndexSubset *&minimalASTParameterIndices) { + Optional minimalConfig = None; + auto configs = original->getDerivativeFunctionConfigurations(); + for (auto config : configs) { + auto *silParameterIndices = autodiff::getLoweredParameterIndices( + config.parameterIndices, original->getInterfaceType()->castTo()); - // If all indices in `parameterIndices` are in `daParameterIndices`, and it - // has fewer indices than our current candidate and a primitive VJP, then - // `attr` is our new candidate. + // If all indices in `parameterIndices` are in `daParameterIndices`, and + // it has fewer indices than our current candidate and a primitive VJP, + // then `attr` is our new candidate. // // NOTE(TF-642): `attr` may come from a un-partial-applied function and // have larger capacity than the desired indices. We expect this logic to // go away when `partial_apply` supports `@differentiable` callees. - if (attrParameterIndices->isSupersetOf(parameterIndices->extendingCapacity( - original->getASTContext(), attrParameterIndices->getCapacity())) && + if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity( + original->getASTContext(), silParameterIndices->getCapacity())) && // fewer parameters than before - (!minimalParameterIndices || - attrParameterIndices->getNumIndices() < - minimalParameterIndices->getNumIndices())) { - minimalAttr = attr; - minimalParameterIndices = attrParameterIndices; + (!minimalConfig || + silParameterIndices->getNumIndices() < + minimalConfig->parameterIndices->getNumIndices())) { + minimalASTParameterIndices = config.parameterIndices; + minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices, + config.derivativeGenericSignature); } } - return minimalAttr; + return minimalConfig; } SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( @@ -88,22 +89,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( if (!originalAFD) return nullptr; - IndexSubset *minimalParameterIndices = nullptr; - const auto *minimalAttr = getMinimalASTDifferentiableAttr( - originalAFD, parameterIndices, minimalParameterIndices); - - // TODO(TF-835): This will also need to search all `@differentiating` - // attributes after we stop synthesizing `@differentiable` attributes for - // `@differentiating` attributes. - - if (!minimalAttr) + IndexSubset *minimalASTParameterIndices = nullptr; + auto minimalConfig = findMinimalDerivativeConfiguration( + originalAFD, parameterIndices, minimalASTParameterIndices); + if (!minimalConfig) return nullptr; - AutoDiffConfig minimalConfig(minimalParameterIndices, resultIndices, - minimalAttr->getDerivativeGenericSignature()); - auto *existingWitness = module.lookUpDifferentiabilityWitness( - {original->getName(), minimalConfig}); + {original->getName(), *minimalConfig}); if (existingWitness) return existingWitness; @@ -113,8 +106,8 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( return SILDifferentiabilityWitness::createDeclaration( module, SILLinkage::PublicExternal, original, - minimalConfig.parameterIndices, minimalConfig.resultIndices, - minimalConfig.derivativeGenericSignature); + minimalConfig->parameterIndices, minimalConfig->resultIndices, + minimalConfig->derivativeGenericSignature); } } // end namespace swift diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index f07d01fb2624e..a08f1cdf5c3e9 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3628,6 +3628,10 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( return nullptr; } getterDecl->getAttrs().add(newAttr); + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + getterDecl->addDerivativeFunctionConfiguration( + {checkedWrtParamIndices, resultIndices, whereClauseGenSig}); return checkedWrtParamIndices; } auto insertion = ctx.DifferentiableAttrs.try_emplace( @@ -3640,6 +3644,10 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( diag::differentiable_attr_duplicate_note); return nullptr; } + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + original->addDerivativeFunctionConfiguration( + {checkedWrtParamIndices, resultIndices, whereClauseGenSig}); return checkedWrtParamIndices; } @@ -3902,9 +3910,8 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) { } // Reject different-file retroactive derivatives. - // TODO(TF-136): Full support for cross-file/cross-module retroactive - // differentiability will require SIL differentiability witnesses and lots of - // plumbing. + // TODO(TF-136): Lift this restriction now that SIL differentiability witness + // infrastructure is ready. if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) { diagnoseAndRemoveAttr(attr, diag::derivative_attr_not_in_same_file_as_original); @@ -3980,6 +3987,12 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) { da->setVJPFunction(derivative); break; } + + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(Ctx, 1, {0}); + originalAFD->addDerivativeFunctionConfiguration( + {checkedWrtParamIndices, resultIndices, + derivative->getGenericSignature()}); } void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index f08bd9d8ff2e8..879d256d8e754 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -565,7 +565,12 @@ swift::matchWitness( reqDiffAttr->getParameterIndices(), /*jvp*/ None, /*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature()); auto insertion = ctx.DifferentiableAttrs.try_emplace( - {witness, newAttr->getParameterIndices()}, newAttr); + {witnessAFD, newAttr->getParameterIndices()}, newAttr); + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + witnessAFD->addDerivativeFunctionConfiguration( + {newAttr->getParameterIndices(), resultIndices, + newAttr->getDerivativeGenericSignature()}); // Valid `@differentiable` attributes are uniqued by their parameter // indices. Reject duplicate attributes for the same decl and parameter // indices pair. diff --git a/lib/Serialization/DeclTypeRecordNodes.def b/lib/Serialization/DeclTypeRecordNodes.def index 70cb3c8b8de68..5000def14362a 100644 --- a/lib/Serialization/DeclTypeRecordNodes.def +++ b/lib/Serialization/DeclTypeRecordNodes.def @@ -190,6 +190,10 @@ OTHER(SELF_PROTOCOL_CONFORMANCE, 251) OTHER(XREF_OPAQUE_RETURN_TYPE_PATH_PIECE, 252) +// SWIFT_ENABLE_TENSORFLOW +OTHER(DERIVATIVE_FUNCTION_CONFIGURATION, 253) +// SWIFT_ENABLE_TENSORFLOW END + #undef RECORD #undef DECLTYPERECORDNODES_HAS_RECORD_VAL #undef RECORD_VAL diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index 710254f943a47..ba5ac444b2549 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -910,6 +910,66 @@ ModuleFile::readObjCMethodTable(ArrayRef fields, StringRef blobData) { base + sizeof(uint32_t), base)); } +/// Used to deserialize entries in the on-disk derivative function configuration +/// table. +class ModuleFile::DerivativeFunctionConfigTableInfo { +public: + using internal_key_type = StringRef; + using external_key_type = internal_key_type; + using data_type = SmallVector, 8>; + using hash_value_type = uint32_t; + using offset_type = unsigned; + + external_key_type GetExternalKey(internal_key_type ID) { return ID; } + + internal_key_type GetInternalKey(external_key_type ID) { return ID; } + + hash_value_type ComputeHash(internal_key_type key) { + return llvm::djbHash(key, SWIFTMODULE_HASH_SEED); + } + + static bool EqualKey(internal_key_type lhs, internal_key_type rhs) { + return lhs == rhs; + } + + static std::pair ReadKeyDataLength(const uint8_t *&data) { + unsigned keyLength = endian::readNext(data); + unsigned dataLength = endian::readNext(data); + return {keyLength, dataLength}; + } + + static internal_key_type ReadKey(const uint8_t *data, unsigned length) { + return StringRef(reinterpret_cast(data), length); + } + + static data_type ReadData(internal_key_type key, const uint8_t *data, + unsigned length) { + data_type result; + const uint8_t *limit = data + length; + while (data < limit) { + DeclID genSigId = endian::readNext(data); + int32_t nameLength = endian::readNext(data); + StringRef mangledName(reinterpret_cast(data), nameLength); + data += nameLength; + result.push_back({mangledName, genSigId}); + } + return result; + } +}; + +std::unique_ptr +ModuleFile::readDerivativeFunctionConfigTable(ArrayRef fields, + StringRef blobData) { + uint32_t tableOffset; + index_block::DerivativeFunctionConfigTableLayout::readRecord(fields, + tableOffset); + auto base = reinterpret_cast(blobData.data()); + + using OwnedTable = std::unique_ptr; + return OwnedTable(SerializedDerivativeFunctionConfigTable::Create( + base + tableOffset, base + sizeof(uint32_t), base)); +} + bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) { if (llvm::Error Err = cursor.EnterSubBlock(INDEX_BLOCK_ID)) { // FIXME this drops the error on the floor. @@ -1011,6 +1071,12 @@ bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) { case index_block::OBJC_METHODS: ObjCMethods = readObjCMethodTable(scratch, blobData); break; + // SWIFT_ENABLE_TENSORFLOW + case index_block::DERIVATIVE_FUNCTION_CONFIGURATIONS: + DerivativeFunctionConfigurations = + readDerivativeFunctionConfigTable(scratch, blobData); + break; + // SWIFT_ENABLE_TENSORFLOW END case index_block::ENTRY_POINT: assert(blobData.empty()); setEntryPointClassID(scratch.front()); @@ -2372,6 +2438,36 @@ void ModuleFile::loadObjCMethods( } } +// SWIFT_ENABLE_TENSORFLOW +void ModuleFile::loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, + llvm::SetVector &results) { + if (!DerivativeFunctionConfigurations) + return; + auto &ctx = originalAFD->getASTContext(); + Mangle::ASTMangler Mangler; + auto mangledName = Mangler.mangleDeclAsUSR(originalAFD, ""); + auto configs = DerivativeFunctionConfigurations->find(mangledName); + if (configs == DerivativeFunctionConfigurations->end()) + return; + for (auto entry : *configs) { + auto *parameterIndices = IndexSubset::getFromString(ctx, entry.first); + auto derivativeGenSigOrError = getGenericSignatureChecked(entry.second); + if (!derivativeGenSigOrError) { + if (!getContext().LangOpts.EnableDeserializationRecovery) + fatal(derivativeGenSigOrError.takeError()); + llvm::consumeError(derivativeGenSigOrError.takeError()); + } + auto derivativeGenSig = derivativeGenSigOrError.get(); + // NOTE(TF-1038): Result indices are currently unsupported in derivative + // registration attributes. In the meantime, always use `{0}` (wrt the + // first and only result). + auto resultIndices = IndexSubset::get(ctx, 1, {0}); + results.insert({parameterIndices, resultIndices, derivativeGenSig}); + } +} +// SWIFT_ENABLE_TENSORFLOW END + Optional> ModuleFile::loadNamedMembers(const IterableDeclContext *IDC, DeclBaseName N, uint64_t contextData) { diff --git a/lib/Serialization/ModuleFile.h b/lib/Serialization/ModuleFile.h index a18febcfdab1b..a2c6e8afa2d21 100644 --- a/lib/Serialization/ModuleFile.h +++ b/lib/Serialization/ModuleFile.h @@ -411,6 +411,14 @@ class ModuleFile llvm::OnDiskIterableChainedHashTable; std::unique_ptr DeclUSRsTable; + // SWIFT_ENABLE_TENSORFLOW + class DerivativeFunctionConfigTableInfo; + using SerializedDerivativeFunctionConfigTable = + llvm::OnDiskIterableChainedHashTable; + std::unique_ptr + DerivativeFunctionConfigurations; + // SWIFT_ENABLE_TENSORFLOW + /// A blob of 0 terminated string segments referenced in \c SourceLocsTextData StringRef SourceLocsTextData; @@ -540,6 +548,14 @@ class ModuleFile std::unique_ptr readDeclMembersTable(ArrayRef fields, StringRef blobData); + // SWIFT_ENABLE_TENSORFLOW + /// Read an on-disk derivative function configuration table stored in + /// index_block::DerivativeFunctionConfigTableLayout format. + std::unique_ptr + readDerivativeFunctionConfigTable(ArrayRef fields, + StringRef blobData); + // SWIFT_ENABLE_TENSORFLOW END + /// Reads the index block, which contains global tables. /// /// Returns false if there was an error. @@ -764,6 +780,14 @@ class ModuleFile bool isInstanceMethod, llvm::TinyPtrVector &methods); + // SWIFT_ENABLE_TENSORFLOW + /// Loads all derivative function configurations for the given + /// AbstractFunctionDecl. + void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, + llvm::SetVector &results); + // SWIFT_ENABLE_TENSORFLOW END + /// Reports all class members in the module to the given consumer. /// /// This is intended for use with id-style lookup and code completion. diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 34aaadc818edb..7a17db27a38ae 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 527; // tensorflow merge; function type differentiability +const uint16_t SWIFTMODULE_VERSION_MINOR = 528; // derivative function config table /// A standard hash seed used for all string hashes in a serialized module. /// @@ -1882,6 +1882,10 @@ namespace index_block { /// produce Objective-C methods. OBJC_METHODS, + // SWIFT_ENABLE_TENSORFLOW + DERIVATIVE_FUNCTION_CONFIGURATIONS, + // SWIFT_ENABLE_TENSORFLOW END + ENTRY_POINT, LOCAL_DECL_CONTEXT_OFFSETS, LOCAL_TYPE_DECLS, @@ -1945,6 +1949,14 @@ namespace index_block { BCBlob // map from member DeclBaseNames to offsets of DECL_MEMBERS records >; + // SWIFT_ENABLE_TENSORFLOW + using DerivativeFunctionConfigTableLayout = BCRecordLayout< + DERIVATIVE_FUNCTION_CONFIGURATIONS, // record ID + BCVBR<16>, // table offset within the blob (see below) + BCBlob // map from original declaration names to derivative configs + >; + // SWIFT_ENABLE_TENSORFLOW END + using EntryPointLayout = BCRecordLayout< ENTRY_POINT, DeclIDField // the ID of the main class; 0 if there was a main source file diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index ca55df9515c66..97dbd951e7b99 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -741,6 +741,9 @@ void Serializer::writeBlockInfoBlock() { BLOCK_RECORD(index_block, CLASS_MEMBERS_FOR_DYNAMIC_LOOKUP); BLOCK_RECORD(index_block, OPERATOR_METHODS); BLOCK_RECORD(index_block, OBJC_METHODS); + // SWIFT_ENABLE_TENSORFLOW + BLOCK_RECORD(index_block, DERIVATIVE_FUNCTION_CONFIGURATIONS); + // SWIFT_ENABLE_TENSORFLOW END BLOCK_RECORD(index_block, ENTRY_POINT); BLOCK_RECORD(index_block, LOCAL_DECL_CONTEXT_OFFSETS); BLOCK_RECORD(index_block, GENERIC_SIGNATURE_OFFSETS); @@ -4572,6 +4575,100 @@ static void writeObjCMethodTable(const index_block::ObjCMethodTableLayout &out, out.emit(scratch, tableOffset, hashTableBlob); } +// SWIFT_ENABLE_TENSORFLOW +namespace { + /// Used to serialize derivative function configurations. + class DerivativeFunctionConfigTableInfo { + public: + using key_type = std::string; + using key_type_ref = StringRef; + using data_type = Serializer::DerivativeFunctionConfigTableData; + using data_type_ref = const data_type &; + using hash_value_type = uint32_t; + using offset_type = unsigned; + + hash_value_type ComputeHash(key_type_ref key) { + assert(!key.empty()); + return llvm::djbHash(key, SWIFTMODULE_HASH_SEED); + } + + std::pair EmitKeyDataLength(raw_ostream &out, + key_type_ref key, + data_type_ref data) { + uint32_t keyLength = key.str().size(); + assert(keyLength == static_cast(keyLength)); + uint32_t dataLength = (sizeof(uint32_t) * 2) * data.size(); + for (auto entry : data) + dataLength += entry.first.size(); + assert(dataLength == static_cast(dataLength)); + endian::Writer writer(out, little); + writer.write(keyLength); + writer.write(dataLength); + return { keyLength, dataLength }; + } + + void EmitKey(raw_ostream &out, key_type_ref key, unsigned len) { + out << key; + } + + void EmitData(raw_ostream &out, key_type_ref key, data_type_ref data, + unsigned len) { + static_assert(declIDFitsIn32Bits(), "DeclID too large"); + endian::Writer writer(out, little); + for (auto &entry : data) { + // Write `GenericSignatureID`. + writer.write(entry.second); + // Write parameter indices string size, followed by data. + writer.write(entry.first.size()); + out << entry.first; + } + } + }; +} // end anonymous namespace + +static void writeDerivativeFunctionConfigs( + Serializer &S, const index_block::DerivativeFunctionConfigTableLayout &out, + Serializer::DerivativeFunctionConfigTable &derivativeConfigs) { + // Create the on-disk hash table. + llvm::OnDiskChainedHashTableGenerator + generator; + llvm::SmallString<32> hashTableBlob; + uint32_t tableOffset; + { + llvm::raw_svector_ostream blobStream(hashTableBlob); + for (auto &entry : derivativeConfigs) + generator.insert(entry.first.get(), entry.second); + // Make sure that no bucket is at offset 0. + endian::write(blobStream, 0, little); + tableOffset = generator.Emit(blobStream); + } + SmallVector scratch; + out.emit(scratch, tableOffset, hashTableBlob); +} + +// Records derivative function configurations for the given AbstractFunctionDecl +// by visiting `@differentiable` and `@derivative` attributes. +static void recordDerivativeFunctionConfig( + Serializer &S, const AbstractFunctionDecl *AFD, + Serializer::UniquedDerivativeFunctionConfigTable &derivativeConfigs) { + auto &ctx = AFD->getASTContext(); + Mangle::ASTMangler Mangler; + for (auto *attr : AFD->getAttrs().getAttributes()) { + auto mangledName = ctx.getIdentifier(Mangler.mangleDeclAsUSR(AFD, "")); + derivativeConfigs[mangledName].insert( + {ctx.getIdentifier(attr->getParameterIndices()->getString()), + attr->getDerivativeGenericSignature()}); + } + for (auto *attr : AFD->getAttrs().getAttributes()) { + auto *origAFD = attr->getOriginalFunction(); + auto mangledName = ctx.getIdentifier(Mangler.mangleDeclAsUSR(origAFD, "")); + derivativeConfigs[mangledName].insert( + {ctx.getIdentifier(attr->getParameterIndices()->getString()), + AFD->getGenericSignature()}); + } +}; +// SWIFT_ENABLE_TENSORFLOW END + /// Recursively walks the members and derived global decls of any nominal types /// to build up global tables. template @@ -4581,6 +4678,9 @@ static void collectInterestingNestedDeclarations( Serializer::DeclTable &operatorMethodDecls, Serializer::ObjCMethodTable &objcMethods, Serializer::NestedTypeDeclsTable &nestedTypeDecls, + // SWIFT_ENABLE_TENSORFLOW + Serializer::UniquedDerivativeFunctionConfigTable &derivativeConfigs, + // SWIFT_ENABLE_TENSORFLOW END bool isLocal = false) { const NominalTypeDecl *nominalParent = nullptr; @@ -4617,14 +4717,21 @@ static void collectInterestingNestedDeclarations( } } - // Record Objective-C methods. - if (auto *func = dyn_cast(member)) + // SWIFT_ENABLE_TENSORFLOW + // Record Objective-C methods and derivative function configurations. + if (auto *func = dyn_cast(member)) { recordObjCMethod(func); + recordDerivativeFunctionConfig(S, func, derivativeConfigs); + } + // SWIFT_ENABLE_TENSORFLOW END // Handle accessors. if (auto storage = dyn_cast(member)) { for (auto *accessor : storage->getAllAccessors()) { recordObjCMethod(accessor); + // SWIFT_ENABLE_TENSORFLOW + recordDerivativeFunctionConfig(S, accessor, derivativeConfigs); + // SWIFT_ENABLE_TENSORFLOW END } } @@ -4647,6 +4754,9 @@ static void collectInterestingNestedDeclarations( collectInterestingNestedDeclarations(S, iterable->getMembers(), operatorMethodDecls, objcMethods, nestedTypeDecls, + // SWIFT_ENABLE_TENSORFLOW + derivativeConfigs, + // SWIFT_ENABLE_TENSORFLOW END isLocal); } } @@ -4660,6 +4770,9 @@ void Serializer::writeAST(ModuleOrSourceFile DC, NestedTypeDeclsTable nestedTypeDecls; LocalTypeHashTableGenerator localTypeGenerator, opaqueReturnTypeGenerator; ExtensionTable extensionDecls; + // SWIFT_ENABLE_TENSORFLOW + UniquedDerivativeFunctionConfigTable uniquedDerivativeConfigs; + // SWIFT_ENABLE_TENSORFLOW END bool hasLocalTypes = false; bool hasOpaqueReturnTypes = false; @@ -4708,6 +4821,10 @@ void Serializer::writeAST(ModuleOrSourceFile DC, } else { llvm_unreachable("all top-level declaration kinds accounted for"); } + // SWIFT_ENABLE_TENSORFLOW + if (auto *AFD = dyn_cast(D)) + recordDerivativeFunctionConfig(*this, AFD, uniquedDerivativeConfigs); + // SWIFT_ENABLE_TENSORFLOW END orderedTopLevelDecls.push_back(addDeclRef(D)); @@ -4717,7 +4834,10 @@ void Serializer::writeAST(ModuleOrSourceFile DC, if (auto IDC = dyn_cast(D)) { collectInterestingNestedDeclarations(*this, IDC->getMembers(), operatorMethodDecls, objcMethods, - nestedTypeDecls); + // SWIFT_ENABLE_TENSORFLOW + nestedTypeDecls, + uniquedDerivativeConfigs); + // SWIFT_ENABLE_TENSORFLOW END } } @@ -4746,7 +4866,11 @@ void Serializer::writeAST(ModuleOrSourceFile DC, if (auto IDC = dyn_cast(TD)) { collectInterestingNestedDeclarations(*this, IDC->getMembers(), operatorMethodDecls, objcMethods, - nestedTypeDecls, /*isLocal=*/true); + // SWIFT_ENABLE_TENSORFLOW + nestedTypeDecls, + uniquedDerivativeConfigs, + /*isLocal=*/true); + // SWIFT_ENABLE_TENSORFLOW END } } @@ -4809,6 +4933,22 @@ void Serializer::writeAST(ModuleOrSourceFile DC, writeNestedTypeDeclsTable(NestedTypeDeclsTable, nestedTypeDecls); } + // SWIFT_ENABLE_TENSORFLOW + // Convert uniqued derivative function config table to serialization- + // ready format: turn `GenericSignature` to `GenericSignatureID`. + DerivativeFunctionConfigTable derivativeConfigs; + for (auto entry : uniquedDerivativeConfigs) { + for (auto config : entry.second) { + auto paramIndices = config.first.str(); + auto genSigID = addGenericSignatureRef(config.second); + derivativeConfigs[entry.first].push_back({paramIndices, genSigID}); + } + } + index_block::DerivativeFunctionConfigTableLayout DerivativeConfigTable(Out); + writeDerivativeFunctionConfigs(*this, DerivativeConfigTable, + derivativeConfigs); + // SWIFT_ENABLE_TENSORFLOW END + if (entryPointClassID.hasValue()) { index_block::EntryPointLayout EntryPoint(Out); EntryPoint.emit(ScratchRecord, entryPointClassID.getValue()); diff --git a/lib/Serialization/Serialization.h b/lib/Serialization/Serialization.h index 3e66e0ae4877a..3df89dc985229 100644 --- a/lib/Serialization/Serialization.h +++ b/lib/Serialization/Serialization.h @@ -246,6 +246,25 @@ class Serializer : public SerializerBase { SmallVector, 4>; using ExtensionTable = llvm::MapVector; + // SWIFT_ENABLE_TENSORFLOW + using DerivativeFunctionConfigTableData = + llvm::SmallVector, 4>; + // In-memory representation of what will eventually be an on-disk hash table + // mapping original declaration USRs to derivative function configurations. + using DerivativeFunctionConfigTable = + llvm::MapVector; + // Uniqued mapping from original declarations USRs to derivative function + // configurations. + // Note: this exists because `GenericSignature` can be used as a `DenseMap` + // key, while `GenericSignatureID` cannot + // (`DenseMapInfo::getEmptyKey()` crashes). To work + // around this, a `UniquedDerivativeFunctionConfigTable` is first + // constructed, and then converted to a `DerivativeFunctionConfigTableData`. + using UniquedDerivativeFunctionConfigTable = llvm::MapVector< + Identifier, + llvm::SmallSetVector, 4>>; + // SWIFT_ENABLE_TENSORFLOW END + private: /// A map from identifiers to methods and properties with the given name. /// diff --git a/lib/Serialization/SerializedModuleLoader.cpp b/lib/Serialization/SerializedModuleLoader.cpp index 3e89e71082f04..b79c59d474a3c 100644 --- a/lib/Serialization/SerializedModuleLoader.cpp +++ b/lib/Serialization/SerializedModuleLoader.cpp @@ -954,6 +954,19 @@ void SerializedModuleLoaderBase::loadObjCMethods( } } +// SWIFT_ENABLE_TENSORFLOW +void SerializedModuleLoaderBase::loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned int previousGeneration, + llvm::SetVector &results) { + for (auto &modulePair : LoadedModuleFiles) { + if (modulePair.second <= previousGeneration) + continue; + modulePair.first->loadDerivativeFunctionConfigurations(originalAFD, + results); + } +} +// SWIFT_ENABLE_TENSORFLOW END + std::error_code MemoryBufferSerializedModuleLoader::findModuleFilesInDirectory( AccessPathElem ModuleID, StringRef DirPath, StringRef ModuleFilename, StringRef ModuleDocFilename, StringRef ModuleSourceInfoFilename,