diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index d70d91ebbdfd4..890bf4cd274c9 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -246,16 +246,16 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s, return s; } -/// A semantic function result type: either a formal function result type or -/// an `inout` parameter type. Used in derivative function type calculation. +/// A semantic function result type: either a formal function result type or a +/// semantic result (an `inout`) parameter type. Used in derivative function type +/// calculation. struct AutoDiffSemanticFunctionResultType { Type type; unsigned index : 30; - bool isInout : 1; - bool isWrtParam : 1; + bool isSemanticResultParameter : 1; - AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt) - : type(t), index(idx), isInout(inout), isWrtParam(wrt) { } + AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param) + : type(t), index(idx), isSemanticResultParameter(param) { } }; /// Key for caching SIL derivative function types. diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index c19f5a7a9706f..f8ff0bbafa643 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3192,6 +3192,12 @@ class AnyFunctionType : public TypeBase { /// Whether the parameter is marked '@noDerivative'. bool isNoDerivative() const { return Flags.isNoDerivative(); } + /// Whether the parameter might be a semantic result for autodiff purposes. + /// This includes inout parameters. + bool isAutoDiffSemanticResult() const { + return isInOut(); + } + ValueOwnership getValueOwnership() const { return Flags.getValueOwnership(); } @@ -3509,8 +3515,8 @@ class AnyFunctionType : public TypeBase { /// Preconditions: /// - Parameters corresponding to parameter indices must conform to /// `Differentiable`. - /// - There is one semantic function result type: either the formal original - /// result or an `inout` parameter. It must conform to `Differentiable`. + /// - There are semantic function result type: either the formal original + /// result or a "wrt" semantic result parameter. /// /// Differential typing rules: takes "wrt" parameter derivatives and returns a /// "wrt" result derivative. @@ -3518,10 +3524,7 @@ class AnyFunctionType : public TypeBase { /// - Case 1: original function has no `inout` parameters. /// - Original: `(T0, T1, ...) -> R` /// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` - /// - Case 2: original function has a non-wrt `inout` parameter. - /// - Original: `(T0, inout T1, ...) -> Void` - /// - Differential: `(T0.Tan, ...) -> T1.Tan` - /// - Case 3: original function has a wrt `inout` parameter. + /// - Case 2: original function has a wrt `inout` parameter. /// - Original: `(T0, inout T1, ...) -> Void` /// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` /// @@ -3531,10 +3534,7 @@ class AnyFunctionType : public TypeBase { /// - Case 1: original function has no `inout` parameters. /// - Original: `(T0, T1, ...) -> R` /// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` - /// - Case 2: original function has a non-wrt `inout` parameter. - /// - Original: `(T0, inout T1, ...) -> Void` - /// - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - /// - Case 3: original function has a wrt `inout` parameter. + /// - Case 2: original function has a wrt `inout` parameter. /// - Original: `(T0, inout T1, ...) -> Void` /// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` /// @@ -4101,6 +4101,9 @@ class SILParameterInfo { return getConvention() == ParameterConvention::Indirect_Inout || getConvention() == ParameterConvention::Indirect_InoutAliasable; } + bool isAutoDiffSemanticResult() const { + return isIndirectMutating(); + } bool isPack() const { return isPackParameter(getConvention()); @@ -4836,6 +4839,37 @@ class SILFunctionType final return llvm::count_if(getParameters(), IndirectMutatingParameterFilter()); } + struct AutoDiffSemanticResultsParameterFilter { + bool operator()(SILParameterInfo param) const { + return param.isAutoDiffSemanticResult(); + } + }; + + using AutoDiffSemanticResultsParameterIter = + llvm::filter_iterator; + using AutoDiffSemanticResultsParameterRange = + iterator_range; + + /// A range of SILParameterInfo for all semantic results parameters. + AutoDiffSemanticResultsParameterRange + getAutoDiffSemanticResultsParameters() const { + return llvm::make_filter_range(getParameters(), + AutoDiffSemanticResultsParameterFilter()); + } + + /// Returns the number of semantic results parameters. + unsigned getNumAutoDiffSemanticResultsParameters() const { + return llvm::count_if(getParameters(), AutoDiffSemanticResultsParameterFilter()); + } + + /// Returns the number of function potential semantic results: + /// * Usual results + /// * Inout parameters + unsigned getNumAutoDiffSemanticResults() const { + return getNumResults() + getNumAutoDiffSemanticResultsParameters(); + } + /// Get the generic signature that the component types are specified /// in terms of, if any. CanGenericSignature getSubstGenericSignature() const { diff --git a/include/swift/SIL/ApplySite.h b/include/swift/SIL/ApplySite.h index a38e25bb3313a..06888dc4d681c 100644 --- a/include/swift/SIL/ApplySite.h +++ b/include/swift/SIL/ApplySite.h @@ -681,6 +681,18 @@ class FullApplySite : public ApplySite { llvm_unreachable("invalid apply kind"); } + AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const { + switch (getKind()) { + case FullApplySiteKind::ApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + case FullApplySiteKind::TryApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + case FullApplySiteKind::BeginApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + } + llvm_unreachable("invalid apply kind"); + } + /// Returns true if \p op is the callee operand of this apply site /// and not an argument operand. bool isCalleeOperand(const Operand &op) const { diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 9a7eb8c8dd49d..b929ba0f67b86 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -2785,6 +2785,25 @@ struct OperandToInoutArgument { using InoutArgumentRange = OptionalTransformRange, OperandToInoutArgument>; +/// Predicate used to filter AutoDiffSemanticResultArgumentRange. +struct OperandToAutoDiffSemanticResultArgument { + ArrayRef paramInfos; + OperandValueArrayRef arguments; + OperandToAutoDiffSemanticResultArgument(ArrayRef paramInfos, + OperandValueArrayRef arguments) + : paramInfos(paramInfos), arguments(arguments) { + assert(paramInfos.size() == arguments.size()); + } + llvm::Optional operator()(size_t i) const { + if (paramInfos[i].isAutoDiffSemanticResult()) + return arguments[i]; + return llvm::None; + } +}; + +using AutoDiffSemanticResultArgumentRange = + OptionalTransformRange, OperandToAutoDiffSemanticResultArgument>; + /// The partial specialization of ApplyInstBase for full applications. /// Adds some methods relating to 'self' and to result types that don't /// make sense for partial applications. @@ -2894,6 +2913,16 @@ class ApplyInstBase impl.getArgumentsWithoutIndirectResults())); } + /// Returns all autodiff semantic result (`@inout`, `@inout_aliasable`) + /// arguments passed to the instruction. + AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const { + auto &impl = asImpl(); + return AutoDiffSemanticResultArgumentRange( + indices(getArgumentsWithoutIndirectResults()), + OperandToAutoDiffSemanticResultArgument(impl.getSubstCalleeConv().getParameters(), + impl.getArgumentsWithoutIndirectResults())); + } + bool hasSemantics(StringRef semanticsString) const { return doesApplyCalleeHaveSemantics(getCallee(), semanticsString); } diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index c7a04a16939e7..3778543fc9acf 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -199,32 +199,28 @@ void autodiff::getFunctionSemanticResults( if (formalResultType->is()) { for (auto elt : formalResultType->castTo()->getElements()) { resultTypes.emplace_back(elt.getType(), resultIdx++, - /*isInout*/ false, /*isWrt*/ false); + /*isParameter*/ false); } } else { resultTypes.emplace_back(formalResultType, resultIdx++, - /*isInout*/ false, /*isWrt*/ false); + /*isParameter*/ false); } } - bool addNonWrts = resultTypes.empty(); - - // Collect wrt `inout` parameters as semantic results - // As an extention, collect all (including non-wrt) inouts as results for - // functions returning void. + // Collect wrt semantic result (`inout`) parameters as + // semantic results auto collectSemanticResults = [&](const AnyFunctionType *functionType, unsigned curryOffset = 0) { for (auto paramAndIndex : enumerate(functionType->getParams())) { - if (!paramAndIndex.value().isInOut()) + if (!paramAndIndex.value().isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIndex.index() + curryOffset; assert(idx < parameterIndices->getCapacity() && "invalid parameter index"); - bool isWrt = parameterIndices->contains(idx); - if (addNonWrts || isWrt) + if (parameterIndices->contains(idx)) resultTypes.emplace_back(paramAndIndex.value().getPlainType(), - resultIdx, /*isInout*/ true, isWrt); + resultIdx, /*isParameter*/ true); resultIdx += 1; } }; diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 36f0a9f95f134..f0385dfdbbed6 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -5558,7 +5558,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - // Accumulate non-inout result tangent spaces. + // Accumulate non-semantic result tangent spaces. SmallVector resultTanTypes, inoutTanTypes; for (auto i : range(originalResults.size())) { auto originalResult = originalResults[i]; @@ -5577,16 +5577,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, std::make_pair(originalResultType, unsigned(originalResult.index))); - if (!originalResult.isInout) + if (!originalResult.isSemanticResultParameter) resultTanTypes.push_back(resultTan->getType()); - else if (originalResult.isInout && !originalResult.isWrtParam) - inoutTanTypes.push_back(resultTan->getType()); } - // Treat non-wrt inouts as semantic results for functions returning Void - if (resultTanTypes.empty()) - resultTanTypes = inoutTanTypes; - // Compute the result linear map function type. FunctionType *linearMapType; switch (kind) { @@ -5597,11 +5591,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, T1, ...) -> R` // - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` // - // Case 2: original function has a non-wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, ...) -> T1.Tan` - // - // Case 3: original function has a wrt `inout` parameter. + // Case 2: original function has a wrt `inout` parameter. // - Original: `(T0, inout T1, ...) -> Void` // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` SmallVector differentialParams; @@ -5648,15 +5638,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, T1, ...) -> R` // - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` // - // Case 2: original function has a non-wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - // - // Case 3: original function has wrt `inout` parameters. + // Case 2: original function has wrt `inout` parameters. // - Original: `(T0, inout T1, ...) -> R` // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` SmallVector pullbackResults; - SmallVector inoutParams; + SmallVector semanticResultParams; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -5669,10 +5655,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( NonDifferentiableDifferentiabilityParameter, std::make_pair(paramType, i)); - if (diffParam.isInOut()) { + if (diffParam.isAutoDiffSemanticResult()) { if (paramType->isVoid()) continue; - inoutParams.push_back(diffParam); + semanticResultParams.push_back(diffParam); continue; } pullbackResults.emplace_back(paramTan->getType()); @@ -5693,15 +5679,15 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( pullbackParams.push_back(AnyFunctionType::Param( resultTanType, Identifier(), flags)); } - // Then append inout parameters. - for (auto i : range(inoutParams.size())) { - auto inoutParam = inoutParams[i]; - auto inoutParamType = inoutParam.getPlainType(); - auto inoutParamTan = - inoutParamType->getAutoDiffTangentSpace(lookupConformance); + // Then append semantic result parameters. + for (auto i : range(semanticResultParams.size())) { + auto semanticResultParam = semanticResultParams[i]; + auto semanticResultParamType = semanticResultParam.getPlainType(); + auto semanticResultParamTan = + semanticResultParamType->getAutoDiffTangentSpace(lookupConformance); auto flags = ParameterTypeFlags().withInOut(true); pullbackParams.push_back(AnyFunctionType::Param( - inoutParamTan->getType(), Identifier(), flags)); + semanticResultParamTan->getType(), Identifier(), flags)); } // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; @@ -5709,6 +5695,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( break; } } + assert(linearMapType && "Expected linear map type"); return linearMapType; } diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 6853badd31b2c..ad9f32baedf01 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -237,9 +237,11 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { if (resultAndIndex.value().getDifferentiability() != SILResultDifferentiability::NotDifferentiable) resultIndices.push_back(resultAndIndex.index()); + + auto numSemanticResults = getNumResults(); - // Check `inout` parameters. - for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters())) + // Check semantic results (`inout`) parameters. + for (auto resultParamAndIndex : enumerate(getAutoDiffSemanticResultsParameters())) // Currently, an `inout` parameter can either be: // 1. Both a differentiability parameter and a differentiability result. // 2. `@noDerivative`: neither a differentiability parameter nor a @@ -251,16 +253,13 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { // cases, so supporting it is a non-goal. // // See TF-1305 for solution ideas. For now, `@noDerivative` `inout` - // parameters are not treated as differentiability results, unless the - // original function has no formal results, in which case all `inout` - // parameters are treated as differentiability results. - if (resultIndices.empty() || - inoutParamAndIndex.value().getDifferentiability() != + // parameters are not treated as differentiability results. + if (resultParamAndIndex.value().getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) - resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); + resultIndices.push_back(getNumResults() + resultParamAndIndex.index()); + + numSemanticResults += getNumAutoDiffSemanticResultsParameters(); - auto numSemanticResults = - getNumResults() + getNumIndirectMutatingParameters(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); } @@ -369,18 +368,19 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy, /// Collects the semantic results of the given function type in /// `originalResults`. The semantic results are formal results followed by -/// `inout` parameters, in type order. +/// semantic result parameters, in type order. static void -getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, +getSemanticResults(SILFunctionType *functionType, + IndexSubset *parameterIndices, SmallVectorImpl &originalResults) { // Collect original formal results. originalResults.append(functionType->getResults().begin(), functionType->getResults().end()); - // Collect original `inout` parameters. + // Collect original semantic result parameters. for (auto i : range(functionType->getNumParameters())) { auto param = functionType->getParameters()[i]; - if (!param.isIndirectMutating()) + if (!param.isAutoDiffSemanticResult()) continue; if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect); @@ -597,23 +597,25 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialResults.push_back({resultTanType, resultConv}); continue; } - // Handle original `inout` parameters. - auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); - auto inoutParamIt = std::next( - originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); + // Handle original semantic result parameters. + auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); + auto resultParamIt = std::next( + originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + resultParamIndex); auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); - // If the original `inout` parameter is a differentiability parameter, then - // it already has a corresponding differential parameter. Skip adding a - // corresponding differential result. + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + // If the original semantic result parameter is a differentiability + // parameter, then it already has a corresponding differential + // parameter. Skip adding a corresponding differential result. if (parameterIndices->contains(paramIndex)) continue; - auto inoutParam = originalFnTy->getParameters()[paramIndex]; - auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap( - inoutParam.getInterfaceType(), lookupConformance, + + auto resultParam = originalFnTy->getParameters()[paramIndex]; + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, substGenericParams, substReplacements, ctx); - differentialResults.push_back( - {inoutParamTanType, ResultConvention::Indirect}); + differentialResults.emplace_back(resultParamTanType, + ResultConvention::Indirect); } SubstitutionMap substitutions; @@ -734,28 +736,29 @@ static CanSILFunctionType getAutoDiffPullbackType( ->getAutoDiffTangentSpace(lookupConformance) ->getCanonicalType(), origRes.getConvention()); - pullbackParams.push_back({resultTanType, paramConv}); + pullbackParams.emplace_back(resultTanType, paramConv); continue; } - // Handle `inout` parameters. - auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); - auto inoutParamIt = std::next( - originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); + // Handle original semantic result parameters. + auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); + auto resultParamIt = std::next( + originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + resultParamIndex); auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); - auto inoutParam = originalFnTy->getParameters()[paramIndex]; + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + auto resultParam = originalFnTy->getParameters()[paramIndex]; // The pullback parameter convention depends on whether the original `inout` // parameter is a differentiability parameter. // - If yes, the pullback parameter convention is `@inout`. // - If no, the pullback parameter convention is `@in_guaranteed`. - auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap( - inoutParam.getInterfaceType(), lookupConformance, - substGenericParams, substReplacements, ctx); - bool isWrtInoutParameter = parameterIndices->contains(paramIndex); - auto paramTanConvention = isWrtInoutParameter - ? inoutParam.getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - pullbackParams.push_back({inoutParamTanType, paramTanConvention}); + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, + substGenericParams, substReplacements, ctx); + ParameterConvention paramTanConvention = resultParam.getConvention(); + if (!parameterIndices->contains(paramIndex)) + paramTanConvention = ParameterConvention::Indirect_In_Guaranteed; + + pullbackParams.emplace_back(resultParamTanType, paramTanConvention); } // Collect pullback results. @@ -763,9 +766,9 @@ static CanSILFunctionType getAutoDiffPullbackType( getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams); SmallVector pullbackResults; for (auto ¶m : diffParams) { - // Skip `inout` parameters, which semantically behave as original results - // and always appear as pullback parameters. - if (param.isIndirectMutating()) + // Skip semantic result parameters, which semantically behave as original + // results and always appear as pullback parameters. + if (param.isAutoDiffSemanticResult()) continue; auto paramTanType = getAutoDiffTangentTypeForLinearMap( param.getInterfaceType(), lookupConformance, @@ -898,6 +901,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( origTypeOfAbstraction, TC); break; } + // Compute the derivative function parameters. SmallVector newParameters; newParameters.reserve(constrainedOriginalFnTy->getNumParameters()); @@ -4091,6 +4095,40 @@ static llvm::cl::opt DisableConstantInfoCache("sil-disable-typelowering-constantinfo-cache", llvm::cl::init(false)); +static IndexSubset * +getLoweredResultIndices(const SILFunctionType *functionType, + const IndexSubset *parameterIndices) { + SmallVector resultIndices; + + // Check formal results. + for (auto resultAndIndex : enumerate(functionType->getResults())) + if (resultAndIndex.value().getDifferentiability() != + SILResultDifferentiability::NotDifferentiable) + resultIndices.push_back(resultAndIndex.index()); + + auto numResults = functionType->getNumResults(); + + // Collect semantic result parameters. + unsigned semResultParamIdx = 0; + for (auto resultParamAndIndex + : enumerate(functionType->getParameters())) { + if (!resultParamAndIndex.value().isAutoDiffSemanticResult()) + continue; + + if (resultParamAndIndex.value().getDifferentiability() != + SILParameterDifferentiability::NotDifferentiable && + parameterIndices->contains(resultParamAndIndex.index())) + resultIndices.push_back(numResults + semResultParamIdx); + semResultParamIdx += 1; + } + + numResults += semResultParamIdx; + + return IndexSubset::get(functionType->getASTContext(), + numResults, resultIndices); +} + + const SILConstantInfo & TypeConverter::getConstantInfo(TypeExpansionContext expansion, SILDeclRef constant) { @@ -4149,11 +4187,9 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion, // Use it to compute lowered derivative function type. auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), formalInterfaceType); - auto numResults = - origFnConstantInfo.SILFnType->getNumResults() + - origFnConstantInfo.SILFnType->getNumIndirectMutatingParameters(); - auto *loweredResultIndices = IndexSubset::getDefault( - M.getASTContext(), numResults, /*includeAll*/ true); + auto *loweredResultIndices + = getLoweredResultIndices(origFnConstantInfo.SILFnType, loweredParamIndices); + silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType( loweredParamIndices, loweredResultIndices, derivativeId->getKind(), *this, LookUpConformanceInModule(&M)); diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index 0cd474d74b521..2c1fbc5876bbf 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -138,8 +138,8 @@ void DifferentiableActivityInfo::propagateVaried( if (isVaried(operand->get(), i)) { for (auto indRes : applySite.getIndirectSILResults()) propagateVariedInwardsThroughProjections(indRes, i); - for (auto inoutArg : applySite.getInoutArguments()) - propagateVariedInwardsThroughProjections(inoutArg, i); + for (auto semresArg : applySite.getAutoDiffSemanticResultArguments()) + propagateVariedInwardsThroughProjections(semresArg, i); // Propagate variedness to apply site direct results. forEachApplyDirectResult(applySite, [&](SILValue directResult) { setVariedAndPropagateToUsers(directResult, i); diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index a8f749b3187c3..aec5f1e09b34a 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -147,11 +147,11 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function, for (auto &resInfo : convs.getResults()) results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] : indResults[indResIdx++]); - // Treat `inout` parameters as semantic results. - // Append `inout` parameters after formal results. + // Treat semantic result parameters as semantic results. + // Append them` parameters after formal results. for (auto i : range(convs.getNumParameters())) { auto paramInfo = convs.getParameters()[i]; - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; auto *argument = function.getArgumentsWithoutIndirectResults()[i]; results.push_back(argument); @@ -190,6 +190,7 @@ void collectMinimalIndicesForFunctionCall( SmallVectorImpl &resultIndices) { auto calleeFnTy = ai->getSubstCalleeType(); auto calleeConvs = ai->getSubstCalleeConv(); + // Parameter indices are indices (in the callee type signature) of parameter // arguments that are varied or are arguments. // Record all parameter indices in type order. @@ -199,6 +200,7 @@ void collectMinimalIndicesForFunctionCall( paramIndices.push_back(currentParamIdx); ++currentParamIdx; } + // Result indices are indices (in the callee type signature) of results that // are useful. SmallVector directResults; @@ -226,22 +228,21 @@ void collectMinimalIndicesForFunctionCall( ++indResIdx; } } - // Record all `inout` parameters as results. - auto inoutParamResultIndex = calleeFnTy->getNumResults(); + + // Record all semantic result parameters as results. + auto semanticResultParamResultIndex = calleeFnTy->getNumResults(); for (auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) { auto ¶m = paramAndIdx.value(); - if (!param.isIndirectMutating()) + if (!param.isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); - auto inoutArg = ai->getArgument(idx); - results.push_back(inoutArg); - resultIndices.push_back(inoutParamResultIndex++); + results.push_back(ai->getArgument(idx)); + resultIndices.push_back(semanticResultParamResultIndex++); } + // Make sure the function call has active results. #ifndef NDEBUG - auto numResults = calleeFnTy->getNumResults() + - calleeFnTy->getNumIndirectMutatingParameters(); - assert(results.size() == numResults); + assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults()); assert(llvm::any_of(results, [&](SILValue result) { return activityInfo.isActive(result, parentConfig); })); diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index f469607de2758..9797c3e982381 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -177,7 +177,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { return activityInfo.isActive(res, config); }); - bool hasActiveInoutArgument = false; + bool hasActiveSemanticResultArgument = false; bool hasActiveArguments = false; auto numIndirectResults = ai->getNumIndirectResults(); for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) { @@ -186,13 +186,13 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { hasActiveArguments = true; auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( numIndirectResults + argIdx); - if (paramInfo.isIndirectMutating()) - hasActiveInoutArgument = true; + if (paramInfo.isAutoDiffSemanticResult()) + hasActiveSemanticResultArgument = true; } } if (!hasActiveArguments) return {}; - if (!hasActiveResults && !hasActiveInoutArgument) + if (!hasActiveResults && !hasActiveSemanticResultArgument) return {}; // Compute differentiability parameters. @@ -213,9 +213,8 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); } // Compute differentiability results. - auto numResults = remappedOrigFnSubstTy->getNumResults() + - remappedOrigFnSubstTy->getNumIndirectMutatingParameters(); - auto *results = IndexSubset::get(original->getASTContext(), numResults, + auto *results = IndexSubset::get(original->getASTContext(), + remappedOrigFnSubstTy->getNumAutoDiffSemanticResults(), activeResultIndices); // Create autodiff indices for the `apply` instruction. AutoDiffConfig applyConfig(parameters, results); @@ -234,10 +233,11 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { for (auto resultIndex : applyConfig.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= origFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - origFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResultArgIdx = resultIndex - origFnTy->getNumResults(); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx); + remappedResultType = semanticResultArg->getType(); } else { remappedResultType = origFnTy->getResults()[resultIndex].getSILStorageInterfaceType(); @@ -277,8 +277,9 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { SmallVector params; for (auto ¶m : silFnTy->getParameters()) { ParameterTypeFlags flags; - if (param.isIndirectMutating()) + if (param.isAutoDiffSemanticResult()) flags = flags.withInOut(true); + params.push_back( AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags)); } diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index c7eb9aa769424..f6ef9a33fdfb7 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -864,6 +864,7 @@ class PullbackCloner::Implementation final /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) void visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); + // Skip `array.uninitialized_intrinsic` applications, which have special // `store` and `copy_addr` support. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) @@ -901,11 +902,11 @@ class PullbackCloner::Implementation final }); SmallVector origAllResults; collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); - // Append `inout` arguments after original results. + // Append semantic result arguments after original results. for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( ai->getNumIndirectResults() + paramIdx); - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; origAllResults.push_back( ai->getArgumentsWithoutIndirectResults()[paramIdx]); @@ -981,10 +982,10 @@ class PullbackCloner::Implementation final auto allResultsIt = allResults.begin(); for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); - // Skip adjoint accumulation for `inout` arguments. + // Skip adjoint accumulation for semantic results arguments. auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( ai->getNumIndirectResults() + i); - if (paramInfo.isIndirectMutating()) + if (paramInfo.isAutoDiffSemanticResult()) continue; auto tan = *allResultsIt++; if (tan->getType().isAddress()) { @@ -2036,6 +2037,7 @@ bool PullbackCloner::Implementation::run() { // the adjoint buffer of the original result. auto seedParamInfo = pullback.getLoweredFunctionType()->getParameters()[seedIndex]; + if (seedParamInfo.isIndirectInOut()) { setAdjointBuffer(originalExitBlock, origResult, seed); } @@ -2123,7 +2125,7 @@ bool PullbackCloner::Implementation::run() { // Collect differentiation parameter adjoints. // Do a first pass to collect non-inout values. for (auto i : getConfig().parameterIndices->getIndices()) { - if (!conv.getParameters()[i].isIndirectMutating()) { + if (!conv.getParameters()[i].isAutoDiffSemanticResult()) { addRetElt(i); } } @@ -2136,14 +2138,14 @@ bool PullbackCloner::Implementation::run() { const auto &pullbackConv = pullback.getConventions(); SmallVector pullbackInOutArgs; for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) { - if (pullbackConv.getParameters()[pullbackArg.index()].isIndirectMutating()) + if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult()) pullbackInOutArgs.push_back(pullbackArg.value()); } unsigned pullbackInoutArgumentIdx = 0; for (auto i : getConfig().parameterIndices->getIndices()) { // Skip non-inout parameters. - if (!conv.getParameters()[i].isIndirectMutating()) + if (!conv.getParameters()[i].isAutoDiffSemanticResult()) continue; // For functions with multiple basic blocks, accumulation is needed diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 487bce2929183..116ea3a2d228a 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -365,7 +365,9 @@ getOrCreateSubsetParametersThunkForLinearMap( const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, ADContext &adContext) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for " << linearMapType + << "Getting a subset parameters thunk for " + << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") + << " linear map " << linearMapType << " from " << actualConfig << " to " << desiredConfig << '\n'); assert(!linearMapType->getCombinedSubstitutions()); @@ -539,10 +541,10 @@ getOrCreateSubsetParametersThunkForLinearMap( unsigned pullbackResultIndex = 0; for (unsigned i : actualConfig.parameterIndices->getIndices()) { auto origParamInfo = origFnType->getParameters()[i]; - // Skip original `inout` parameters. All non-indirect-result pullback - // arguments (including `inout` arguments) are appended to `arguments` + // Skip original semantic result parameters. All non-indirect-result pullback + // arguments (including semantic result` arguments) are appended to `arguments` // later. - if (origParamInfo.isIndirectMutating()) + if (origParamInfo.isAutoDiffSemanticResult()) continue; auto resultInfo = linearMapType->getResults()[pullbackResultIndex]; assert(pullbackResultIndex < linearMapType->getNumResults()); @@ -619,16 +621,18 @@ getOrCreateSubsetParametersThunkForLinearMap( extractAllElements(ai, builder, pullbackDirectResults); SmallVector allResults; collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults); - // Collect pullback `inout` arguments in type order. - unsigned inoutArgIdx = 0; + // Collect pullback semantic result arguments in type order. + unsigned semanticResultArgIdx = 0; SILFunctionConventions origConv(origFnType, thunk->getModule()); for (auto paramIdx : actualConfig.parameterIndices->getIndices()) { auto paramInfo = origConv.getParameters()[paramIdx]; - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; - auto inoutArg = *std::next(ai->getInoutArguments().begin(), inoutArgIdx++); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx++); unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx); - allResults.insert(allResults.begin() + mappedParamIdx, inoutArg); + allResults.insert(allResults.begin() + mappedParamIdx, semanticResultArg); } assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() && "Number of pullback results should match number of differentiability " @@ -668,8 +672,10 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, ADContext &adContext) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for derivative function " - << derivativeFn << " of the original function " << origFnOperand + << "Getting a subset parameters thunk for derivative " + << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") + << " function " << derivativeFn + << " of the original function " << origFnOperand << " from " << actualConfig << " to " << desiredConfig << '\n'); auto origFnType = origFnOperand->getType().castTo(); @@ -823,9 +829,7 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( SILType::getPrimitiveObjectType(linearMapTargetType), /*withoutActuallyEscaping*/ false); } - assert(origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters() > - 0); + assert(origFnType->getNumAutoDiffSemanticResults() > 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { directResults.push_back(thunkedLinearMap); diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index f91175c655e91..7cb715b278553 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -449,24 +449,22 @@ class VJPCloner::Implementation final activeResultIndices); assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); assert(!activeResultIndices.empty() && "Result indices cannot be empty"); - LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; llvm::interleave( activeParamIndices.begin(), activeParamIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}, results={"; llvm::interleave( + s << "), results=("; llvm::interleave( activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}\n";); + s << ")\n";); // Form expected indices. - auto numSemanticResults = - ai->getSubstCalleeType()->getNumResults() + - ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); AutoDiffConfig config( IndexSubset::get(getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices), - IndexSubset::get(getASTContext(), numSemanticResults, + IndexSubset::get(getASTContext(), + ai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), activeResultIndices)); // Emit the VJP. @@ -537,10 +535,11 @@ class VJPCloner::Implementation final for (auto resultIndex : config.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults(); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx); + remappedResultType = semanticResultArg->getType(); } else { remappedResultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); @@ -891,55 +890,57 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { auto config = witness->getConfig(); // Add pullback parameters based on original result indices. - SmallVector inoutParamIndices; + SmallVector semanticResultParamIndices; for (auto i : range(origTy->getNumParameters())) { auto origParam = origParams[i]; - if (!origParam.isIndirectInOut()) + if (!origParam.isAutoDiffSemanticResult()) continue; - inoutParamIndices.push_back(i); + semanticResultParamIndices.push_back(i); } + for (auto resultIndex : config.resultIndices->getIndices()) { // Handle formal result. if (resultIndex < origTy->getNumResults()) { auto origResult = origTy->getResults()[resultIndex]; origResult = origResult.getWithInterfaceType( origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); - pbParams.push_back(getTangentParameterInfoForOriginalResult( + auto paramInfo = getTangentParameterInfoForOriginalResult( origResult.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), - origResult.getConvention())); + origResult.getConvention()); + pbParams.push_back(paramInfo); continue; } - // Handle `inout` parameter. + + // Handle semantic result parameter. unsigned paramIndex = 0; - unsigned inoutParamIndex = 0; + unsigned resultParamIndex = 0; for (auto i : range(origTy->getNumParameters())) { auto origParam = origTy->getParameters()[i]; - if (!origParam.isIndirectMutating()) { + if (!origParam.isAutoDiffSemanticResult()) { ++paramIndex; continue; } - if (inoutParamIndex == resultIndex - origTy->getNumResults()) + if (resultParamIndex == resultIndex - origTy->getNumResults()) break; ++paramIndex; - ++inoutParamIndex; + ++resultParamIndex; } - auto inoutParam = origParams[paramIndex]; - auto origResult = inoutParam.getWithInterfaceType( - inoutParam.getInterfaceType()->getReducedType(witnessCanGenSig)); - auto inoutParamTanConvention = - config.isWrtParameter(paramIndex) - ? inoutParam.getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - SILParameterInfo inoutParamTanParam( - origResult.getInterfaceType() - ->getAutoDiffTangentSpace(lookupConformance) - ->getType() - ->getReducedType(witnessCanGenSig), - inoutParamTanConvention); - pbParams.push_back(inoutParamTanParam); + auto resultParam = origParams[paramIndex]; + auto origResult = resultParam.getWithInterfaceType( + resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); + + auto resultParamTanConvention = resultParam.getConvention(); + if (!config.isWrtParameter(paramIndex)) + resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; + + pbParams.emplace_back(origResult.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getReducedType(witnessCanGenSig), + resultParamTanConvention); } if (pullbackInfo.hasHeapAllocatedContext()) { @@ -961,7 +962,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { // Add pullback results for the requested wrt parameters. for (auto i : config.parameterIndices->getIndices()) { auto origParam = origParams[i]; - if (origParam.isIndirectMutating()) + if (origParam.isAutoDiffSemanticResult()) continue; origParam = origParam.getWithInterfaceType( origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); @@ -997,6 +998,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { original->isRuntimeAccessible()); pullback->setDebugScope(new (module) SILDebugScope(original->getLocation(), pullback)); + return pullback; } diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 8a1f59a260a76..60f57f80ebe85 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -544,11 +544,11 @@ emitDerivativeFunctionReference( for (auto resultIndex : desiredResultIndices->getIndices()) { SILType resultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutParamIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutParam = - *std::next(originalFnTy->getIndirectMutatingParameters().begin(), - inoutParamIdx); - resultType = inoutParam.getSILStorageInterfaceType(); + auto semanticResultParamIdx = resultIndex - originalFnTy->getNumResults(); + auto semanticResultParam = + *std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + semanticResultParamIdx); + resultType = semanticResultParam.getSILStorageInterfaceType(); } else { resultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); diff --git a/test/AutoDiff/SILGen/inout_differentiability_witness.swift b/test/AutoDiff/SILGen/inout_differentiability_witness.swift index e49b4e92a947d..f146f966a2b14 100644 --- a/test/AutoDiff/SILGen/inout_differentiability_witness.swift +++ b/test/AutoDiff/SILGen/inout_differentiability_witness.swift @@ -17,7 +17,7 @@ func test3(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return @differentiable(reverse, wrt: y) func test4(x: Int, y: inout DiffableStruct, z: Float) -> Void { } -@differentiable(reverse, wrt: z) +@differentiable(reverse, wrt: (y, z)) func test5(x: Int, y: inout DiffableStruct, z: Float) -> Void { } @differentiable(reverse, wrt: (y, z)) @@ -48,9 +48,9 @@ func test6(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return // CHECK: } // CHECK-LABEL: differentiability witness for test5(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { -// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (Float) -> @out DiffableStruct.TangentVector -// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@in_guaranteed DiffableStruct.TangentVector) -> Float +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float // CHECK: } // CHECK-LABEL: differentiability witness for test6(x:y:z:) diff --git a/test/AutoDiff/SILGen/witness_table.swift b/test/AutoDiff/SILGen/witness_table.swift index 5928bd844d09f..4f631f4067dc2 100644 --- a/test/AutoDiff/SILGen/witness_table.swift +++ b/test/AutoDiff/SILGen/witness_table.swift @@ -12,7 +12,7 @@ protocol Protocol: Differentiable { @differentiable(reverse) var property: Float { get set } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) subscript(_ x: Float, _ y: Float) -> Float { get set } } @@ -82,22 +82,22 @@ struct Struct: Protocol { // CHECK: apply [[VJP_FN]] // CHECK: } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) subscript(_ x: Float, _ y: Float) -> Float { get { x } set {} } - // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, @in_guaranteed τ_0_0) -> Float for ) // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float - // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]] // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] // CHECK: apply [[JVP_FN]] // CHECK: } - // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, @out τ_0_0) for ) // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float - // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]] // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] // CHECK: apply [[VJP_FN]] // CHECK: } @@ -118,10 +118,10 @@ struct Struct: Protocol { // CHECK-NEXT: method #Protocol.property!setter.vjp.SS.: (inout Self) -> (Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW_vjp_SS // CHECK-NEXT: method #Protocol.property!modify: (inout Self) -> () -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvMTW // CHECK-NEXT: method #Protocol.subscript!getter: (Self) -> (Float, Float) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW -// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SU -// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU +// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUS.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS +// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUS.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS // CHECK-NEXT: method #Protocol.subscript!setter: (inout Self) -> (Float, Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW -// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUU.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUU -// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUU.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUU +// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUS.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUS +// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUS.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUS // CHECK-NEXT: method #Protocol.subscript!modify: (inout Self) -> (Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftciMTW // CHECK: } diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 02f75ebba20f4..f2cf4d0d9bf0c 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -407,23 +407,21 @@ func testArrayUninitializedIntrinsicApplyIndirectResult(_ x: T, _ y: T) -> [W struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } -// CHECK-LABEL: [AD] Activity info for ${{.*}}3MutV14mutatingMethodyyACF at parameter indices (0) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}3MutV14mutatingMethodyyACF at parameter indices (0, 1) and result indices (0) // CHECK: [VARIED] %0 = argument of bb0 : $Mut -// CHECK: [USEFUL] %1 = argument of bb0 : $*Mut +// CHECK: [ACTIVE] %1 = argument of bb0 : $*Mut -// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as -// active. -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) { nonactive.mutatingMethod(x) nonactive = x } -// CHECK-LABEL: [AD] Activity info for ${{.*}}17nonActiveInoutArgyyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}17nonActiveInoutArgyyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = begin_access [modify] [static] %0 : $*Mut @@ -449,14 +447,14 @@ func activeInoutArgMutatingMethod(_ x: Mut) -> Mut { // CHECK: [ACTIVE] %11 = begin_access [read] [static] %2 : $*Mut // CHECK: [ACTIVE] %12 = load [trivial] %11 : $*Mut -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { var result = nonactive result.mutatingMethod(x) nonactive = result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}31activeInoutArgMutatingMethodVaryyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}31activeInoutArgMutatingMethodVaryyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = alloc_stack $Mut, var, name "result" @@ -470,14 +468,14 @@ func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { // CHECK: [ACTIVE] %15 = begin_access [modify] [static] %0 : $*Mut // CHECK: [NONE] %19 = tuple () -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { var result = (nonactive, x) result.0.mutatingMethod(result.0) nonactive = result.0 } -// CHECK-LABEL: [AD] Activity info for ${{.*}}33activeInoutArgMutatingMethodTupleyyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}33activeInoutArgMutatingMethodTupleyyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = alloc_stack $(Mut, Mut), var, name "result" @@ -499,39 +497,39 @@ func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { // Check `inout` arguments. @differentiable(reverse) -func activeInoutArg(_ x: Float) -> Float { +func activeInoutArg(_ x: inout Float) -> Float { var result = x result += x return result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at parameter indices (0) and result indices (0, 1) +// CHECK: [ACTIVE] %0 = argument of bb0 : $*Float // CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result" -// CHECK: [ACTIVE] %5 = begin_access [modify] [static] %2 : $*Float +// CHECK: [ACTIVE] %10 = begin_access [modify] [static] %2 : $*Float // CHECK: [NONE] // function_ref static Float.+= infix(_:_:) -// CHECK: [NONE] %7 = apply %6(%5, %0, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () -// CHECK: [ACTIVE] %9 = begin_access [read] [static] %2 : $*Float -// CHECK: [ACTIVE] %10 = load [trivial] %9 : $*Float +// CHECK: [NONE] %12 = apply %11(%10, %8, %6) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [ACTIVE] %14 = begin_access [read] [static] %2 : $*Float +// CHECK: [ACTIVE] %15 = load [trivial] %14 : $*Float @differentiable(reverse) -func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float { +func activeInoutArgNonactiveInitialResult(_ x: inout Float) -> Float { var result: Float = 1 result += x return result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at parameter indices (0) and result indices (0, 1) +// CHECK: [ACTIVE] %0 = argument of bb0 : $*Float // CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result" // CHECK: [NONE] // function_ref Float.init(_builtinIntegerLiteral:) // CHECK: [USEFUL] %6 = apply %5(%3, %4) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // CHECK: [USEFUL] %8 = metatype $@thin Float.Type -// CHECK: [ACTIVE] %9 = begin_access [modify] [static] %2 : $*Float +// CHECK: [ACTIVE] %12 = begin_access [modify] [static] %2 : $*Float // CHECK: [NONE] // function_ref static Float.+= infix(_:_:) -// CHECK: [NONE] %11 = apply %10(%9, %0, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () -// CHECK: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float -// CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float +// CHECK: [NONE] %14 = apply %13(%12, %10, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [ACTIVE] %16 = begin_access [read] [static] %2 : $*Float +// CHECK: [ACTIVE] %17 = load [trivial] %16 : $*Float //===----------------------------------------------------------------------===// // Throwing function differentiation (`try_apply`) diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index 0f106c55bff56..e149bcb3116db 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -400,11 +400,11 @@ func activeInoutParamControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func nonActiveInoutParam(_ nonactive: inout Mut, _ x: Mut) { nonactive.mutatingMethod(x) } @@ -416,14 +416,14 @@ func activeInoutParamMutatingMethod(_ x: Mut) -> Mut { return result } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutParamMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { var result = nonactive result.mutatingMethod(x) nonactive = result } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { var result = (nonactive, x) result.0.mutatingMethod(result.0) diff --git a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift index 5fb4d13407bba..cd29f9019c45c 100644 --- a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift @@ -89,7 +89,7 @@ func activeInoutParamControlFlow(_ array: [Float]) -> Float { struct X: Differentiable { var x: Float - @differentiable(reverse, wrt: y) + @differentiable(reverse, wrt: (self, y)) mutating func mutate(_ y: X) { self.x = y.x } } @@ -104,7 +104,7 @@ func activeMutatingMethod(_ x: Float) -> Float { struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index f3e26e8a32b13..5bd615d1188a1 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -769,7 +769,8 @@ struct InoutParameters: Differentiable { } extension InoutParameters { - // expected-note @+1 4 {{'staticMethod(_:rhs:)' defined here}} + // expected-note @+2 {{'staticMethod(_:rhs:)' defined here}} + // expected-note @+1 {{'staticMethod(_:rhs:)' defined here}} static func staticMethod(_ lhs: inout Self, rhs: Self) {} // Test wrt `inout` parameter. @@ -800,33 +801,34 @@ extension InoutParameters { // Test non-wrt `inout` parameter. + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func vjpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( value: Void, pullback: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}} + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func vjpNotWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> ( - // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, pullback: (inout TangentVector) -> TangentVector ) { fatalError() } + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( value: Void, differential: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}} + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( - // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, differential: (inout TangentVector) -> TangentVector ) { fatalError() } } extension InoutParameters { - // expected-note @+1 4 {{'mutatingMethod' defined here}} + // expected-note @+2 {{'mutatingMethod' defined here}} + // expected-note @+1 {{'mutatingMethod' defined here}} mutating func mutatingMethod(_ other: Self) {} // Test wrt `inout` `self` parameter. @@ -857,27 +859,27 @@ extension InoutParameters { // Test non-wrt `inout` `self` parameter. + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func vjpNotWrtInout(_ other: Self) -> ( value: Void, pullback: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}} + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func vjpNotWrtInoutMismatch(_ other: Self) -> ( - // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, pullback: (inout TangentVector) -> TangentVector ) { fatalError() } + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func jvpNotWrtInout(_ other: Self) -> ( value: Void, differential: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}} + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func jvpNotWrtInoutMismatch(_ other: Self) -> ( - // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, differential: (TangentVector, TangentVector) -> Void ) { fatalError() } } diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 53299fe2f639b..eefa81f12a70b 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -681,7 +681,7 @@ struct InoutParameters: Differentiable { } extension NonDiffableStruct { - // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'NonDiffableStruct' does not conform to 'Differentiable'}} + // expected-error @+1 {{cannot differentiate void function 'nondiffResult(x:y:z:)'}} @differentiable(reverse) static func nondiffResult(x: Int, y: inout NonDiffableStruct, z: Float) {} diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift b/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift index 62a7d2ea018cf..5ab20d9360a82 100644 --- a/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift +++ b/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift @@ -4,15 +4,16 @@ import _Differentiation // https://github.com/apple/swift/issues/55745 // Test protocol witness thunk for `@differentiable` protocol requirement, where -// the required method has a non-wrt `inout` parameter that should be treated as -// a differentiability result. +// the required method has a non-wrt `inout` parameter. protocol Proto { + // expected-error @+1 {{cannot differentiate void function 'method(x:y:)'}} @differentiable(reverse, wrt: x) func method(x: Float, y: inout Float) } struct Struct: Proto { + // expected-error @+1 {{cannot differentiate void function 'method(x:y:)'}} @differentiable(reverse, wrt: x) func method(x: Float, y: inout Float) { y = y * x diff --git a/test/AutoDiff/validation-test/forward_mode_simple.swift b/test/AutoDiff/validation-test/forward_mode_simple.swift index 0b4cb384e368c..90e52c4e71650 100644 --- a/test/AutoDiff/validation-test/forward_mode_simple.swift +++ b/test/AutoDiff/validation-test/forward_mode_simple.swift @@ -1320,29 +1320,14 @@ ForwardModeTests.test("ForceUnwrapping") { } ForwardModeTests.test("NonVariedResult") { - @differentiable(reverse, wrt: x) - func nonWrtInoutParam(_ x: T, _ y: inout T) { - y = x - } - @differentiable(reverse) func wrtInoutParam(_ x: T, _ y: inout T) { y = x } - @differentiable(reverse, wrt: x) - func nonWrtInoutParamNonVaried(_ x: T, _ y: inout T) {} - - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func wrtInoutParamNonVaried(_ x: T, _ y: inout T) {} - @differentiable(reverse) - func variedResultTracked(_ x: Tracked) -> Tracked { - var result: Tracked = 0 - nonWrtInoutParam(x, &result) - return result - } - @differentiable(reverse) func variedResultTracked2(_ x: Tracked) -> Tracked { var result: Tracked = 0 @@ -1352,13 +1337,6 @@ ForwardModeTests.test("NonVariedResult") { @differentiable(reverse) func nonVariedResultTracked(_ x: Tracked) -> Tracked { - var result: Tracked = 0 - nonWrtInoutParamNonVaried(x, &result) - return result - } - - @differentiable(reverse) - func nonVariedResultTracked2(_ x: Tracked) -> Tracked { // expected-warning @+1 {{variable 'result' was never mutated}} var result: Tracked = 0 return result diff --git a/test/AutoDiff/validation-test/inout_parameters.swift b/test/AutoDiff/validation-test/inout_parameters.swift index 800b373ffcec0..95ede7ee938c6 100644 --- a/test/AutoDiff/validation-test/inout_parameters.swift +++ b/test/AutoDiff/validation-test/inout_parameters.swift @@ -191,26 +191,27 @@ InoutParameterAutoDiffTests.test("InoutClassParameter") { } } -// https://github.com/apple/swift/issues/55745 -// Test function with non-wrt `inout` parameter, which should be -// treated as a differentiability result. +// Test function with wrt `inout` parameter, which should be treated as a differentiability result. +// Original issue https://github.com/apple/swift/issues/55745 deals with non-wrt `inout` which +// we explicitly disallow now + protocol P_55745 { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func method(_ x: Float, _ y: inout Float) - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func genericMethod(_ x: T, _ y: inout T) } InoutParameterAutoDiffTests.test("non-wrt inout parameter") { struct Struct: P_55745 { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func method(_ x: Float, _ y: inout Float) { y = y * x } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func genericMethod(_ x: T, _ y: inout T) { y = x }