diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 3bcdc5dc0852e..096831f944a6c 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -32,6 +32,7 @@ namespace swift { +class AbstractFunctionDecl; class AnyFunctionType; class SourceFile; class SILFunctionType; @@ -398,9 +399,6 @@ class DerivativeFunctionTypeError enum class Kind { /// Original function type has no semantic results. NoSemanticResults, - /// Original function type has multiple semantic results. - // TODO(TF-1250): Support function types with multiple semantic results. - MultipleSemanticResults, /// Differentiability parmeter indices are empty. NoDifferentiabilityParameters, /// A differentiability parameter does not conform to `Differentiable`. @@ -429,7 +427,6 @@ class DerivativeFunctionTypeError explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind) : functionType(functionType), kind(kind), value(Value()) { assert(kind == Kind::NoSemanticResults || - kind == Kind::MultipleSemanticResults || kind == Kind::NoDifferentiabilityParameters); }; @@ -579,6 +576,10 @@ void getFunctionSemanticResultTypes( SmallVectorImpl &result, GenericEnvironment *genericEnv = nullptr); +/// Returns the indices of all semantic results for a given function. +IndexSubset *getAllFunctionSemanticResultIndices( + const AbstractFunctionDecl *AFD); + /// Returns the lowered SIL parameter indices for the given AST parameter /// indices and `AnyfunctionType`. /// diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 4c372593c3e96..cb4256ca2d7fc 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3496,9 +3496,6 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none, (DescriptiveDeclKind)) ERROR(autodiff_attr_original_void_result,none, "cannot differentiate void function %0", (DeclName)) -ERROR(autodiff_attr_original_multiple_semantic_results,none, - "cannot differentiate functions with both an 'inout' parameter and a " - "result", ()) ERROR(autodiff_attr_result_not_differentiable,none, "can only differentiate functions with results that conform to " "'Differentiable', but %0 does not conform to 'Differentiable'", (Type)) diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 60a49b43e304f..a517854234eb8 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -196,8 +196,16 @@ void autodiff::getFunctionSemanticResultTypes( functionType->getResult()->getAs()) { formalResultType = resultFunctionType->getResult(); } - if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) - result.push_back({remap(formalResultType), /*isInout*/ false}); + if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) { + // Separate tuple elements into individual results. + if (formalResultType->is()) { + for (auto elt : formalResultType->castTo()->getElements()) { + result.push_back({remap(elt.getType()), /*isInout*/ false}); + } + } else { + result.push_back({remap(formalResultType), /*isInout*/ false}); + } + } // Collect `inout` parameters as semantic results. for (auto param : functionType->getParams()) @@ -211,6 +219,16 @@ void autodiff::getFunctionSemanticResultTypes( } } +IndexSubset * +autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) { + auto originalFn = AFD->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + return IndexSubset::getDefault( + AFD->getASTContext(), numResults, /*includeAll*/ true); +} + // TODO(TF-874): Simplify this helper. See TF-874 for WIP. IndexSubset * autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices, @@ -395,9 +413,6 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const { case Kind::NoSemanticResults: OS << "has no semantic results ('Void' result)"; break; - case Kind::MultipleSemanticResults: - OS << "has multiple semantic results"; - break; case Kind::NoDifferentiabilityParameters: OS << "has no differentiability parameters"; break; diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 3a869beab1a2c..22dedc0cc4cb2 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -6432,31 +6432,86 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( getSubsetParameters(parameterIndices, diffParams, /*reverseCurryLevels*/ !makeSelfParamFirst); - // Get the original semantic result type. + // Get the original non-inout semantic result types. SmallVector originalResults; autodiff::getFunctionSemanticResultTypes(this, originalResults); // Error if no original semantic results. if (originalResults.empty()) return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - // Error if multiple original semantic results. - // TODO(TF-1250): Support functions with multiple semantic results. - if (originalResults.size() > 1) - return llvm::make_error( - this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults); - auto originalResult = originalResults.front(); - auto originalResultType = originalResult.type; - - // Get the original semantic result type's `TangentVector` associated type. - auto resultTan = - originalResultType->getAutoDiffTangentSpace(lookupConformance); - // Error if original semantic result has no tangent space. - if (!resultTan) { + // Accumulate non-inout result tangent spaces. + SmallVector resultTanTypes; + bool hasInoutResult = false; + for (auto i : range(originalResults.size())) { + auto originalResult = originalResults[i]; + auto originalResultType = originalResult.type; + // Voids currently have a defined tangent vector, so ignore them. + if (originalResultType->isVoid()) + continue; + if (originalResult.isInout) { + hasInoutResult = true; + continue; + } + // Get the original semantic result type's `TangentVector` associated type. + auto resultTan = + originalResultType->getAutoDiffTangentSpace(lookupConformance); + if (!resultTan) + continue; + auto resultTanType = resultTan->getType(); + resultTanTypes.push_back(resultTanType); + } + // Append non-wrt inout result tangent spaces. + // This uses the logic from getSubsetParameters(), only operating over all + // parameter indices and looking for non-wrt indices. + SmallVector curryLevels; + // An inlined version of unwrapCurryLevels(). + AnyFunctionType *fnTy = this; + while (fnTy != nullptr) { + curryLevels.push_back(fnTy); + fnTy = fnTy->getResult()->getAs(); + } + + SmallVector curryLevelParameterIndexOffsets(curryLevels.size()); + unsigned currentOffset = 0; + for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) { + curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset; + currentOffset += curryLevels[curryLevelIndex]->getNumParams(); + } + + if (!makeSelfParamFirst) { + std::reverse(curryLevels.begin(), curryLevels.end()); + std::reverse(curryLevelParameterIndexOffsets.begin(), + curryLevelParameterIndexOffsets.end()); + } + + for (unsigned curryLevelIndex : indices(curryLevels)) { + auto *curryLevel = curryLevels[curryLevelIndex]; + unsigned parameterIndexOffset = + curryLevelParameterIndexOffsets[curryLevelIndex]; + for (unsigned paramIndex : range(curryLevel->getNumParams())) { + if (parameterIndices->contains(parameterIndexOffset + paramIndex)) + continue; + + auto param = curryLevel->getParams()[paramIndex]; + if (param.isInOut()) { + auto resultType = param.getPlainType(); + if (resultType->isVoid()) + continue; + auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance); + if (!resultTan) + continue; + auto resultTanType = resultTan->getType(); + resultTanTypes.push_back(resultTanType); + } + } + } + + // Error if no semantic result has a tangent space. + if (resultTanTypes.empty() && !hasInoutResult) { return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, - std::make_pair(originalResultType, /*index*/ 0)); + std::make_pair(originalResults.front().type, /*index*/ 0)); } - auto resultTanType = resultTan->getType(); // Compute the result linear map function type. FunctionType *linearMapType; @@ -6472,11 +6527,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, inout T1, ...) -> Void` // - Differential: `(T0.Tan, ...) -> T1.Tan` // - // Case 3: original function has a wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` + // Case 3: original function has wrt `inout` parameters. + // - Original: `(T0, inout T1, ...) -> R` + // - Differential: `(T0.Tan, inout T1.Tan, ...) -> R.Tan` SmallVector differentialParams; - bool hasInoutDiffParameter = false; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -6491,11 +6545,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( } differentialParams.push_back(AnyFunctionType::Param( paramTan->getType(), Identifier(), diffParam.getParameterFlags())); - if (diffParam.isInOut()) - hasInoutDiffParameter = true; } - auto differentialResult = - hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType; + Type differentialResult; + if (resultTanTypes.empty()) { + differentialResult = ctx.TheEmptyTupleType; + } else if (resultTanTypes.size() == 1) { + differentialResult = resultTanTypes.front(); + } else { + SmallVector differentialResults; + for (auto i : range(resultTanTypes.size())) { + auto resultTanType = resultTanTypes[i]; + differentialResults.push_back( + TupleTypeElt(resultTanType, Identifier())); + } + differentialResult = TupleType::get(differentialResults, ctx); + } + // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; linearMapType = @@ -6513,11 +6578,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, inout T1, ...) -> Void` // - Pullback: `(T1.Tan) -> (T0.Tan, ...)` // - // Case 3: original function has a wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` + // Case 3: original function has wrt `inout` parameters. + // - Original: `(T0, inout T1, ...) -> R` + // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` SmallVector pullbackResults; - bool hasInoutDiffParameter = false; + SmallVector inoutParams; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -6531,7 +6596,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( std::make_pair(paramType, i)); } if (diffParam.isInOut()) { - hasInoutDiffParameter = true; + if (paramType->isVoid()) + continue; + inoutParams.push_back(diffParam); continue; } pullbackResults.emplace_back(paramTan->getType()); @@ -6544,12 +6611,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( } else { pullbackResult = TupleType::get(pullbackResults, ctx); } - auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter); - auto pullbackParam = - AnyFunctionType::Param(resultTanType, Identifier(), flags); + // First accumulate non-inout results as pullback parameters. + SmallVector pullbackParams; + for (auto i : range(resultTanTypes.size())) { + auto resultTanType = resultTanTypes[i]; + auto flags = ParameterTypeFlags().withInOut(false); + pullbackParams.push_back(AnyFunctionType::Param( + resultTanType, Identifier(), flags)); + } + // Then append inout parameters. + for (auto i : range(inoutParams.size())) { + auto inoutParam = inoutParams[i]; + auto inoutParamType = inoutParam.getPlainType(); + auto inoutParamTan = + inoutParamType->getAutoDiffTangentSpace(lookupConformance); + auto flags = ParameterTypeFlags().withInOut(true); + pullbackParams.push_back(AnyFunctionType::Param( + inoutParamTan->getType(), Identifier(), flags)); + } // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; - linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info); + linearMapType = FunctionType::get(pullbackParams, pullbackResult, info); break; } } diff --git a/lib/IRGen/IRGenMangler.h b/lib/IRGen/IRGenMangler.h index 28faba3baa527..f349ec805281a 100644 --- a/lib/IRGen/IRGenMangler.h +++ b/lib/IRGen/IRGenMangler.h @@ -57,9 +57,11 @@ class IRGenMangler : public Mangle::ASTMangler { AutoDiffDerivativeFunctionIdentifier *derivativeId) { beginManglingWithAutoDiffOriginalFunction(func); auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind()); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(func); AutoDiffConfig config( derivativeId->getParameterIndices(), - IndexSubset::get(func->getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()); appendAutoDiffFunctionParts("TJ", kind, config); appendOperator("Tj"); @@ -86,9 +88,11 @@ class IRGenMangler : public Mangle::ASTMangler { AutoDiffDerivativeFunctionIdentifier *derivativeId) { beginManglingWithAutoDiffOriginalFunction(func); auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind()); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(func); AutoDiffConfig config( derivativeId->getParameterIndices(), - IndexSubset::get(func->getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()); appendAutoDiffFunctionParts("TJ", kind, config); appendOperator("Tq"); diff --git a/lib/SIL/IR/SILDeclRef.cpp b/lib/SIL/IR/SILDeclRef.cpp index 1907913f9528f..f4640bdc14622 100644 --- a/lib/SIL/IR/SILDeclRef.cpp +++ b/lib/SIL/IR/SILDeclRef.cpp @@ -853,7 +853,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { auto *silParameterIndices = autodiff::getLoweredParameterIndices( derivativeFunctionIdentifier->getParameterIndices(), getDecl()->getInterfaceType()->castTo()); - auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0}); + auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices( + asAutoDiffOriginalFunction().getAbstractFunctionDecl()); AutoDiffConfig silConfig( silParameterIndices, resultIndices, derivativeFunctionIdentifier->getDerivativeGenericSignature()); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 38412802f9c9d..c39c7ab59eaeb 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -238,8 +238,6 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { resultIndices.push_back(resultAndIndex.index()); // Check `inout` parameters. for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters())) - // FIXME(TF-1305): The `getResults().empty()` condition is a hack. - // // Currently, an `inout` parameter can either be: // 1. Both a differentiability parameter and a differentiability result. // 2. `@noDerivative`: neither a differentiability parameter nor a @@ -251,13 +249,8 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { // cases, so supporting it is a non-goal. // // See TF-1305 for solution ideas. For now, `@noDerivative` `inout` - // parameters are not treated as differentiability results, unless the - // original function has no formal results, in which case all `inout` // parameters are treated as differentiability results. - if (getResults().empty() || - inoutParamAndIndex.value().getDifferentiability() != - SILParameterDifferentiability::NotDifferentiable) - resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); + resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); auto numSemanticResults = getNumResults() + getNumIndirectMutatingParameters(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); @@ -574,7 +567,7 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialResults.push_back({resultTanType, resultConv}); continue; } - // Handle original `inout` parameter. + // Handle original `inout` parameters. auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); auto inoutParamIt = std::next( originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); @@ -709,7 +702,7 @@ static CanSILFunctionType getAutoDiffPullbackType( pullbackParams.push_back({resultTanType, paramConv}); continue; } - // Handle original `inout` parameter. + // Handle `inout` parameters. auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); auto inoutParamIt = std::next( originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index b5c3a8ec9c9ce..1cf60c3ab8fe7 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1252,11 +1252,12 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto *AFD = constant.getAbstractFunctionDecl(); auto emitWitnesses = [&](DeclAttributes &Attrs) { for (auto *diffAttr : Attrs.getAttributes()) { - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); assert((!F->getLoweredFunctionType()->getSubstGenericSignature() || diffAttr->getDerivativeGenericSignature()) && "Type-checking should resolve derivative generic signatures for " "all original SIL functions with generic signatures"); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(AFD); auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( AFD->getGenericSignature(), @@ -1285,7 +1286,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( origAFD->getGenericSignature(), AFD->getGenericSignature()); - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(origAFD); AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, witnessGenSig); emitDifferentiabilityWitness(origAFD, origFn, diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index 5476dc755b695..59d0c865d2478 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -547,11 +547,13 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk( SILGenFunctionBuilder builder(*this); auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction(); Mangle::ASTMangler mangler; + auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices( + originalFnDeclRef.getAbstractFunctionDecl()); auto name = mangler.mangleAutoDiffDerivativeFunction( originalFnDeclRef.getAbstractFunctionDecl(), derivativeId->getKind(), AutoDiffConfig(derivativeId->getParameterIndices(), - IndexSubset::get(getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()), /*isVTableThunk*/ true); auto *thunk = builder.getOrCreateFunction( @@ -571,7 +573,8 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk( auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), derivativeFnDecl->getInterfaceType()->castTo()); - auto *loweredResultIndices = IndexSubset::get(getASTContext(), 1, {0}); + auto *loweredResultIndices = autodiff::getAllFunctionSemanticResultIndices( + originalFnDeclRef.getAbstractFunctionDecl()); auto diffFn = SGF.B.createDifferentiableFunction( loc, loweredParamIndices, loweredResultIndices, originalFn); auto derivativeFn = SGF.B.createDifferentiableFunctionExtract( diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index 33e0a3fa6f067..1e640fe2cd053 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -232,7 +232,8 @@ void collectMinimalIndicesForFunctionCall( auto ¶m = paramAndIdx.value(); if (!param.isIndirectMutating()) continue; - unsigned idx = paramAndIdx.index(); + unsigned idx = + paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); auto inoutArg = ai->getArgument(idx); results.push_back(inoutArg); resultIndices.push_back(inoutParamResultIndex++); @@ -492,10 +493,6 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( SILModule &module, SILFunction *original, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices) { - // AST differentiability witnesses always have a single result. - if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0)) - return nullptr; - // Explicit differentiability witnesses only exist on SIL functions that come // from AST functions. auto *originalAFD = findAbstractFunctionDecl(original); diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index f1a94e4320acd..b44ba46db8d19 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -472,6 +472,12 @@ getOrCreateSubsetParametersThunkForLinearMap( return mappedIndex; }; + auto toIndirectResultsIter = thunk->getIndirectResults().begin(); + auto useNextIndirectResult = [&]() { + assert(toIndirectResultsIter != thunk->getIndirectResults().end()); + arguments.push_back(*toIndirectResultsIter++); + }; + switch (kind) { // Differential arguments are: // - All indirect results, followed by: @@ -480,9 +486,29 @@ getOrCreateSubsetParametersThunkForLinearMap( // indices). // - Zeros (when parameter is not in desired indices). case AutoDiffDerivativeFunctionKind::JVP: { - // Forward all indirect results. - arguments.append(thunk->getIndirectResults().begin(), - thunk->getIndirectResults().end()); + unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults(); + // Forward desired indirect results. + for (unsigned idx : *actualConfig.resultIndices) { + if (idx >= numIndirectResults) + break; + + auto resultInfo = linearMapType->getResults()[idx]; + assert(idx < linearMapType->getNumResults()); + + // Forward result argument in case we do not need to thunk it away. + if (desiredConfig.resultIndices->contains(idx)) { + useNextIndirectResult(); + continue; + } + + // Otherwise, allocate and use an uninitialized indirect result. + auto *indirectResult = builder.createAllocStack( + loc, resultInfo.getSILStorageInterfaceType()); + localAllocations.push_back(indirectResult); + arguments.push_back(indirectResult); + } + assert(toIndirectResultsIter == thunk->getIndirectResults().end()); + auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin(); auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); }; // Iterate over actual indices. @@ -507,10 +533,6 @@ getOrCreateSubsetParametersThunkForLinearMap( // - Zeros (when parameter is not in desired indices). // - All actual arguments. case AutoDiffDerivativeFunctionKind::VJP: { - auto toIndirectResultsIter = thunk->getIndirectResults().begin(); - auto useNextIndirectResult = [&]() { - arguments.push_back(*toIndirectResultsIter++); - }; // Collect pullback arguments. unsigned pullbackResultIndex = 0; for (unsigned i : actualConfig.parameterIndices->getIndices()) { @@ -539,8 +561,18 @@ getOrCreateSubsetParametersThunkForLinearMap( arguments.push_back(indirectResult); } // Forward all actual non-indirect-result arguments. - arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(), - thunk->getArgumentsWithoutIndirectResults().end() - 1); + auto thunkArgs = thunk->getArgumentsWithoutIndirectResults(); + // Slice out the function to be called. + thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1); + unsigned thunkArg = 0; + for (unsigned idx : *actualConfig.resultIndices) { + // Forward result argument in case we do not need to thunk it away. + if (desiredConfig.resultIndices->contains(idx)) + arguments.push_back(thunkArgs[thunkArg++]); + else { // Otherwise, zero it out. + buildZeroArgument(linearMapType->getParameters()[arguments.size()]); + } + } break; } } @@ -550,10 +582,33 @@ getOrCreateSubsetParametersThunkForLinearMap( auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments); // If differential thunk, deallocate local allocations and directly return - // `apply` result. + // `apply` result (if it is desired). if (kind == AutoDiffDerivativeFunctionKind::JVP) { + SmallVector differentialDirectResults; + extractAllElements(ai, builder, differentialDirectResults); + SmallVector allResults; + collectAllActualResultsInTypeOrder(ai, differentialDirectResults, + allResults); + unsigned numResults = thunk->getConventions().getNumDirectSILResults() + + thunk->getConventions().getNumDirectSILResults(); + SmallVector results; + for (unsigned idx : *actualConfig.resultIndices) { + if (idx >= numResults) + break; + + auto result = allResults[idx]; + if (desiredConfig.isWrtResult(idx)) + results.push_back(result); + else { + if (result->getType().isAddress()) + builder.emitDestroyAddrAndFold(loc, result); + else + builder.emitDestroyValueOperation(loc, result); + } + } cleanupValues(); - builder.createReturn(loc, ai); + auto result = joinElements(results, builder, loc); + builder.createReturn(loc, result); return {thunk, interfaceSubs}; } @@ -769,8 +824,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( /*withoutActuallyEscaping*/ false); } assert(origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters() == - 1); + origFnType->getNumIndirectMutatingParameters() > + 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { auto result = diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 8061f84362915..95f5155d55c6e 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4909,12 +4909,6 @@ bool resolveDifferentiableAttrDifferentiabilityParameters( original->getName()) .highlight(original->getSourceRange()); return; - case DerivativeFunctionTypeError::Kind::MultipleSemanticResults: - diags - .diagnose(attr->getLocation(), - diag::autodiff_attr_original_multiple_semantic_results) - .highlight(original->getSourceRange()); - return; case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters: diags.diagnose(attr->getLocation(), diag::diff_params_clause_no_inferred_parameters); @@ -5079,7 +5073,8 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate( } getterDecl->getAttrs().add(newAttr); // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(getterDecl); getterDecl->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; @@ -5094,7 +5089,11 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate( return nullptr; } // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFnRemappedTy, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + ctx, numResults, /*includeAll*/ true); original->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; @@ -5427,12 +5426,6 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { originalAFD->getName()) .highlight(attr->getOriginalFunctionName().Loc.getSourceRange()); return; - case DerivativeFunctionTypeError::Kind::MultipleSemanticResults: - diags - .diagnose(attr->getLocation(), - diag::autodiff_attr_original_multiple_semantic_results) - .highlight(attr->getOriginalFunctionName().Loc.getSourceRange()); - return; case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters: diags.diagnose(attr->getLocation(), diag::diff_params_clause_no_inferred_parameters); @@ -5522,7 +5515,8 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { } // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(Ctx, 1, {0}); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(originalAFD); originalAFD->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivative->getGenericSignature()}); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index cc499e24aa585..fafa782c6bed1 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -490,7 +490,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, witness->getAttrs().add(newAttr); success = true; // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(witnessAFD); witnessAFD->addDerivativeFunctionConfiguration( {newAttr->getParameterIndices(), resultIndices, newAttr->getDerivativeGenericSignature()}); diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index 642fe6a2a21d0..ae9a3f4042923 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -667,9 +667,9 @@ void ModuleFile::loadDerivativeFunctionConfigurations( } auto derivativeGenSig = derivativeGenSigOrError.get(); // NOTE(TF-1038): Result indices are currently unsupported in derivative - // registration attributes. In the meantime, always use `{0}` (wrt the - // first and only result). - auto resultIndices = IndexSubset::get(ctx, 1, {0}); + // registration attributes. In the meantime, always use all results. + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(originalAFD); results.insert({parameterIndices, resultIndices, derivativeGenSig}); } } diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index afd7fd44559cb..c0989007011cc 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -730,22 +730,27 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) { // Add derivative function symbols. for (const auto *differentiableAttr : - AFD->getAttrs().getAttributes()) + AFD->getAttrs().getAttributes()) { + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(AFD); addDerivativeConfiguration( differentiableAttr->getDifferentiabilityKind(), AFD, AutoDiffConfig(differentiableAttr->getParameterIndices(), - IndexSubset::get(AFD->getASTContext(), 1, {0}), + resultIndices, differentiableAttr->getDerivativeGenericSignature())); + } for (const auto *derivativeAttr : - AFD->getAttrs().getAttributes()) + AFD->getAttrs().getAttributes()) { + auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices( + derivativeAttr->getOriginalFunction(AFD->getASTContext())); addDerivativeConfiguration( DifferentiabilityKind::Reverse, derivativeAttr->getOriginalFunction(AFD->getASTContext()), AutoDiffConfig(derivativeAttr->getParameterIndices(), - IndexSubset::get(AFD->getASTContext(), 1, {0}), + resultIndices, AFD->getGenericSignature())); - + } visitDefaultArguments(AFD, AFD->getParameters()); if (AFD->hasAsync()) { diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 3281dbe7f2358..850a5c6e25489 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -533,6 +533,44 @@ func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float { // CHECK: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float // CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float +public struct ArrayWrapper: Differentiable { + var values: [Float] + + @differentiable(reverse) + mutating func get(index: Int) -> Float { + self.values[index] + } + + // Check `inout` with result. + + // CHECK-LABEL: [AD] Activity info for ${{.*}}get{{.*}} at parameter indices (1) and result indices (0, 1) + // CHECK: bb0: + // CHECK: [USEFUL] %0 = argument of bb0 : $Int + // CHECK: [ACTIVE] %1 = argument of bb0 : $*ArrayWrapper + // CHECK: [ACTIVE] %4 = begin_access [read] [static] %1 : $*ArrayWrapper + // CHECK: [ACTIVE] %5 = struct_element_addr %4 : $*ArrayWrapper, #ArrayWrapper.values + // CHECK: [ACTIVE] %6 = load_borrow %5 : $*Array + // CHECK: [ACTIVE] %7 = alloc_stack $Float + // CHECK: [NONE] // function_ref Array.subscript.getter + // CHECK: %8 = function_ref @$sSayxSicig : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0 + // CHECK: [NONE] %9 = apply %8(%7, %0, %6) : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0 + // CHECK: [ACTIVE] %10 = load [trivial] %7 : $*Float +} + +@differentiable(reverse) +func testInoutAndResult(x: Int, y: inout ArrayWrapper) { + let _ = y.get(index: x) +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}testInoutAndResult{{.*}} at parameter indices (1) and result indices (0) +// CHECK: bb0: +// CHECK: [USEFUL] %0 = argument of bb0 : $Int +// CHECK: [ACTIVE] %1 = argument of bb0 : $*ArrayWrapper +// CHECK: [ACTIVE] %4 = begin_access [modify] [static] %1 : $*ArrayWrapper +// CHECK: [NONE] // function_ref ArrayWrapper.get(index:) +// CHECK: %5 = function_ref @$s17activity_analysis12ArrayWrapperV3get5indexSfSi_tF : $@convention(method) (Int, @inout ArrayWrapper) -> Float +// CHECK: [VARIED] %6 = apply %5(%0, %4) : $@convention(method) (Int, @inout ArrayWrapper) -> Float + //===----------------------------------------------------------------------===// // Throwing function differentiation (`try_apply`) //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift index 59ec26ef9bd08..8942432a0a2a9 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift @@ -1,3 +1,25 @@ +import _Differentiation + public struct Struct { public func method(_ x: Float) -> Float { x } } + +// Test cross-module recognition of functions with multiple semantic results. +@differentiable(reverse) +public func swap(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} + +@differentiable(reverse) +public func swapCustom(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} +@derivative(of: swapCustom) +public func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( + value: Void, pullback: (inout Float, inout Float) -> Void +) { + swapCustom(&x, &y) + return ((), {v1, v2 in + let tmp = v1; v1 = v2; v2 = tmp + }) +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift index 7947518708ad3..5485a5f9a68b7 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift @@ -11,3 +11,18 @@ extension Struct: Differentiable { (x, { $0 }) } } + +// Test cross-module recognition of functions with multiple semantic results. +@differentiable(reverse) +func multiply_swap(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swap(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 +} + +@differentiable(reverse) +func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapCustom(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 +} diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 8e99b4bdab4e0..690090699a9bf 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -746,13 +746,18 @@ extension ProtocolRequirementDerivative { func multipleSemanticResults(_ x: inout Float) -> Float { return x } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @derivative(of: multipleSemanticResults) func vjpMultipleSemanticResults(x: inout Float) -> ( - value: Float, pullback: (Float) -> Float -) { - return (multipleSemanticResults(&x), { $0 }) + value: Float, pullback: (Float, inout Float) -> Void +) { fatalError() } + +func inoutNonDifferentiableResult(_ x: inout Float) -> Int { + return 5 } +@derivative(of: inoutNonDifferentiableResult) +func vjpInoutNonDifferentiableResult(x: inout Float) -> ( + value: Int, pullback: (inout Float) -> Void +) { fatalError() } struct InoutParameters: Differentiable { typealias TangentVector = DummyTangentVector @@ -885,17 +890,31 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {} extension InoutParameters { func multipleSemanticResults(_ x: inout Float) -> Float { x } - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} - @derivative(of: multipleSemanticResults) + @derivative(of: multipleSemanticResults, wrt: x) func vjpMultipleSemanticResults(_ x: inout Float) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float, inout Float) -> Void ) { fatalError() } func inoutVoid(_ x: Float, _ void: inout Void) -> Float {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} - @derivative(of: inoutVoid) + @derivative(of: inoutVoid, wrt: (x, void)) func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float) -> Float + ) { fatalError() } +} + +// Test tuple results. + +extension InoutParameters { + func tupleResults(_ x: Float) -> (Float, Float) { (x, x) } + @derivative(of: tupleResults, wrt: x) + func vjpTupleResults(_ x: Float) -> ( + value: (Float, Float), pullback: (Float, Float) -> Float + ) { fatalError() } + + func tupleResultsInt(_ x: Float) -> (Int, Float) { (1, x) } + @derivative(of: tupleResultsInt, wrt: x) + func vjpTupleResults(_ x: Float) -> ( + value: (Int, Float), pullback: (Float) -> Float ) { fatalError() } } diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 0cd6fa5b1bdb1..b9c2fee69d680 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -528,7 +528,6 @@ func two9(x: Float, y: Float) -> Float { func inout1(x: Float, y: inout Float) -> Void { let _ = x + y } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse, wrt: y) func inout2(x: Float, y: inout Float) -> Float { let _ = x + y @@ -670,11 +669,9 @@ final class FinalClass: Differentiable { @differentiable(reverse, wrt: y) func inoutVoid(x: Float, y: inout Float) {} -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) func multipleSemanticResults(_ x: inout Float) -> Float { x } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse, wrt: y) func swap(x: inout Float, y: inout Float) {} @@ -687,7 +684,6 @@ extension InoutParameters { @differentiable(reverse) static func staticMethod(_ lhs: inout Self, rhs: Self) {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {} } @@ -696,11 +692,23 @@ extension InoutParameters { @differentiable(reverse) mutating func mutatingMethod(_ other: Self) {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) mutating func mutatingMethod(_ other: Self) -> Self {} } +// Test tuple results. + +extension InoutParameters { + @differentiable(reverse) + static func tupleResults(_ x: Self) -> (Self, Self) {} + + @differentiable(reverse) + static func tupleResultsInt(_ x: Self) -> (Int, Self) {} + + @differentiable(reverse) + static func tupleResultsInt2(_ x: Self) -> (Self, Int) {} +} + // Test accessors: `set`, `_read`, `_modify`. struct Accessors: Differentiable { diff --git a/test/AutoDiff/Serialization/derivative_attr.swift b/test/AutoDiff/Serialization/derivative_attr.swift index 91677baa80e25..c41c0a36d1a50 100644 --- a/test/AutoDiff/Serialization/derivative_attr.swift +++ b/test/AutoDiff/Serialization/derivative_attr.swift @@ -37,6 +37,26 @@ func derivativeTop2( (y, { (dx, dy) in dy }) } +// Test top-level inout functions. + +func topInout1(_ x: inout S) {} + +// CHECK: @derivative(of: topInout1, wrt: x) +@derivative(of: topInout1) +func derivativeTopInout1(_ x: inout S) -> (value: Void, pullback: (inout S) -> Void) { + fatalError() +} + +func topInout2(_ x: inout S) -> S { + x +} + +// CHECK: @derivative(of: topInout2, wrt: x) +@derivative(of: topInout2) +func derivativeTopInout2(_ x: inout S) -> (value: S, pullback: (S, inout S) -> Void) { + fatalError() +} + // Test instance methods. extension S { diff --git a/test/AutoDiff/Serialization/differentiable_attr.swift b/test/AutoDiff/Serialization/differentiable_attr.swift index b8c83362bd813..e09f7541caf90 100644 --- a/test/AutoDiff/Serialization/differentiable_attr.swift +++ b/test/AutoDiff/Serialization/differentiable_attr.swift @@ -43,6 +43,29 @@ func testWrtClause(x: Float, y: Float) -> Float { return x } +// CHECK: @differentiable(reverse, wrt: x) +// CHECK-NEXT: func testInout(x: inout Float) +@differentiable(reverse) +func testInout(x: inout Float) { + x = x * 2.0 +} + +// CHECK: @differentiable(reverse, wrt: x) +// CHECK-NEXT: func testInoutResult(x: inout Float) -> Float +@differentiable(reverse) +func testInoutResult(x: inout Float) -> Float { + x = x * 2.0 + return x +} + +// CHECK: @differentiable(reverse, wrt: (x, y)) +// CHECK-NEXT: func testMultipleInout(x: inout Float, y: inout Float) +@differentiable(reverse) +func testMultipleInout(x: inout Float, y: inout Float) { + x = x * y + y = x +} + struct InstanceMethod : Differentiable { // CHECK: @differentiable(reverse, wrt: (self, y)) // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float diff --git a/test/AutoDiff/Serialization/differentiable_function.swift b/test/AutoDiff/Serialization/differentiable_function.swift index 316a0a6eca40d..e31d874bb8920 100644 --- a/test/AutoDiff/Serialization/differentiable_function.swift +++ b/test/AutoDiff/Serialization/differentiable_function.swift @@ -15,3 +15,15 @@ func b(_ f: @differentiable(_linear) (Float) -> Float) {} func c(_ f: @differentiable(reverse) (Float, @noDerivative Float) -> Float) {} // CHECK: func c(_ f: @differentiable(reverse) (Float, @noDerivative Float) -> Float) + +func d(_ f: @differentiable(reverse) (inout Float) -> ()) {} +// CHECK: func d(_ f: @differentiable(reverse) (inout Float) -> ()) + +func e(_ f: @differentiable(reverse) (inout Float, inout Float) -> ()) {} +// CHECK: func e(_ f: @differentiable(reverse) (inout Float, inout Float) -> ()) + +func f(_ f: @differentiable(reverse) (inout Float) -> Float) {} +// CHECK: func f(_ f: @differentiable(reverse) (inout Float) -> Float) + +func g(_ f: @differentiable(reverse) (inout Float, Float) -> Float) {} +// CHECK: func g(_ f: @differentiable(reverse) (inout Float, Float) -> Float) diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index 88b33e0ecfeaf..408a991cd2bff 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -121,6 +121,69 @@ SimpleMathTests.test("MultipleResults") { expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapAndReturnProduct)) } +// Test function with multiple `inout` parameters and a custom pullback. +@differentiable(reverse) +func swapCustom(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} +@derivative(of: swapCustom) +func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( + value: Void, pullback: (inout Float, inout Float) -> Void +) { + swapCustom(&x, &y) + return ((), {v1, v2 in + let tmp = v1; v1 = v2; v2 = tmp + }) +} + +SimpleMathTests.test("MultipleResultsWithCustomPullback") { + func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapCustom(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapCustom)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapCustom)) +} + +// Test functions returning tuples. +@differentiable(reverse) +func swapTuple(_ x: Float, _ y: Float) -> (Float, Float) { + return (y, x) +} + +@differentiable(reverse) +func swapTupleCustom(_ x: Float, _ y: Float) -> (Float, Float) { + return (y, x) +} +@derivative(of: swapTupleCustom) +func vjpSwapTupleCustom(_ x: Float, _ y: Float) -> ( + value: (Float, Float), pullback: (Float, Float) -> (Float, Float) +) { + return (swapTupleCustom(x, y), {v1, v2 in + return (v2, v1) + }) +} + +SimpleMathTests.test("ReturningTuples") { + func multiply_swapTuple(_ x: Float, _ y: Float) -> Float { + let result = swapTuple(x, y) + return result.0 * result.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTuple)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTuple)) + + func multiply_swapTupleCustom(_ x: Float, _ y: Float) -> Float { + let result = swapTupleCustom(x, y) + return result.0 * result.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTupleCustom)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTupleCustom)) +} + SimpleMathTests.test("CaptureLocal") { let z: Float = 10 func foo(_ x: Float) -> Float {