diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index d70d91ebbdfd4..6d4c8775cd9bb 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -246,16 +246,16 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s, return s; } -/// A semantic function result type: either a formal function result type or -/// an `inout` parameter type. Used in derivative function type calculation. +/// A semantic function result type: either a formal function result type or a +/// semantic result (an `inout` or class-bound) parameter type. Used in +/// derivative function type calculation. struct AutoDiffSemanticFunctionResultType { Type type; unsigned index : 30; - bool isInout : 1; - bool isWrtParam : 1; + bool isSemanticResultParameter : 1; - AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt) - : type(t), index(idx), isInout(inout), isWrtParam(wrt) { } + AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param) + : type(t), index(idx), isSemanticResultParameter(param) { } }; /// Key for caching SIL derivative function types. diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index c19f5a7a9706f..ed4f8e29bb09b 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3192,6 +3192,12 @@ class AnyFunctionType : public TypeBase { /// Whether the parameter is marked '@noDerivative'. bool isNoDerivative() const { return Flags.isNoDerivative(); } + /// Whether the parameter might be a semantic result for autodiff purposes. + /// This includes inout and class-bound parameters. + bool isAutoDiffSemanticResult() const { + return isInOut() || Ty->getClassOrBoundGenericClass() != nullptr; + } + ValueOwnership getValueOwnership() const { return Flags.getValueOwnership(); } @@ -3509,8 +3515,8 @@ class AnyFunctionType : public TypeBase { /// Preconditions: /// - Parameters corresponding to parameter indices must conform to /// `Differentiable`. - /// - There is one semantic function result type: either the formal original - /// result or an `inout` parameter. It must conform to `Differentiable`. + /// - There are semantic function result type: either the formal original + /// result or a "wrt" semantic result parameter. /// /// Differential typing rules: takes "wrt" parameter derivatives and returns a /// "wrt" result derivative. @@ -3518,10 +3524,7 @@ class AnyFunctionType : public TypeBase { /// - Case 1: original function has no `inout` parameters. /// - Original: `(T0, T1, ...) -> R` /// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` - /// - Case 2: original function has a non-wrt `inout` parameter. - /// - Original: `(T0, inout T1, ...) -> Void` - /// - Differential: `(T0.Tan, ...) -> T1.Tan` - /// - Case 3: original function has a wrt `inout` parameter. + /// - Case 2: original function has a wrt `inout` parameter. /// - Original: `(T0, inout T1, ...) -> Void` /// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` /// @@ -3531,10 +3534,7 @@ class AnyFunctionType : public TypeBase { /// - Case 1: original function has no `inout` parameters. /// - Original: `(T0, T1, ...) -> R` /// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` - /// - Case 2: original function has a non-wrt `inout` parameter. - /// - Original: `(T0, inout T1, ...) -> Void` - /// - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - /// - Case 3: original function has a wrt `inout` parameter. + /// - Case 2: original function has a wrt `inout` parameter. /// - Original: `(T0, inout T1, ...) -> Void` /// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` /// @@ -4101,6 +4101,10 @@ class SILParameterInfo { return getConvention() == ParameterConvention::Indirect_Inout || getConvention() == ParameterConvention::Indirect_InoutAliasable; } + bool isAutoDiffSemanticResult() const { + return isIndirectMutating() || + getInterfaceType().getClassOrBoundGenericClass() != nullptr; + } bool isPack() const { return isPackParameter(getConvention()); @@ -4836,6 +4840,38 @@ class SILFunctionType final return llvm::count_if(getParameters(), IndirectMutatingParameterFilter()); } + struct AutoDiffSemanticResultsParameterFilter { + bool operator()(SILParameterInfo param) const { + return param.isAutoDiffSemanticResult(); + } + }; + + using AutoDiffSemanticResultsParameterIter = + llvm::filter_iterator; + using AutoDiffSemanticResultsParameterRange = + iterator_range; + + /// A range of SILParameterInfo for all semantic results parameters. + AutoDiffSemanticResultsParameterRange + getAutoDiffSemanticResultsParameters() const { + return llvm::make_filter_range(getParameters(), + AutoDiffSemanticResultsParameterFilter()); + } + + /// Returns the number of semantic results parameters. + unsigned getNumAutoDiffSemanticResultsParameters() const { + return llvm::count_if(getParameters(), AutoDiffSemanticResultsParameterFilter()); + } + + /// Returns the number of function potential semantic results: + /// * Usual results + /// * Inout parameters + /// * Class or class-bound parameters + unsigned getNumAutoDiffSemanticResults() const { + return getNumResults() + getNumAutoDiffSemanticResultsParameters(); + } + /// Get the generic signature that the component types are specified /// in terms of, if any. CanGenericSignature getSubstGenericSignature() const { diff --git a/include/swift/SIL/ApplySite.h b/include/swift/SIL/ApplySite.h index a38e25bb3313a..06888dc4d681c 100644 --- a/include/swift/SIL/ApplySite.h +++ b/include/swift/SIL/ApplySite.h @@ -681,6 +681,18 @@ class FullApplySite : public ApplySite { llvm_unreachable("invalid apply kind"); } + AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const { + switch (getKind()) { + case FullApplySiteKind::ApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + case FullApplySiteKind::TryApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + case FullApplySiteKind::BeginApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + } + llvm_unreachable("invalid apply kind"); + } + /// Returns true if \p op is the callee operand of this apply site /// and not an argument operand. bool isCalleeOperand(const Operand &op) const { diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 9a7eb8c8dd49d..b929ba0f67b86 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -2785,6 +2785,25 @@ struct OperandToInoutArgument { using InoutArgumentRange = OptionalTransformRange, OperandToInoutArgument>; +/// Predicate used to filter AutoDiffSemanticResultArgumentRange. +struct OperandToAutoDiffSemanticResultArgument { + ArrayRef paramInfos; + OperandValueArrayRef arguments; + OperandToAutoDiffSemanticResultArgument(ArrayRef paramInfos, + OperandValueArrayRef arguments) + : paramInfos(paramInfos), arguments(arguments) { + assert(paramInfos.size() == arguments.size()); + } + llvm::Optional operator()(size_t i) const { + if (paramInfos[i].isAutoDiffSemanticResult()) + return arguments[i]; + return llvm::None; + } +}; + +using AutoDiffSemanticResultArgumentRange = + OptionalTransformRange, OperandToAutoDiffSemanticResultArgument>; + /// The partial specialization of ApplyInstBase for full applications. /// Adds some methods relating to 'self' and to result types that don't /// make sense for partial applications. @@ -2894,6 +2913,16 @@ class ApplyInstBase impl.getArgumentsWithoutIndirectResults())); } + /// Returns all autodiff semantic result (`@inout`, `@inout_aliasable`) + /// arguments passed to the instruction. + AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const { + auto &impl = asImpl(); + return AutoDiffSemanticResultArgumentRange( + indices(getArgumentsWithoutIndirectResults()), + OperandToAutoDiffSemanticResultArgument(impl.getSubstCalleeConv().getParameters(), + impl.getArgumentsWithoutIndirectResults())); + } + bool hasSemantics(StringRef semanticsString) const { return doesApplyCalleeHaveSemantics(getCallee(), semanticsString); } diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index c7a04a16939e7..42a98d0f5a410 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -180,10 +180,9 @@ void AnyFunctionType::getSubsetParameters( } } -void autodiff::getFunctionSemanticResults( - const AnyFunctionType *functionType, - const IndexSubset *parameterIndices, - SmallVectorImpl &resultTypes) { +static void getFunctionFormalResults( + const AnyFunctionType *functionType, + SmallVectorImpl &resultTypes) { auto &ctx = functionType->getASTContext(); // Collect formal result type as a semantic result, unless it is @@ -199,33 +198,36 @@ void autodiff::getFunctionSemanticResults( if (formalResultType->is()) { for (auto elt : formalResultType->castTo()->getElements()) { resultTypes.emplace_back(elt.getType(), resultIdx++, - /*isInout*/ false, /*isWrt*/ false); + /*isParameter*/ false); } } else { resultTypes.emplace_back(formalResultType, resultIdx++, - /*isInout*/ false, /*isWrt*/ false); + /*isParameter*/ false); } } +} - bool addNonWrts = resultTypes.empty(); - - // Collect wrt `inout` parameters as semantic results - // As an extention, collect all (including non-wrt) inouts as results for - // functions returning void. +void autodiff::getFunctionSemanticResults( + const AnyFunctionType *functionType, + const IndexSubset *parameterIndices, + SmallVectorImpl &resultTypes) { + getFunctionFormalResults(functionType, resultTypes); + unsigned numResults = resultTypes.size(); + + // Collect wrt semantic result (`inout` and class references) parameters as + // semantic results auto collectSemanticResults = [&](const AnyFunctionType *functionType, unsigned curryOffset = 0) { for (auto paramAndIndex : enumerate(functionType->getParams())) { - if (!paramAndIndex.value().isInOut()) + if (!paramAndIndex.value().isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIndex.index() + curryOffset; assert(idx < parameterIndices->getCapacity() && "invalid parameter index"); - bool isWrt = parameterIndices->contains(idx); - if (addNonWrts || isWrt) + if (parameterIndices->contains(idx)) resultTypes.emplace_back(paramAndIndex.value().getPlainType(), - resultIdx, /*isInout*/ true, isWrt); - resultIdx += 1; + numResults + idx, /*isParameter*/ true); } }; @@ -245,17 +247,26 @@ autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType, const IndexSubset *parameterIndices) { auto &ctx = functionType->getASTContext(); - SmallVector semanticResults; + SmallVector formalResults, semanticResults; + getFunctionFormalResults(functionType, formalResults); autodiff::getFunctionSemanticResults(functionType, parameterIndices, semanticResults); + unsigned numSemanticResults = formalResults.size(); + if (auto *resultFnType = + functionType->getResult()->getAs()) { + assert(functionType->getNumParams() == 1 && "unexpected function type"); + numSemanticResults += 1 + resultFnType->getNumParams(); + } else { + numSemanticResults += functionType->getNumParams(); + } + SmallVector resultIndices; - unsigned cap = 0; for (const auto& result : semanticResults) { + assert(result.index < numSemanticResults); resultIndices.push_back(result.index); - cap = std::max(cap, result.index + 1U); } - return IndexSubset::get(ctx, cap, resultIndices); + return IndexSubset::get(ctx, numSemanticResults, resultIndices); } IndexSubset * diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 36f0a9f95f134..5d31060d3e30c 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -5558,7 +5558,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - // Accumulate non-inout result tangent spaces. + // Accumulate non-semantic result tangent spaces. SmallVector resultTanTypes, inoutTanTypes; for (auto i : range(originalResults.size())) { auto originalResult = originalResults[i]; @@ -5577,31 +5577,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, std::make_pair(originalResultType, unsigned(originalResult.index))); - if (!originalResult.isInout) + if (!originalResult.isSemanticResultParameter) resultTanTypes.push_back(resultTan->getType()); - else if (originalResult.isInout && !originalResult.isWrtParam) - inoutTanTypes.push_back(resultTan->getType()); } - // Treat non-wrt inouts as semantic results for functions returning Void - if (resultTanTypes.empty()) - resultTanTypes = inoutTanTypes; - // Compute the result linear map function type. FunctionType *linearMapType; switch (kind) { case AutoDiffLinearMapKind::Differential: { // Compute the differential type, returned by JVP functions. // - // Case 1: original function has no `inout` parameters. + // Case 1: original function has no semantic result parameters. // - Original: `(T0, T1, ...) -> R` // - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` // - // Case 2: original function has a non-wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, ...) -> T1.Tan` - // - // Case 3: original function has a wrt `inout` parameter. + // Case 2: original function has a wrt semantic result parameter + // (e.g. `inout`) // - Original: `(T0, inout T1, ...) -> Void` // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` SmallVector differentialParams; @@ -5644,19 +5635,16 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( case AutoDiffLinearMapKind::Pullback: { // Compute the pullback type, returned by VJP functions. // - // Case 1: original function has no `inout` parameters. + // Case 1: original function has no semantic result parameters. // - Original: `(T0, T1, ...) -> R` // - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` // - // Case 2: original function has a non-wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - // - // Case 3: original function has wrt `inout` parameters. + // Case 2: original function has wrt semantic result parameters + // (e.g. an `inout` one) // - Original: `(T0, inout T1, ...) -> R` // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` SmallVector pullbackResults; - SmallVector inoutParams; + SmallVector semanticResultParams; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -5669,10 +5657,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( NonDifferentiableDifferentiabilityParameter, std::make_pair(paramType, i)); - if (diffParam.isInOut()) { + if (diffParam.isAutoDiffSemanticResult()) { if (paramType->isVoid()) continue; - inoutParams.push_back(diffParam); + semanticResultParams.push_back(diffParam); continue; } pullbackResults.emplace_back(paramTan->getType()); @@ -5685,7 +5673,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( } else { pullbackResult = TupleType::get(pullbackResults, ctx); } - // First accumulate non-inout results as pullback parameters. + // First accumulate results as pullback parameters. SmallVector pullbackParams; for (auto i : range(resultTanTypes.size())) { auto resultTanType = resultTanTypes[i]; @@ -5693,15 +5681,15 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( pullbackParams.push_back(AnyFunctionType::Param( resultTanType, Identifier(), flags)); } - // Then append inout parameters. - for (auto i : range(inoutParams.size())) { - auto inoutParam = inoutParams[i]; - auto inoutParamType = inoutParam.getPlainType(); - auto inoutParamTan = - inoutParamType->getAutoDiffTangentSpace(lookupConformance); + // Then append semantic result parameters. + for (auto i : range(semanticResultParams.size())) { + auto semanticResultParam = semanticResultParams[i]; + auto semanticResultParamType = semanticResultParam.getPlainType(); + auto semanticResultParamTan = + semanticResultParamType->getAutoDiffTangentSpace(lookupConformance); auto flags = ParameterTypeFlags().withInOut(true); pullbackParams.push_back(AnyFunctionType::Param( - inoutParamTan->getType(), Identifier(), flags)); + semanticResultParamTan->getType(), Identifier(), flags)); } // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; @@ -5709,6 +5697,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( break; } } + assert(linearMapType && "Expected linear map type"); return linearMapType; } diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 6853badd31b2c..6212321589846 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -237,30 +237,33 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { if (resultAndIndex.value().getDifferentiability() != SILResultDifferentiability::NotDifferentiable) resultIndices.push_back(resultAndIndex.index()); - - // Check `inout` parameters. - for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters())) - // Currently, an `inout` parameter can either be: + + auto numSemanticResults = getNumResults(); + + // Check semantic results (`inout` or class-bound) parameters. + for (auto resultParamAndIndex : enumerate(getParameters())) { + if (!resultParamAndIndex.value().isAutoDiffSemanticResult()) + continue; + + // Currently, a semantic result parameter can either be: // 1. Both a differentiability parameter and a differentiability result. // 2. `@noDerivative`: neither a differentiability parameter nor a // differentiability result. - // However, there is no way to represent an `inout` parameter that: + // However, there is no way to represent a semantic result parameter that: // 3. Is a differentiability result but not a differentiability parameter. // 4. Is a differentiability parameter but not a differentiability result. // This case is not currently expressible and does not yet have clear use // cases, so supporting it is a non-goal. // - // See TF-1305 for solution ideas. For now, `@noDerivative` `inout` - // parameters are not treated as differentiability results, unless the - // original function has no formal results, in which case all `inout` - // parameters are treated as differentiability results. - if (resultIndices.empty() || - inoutParamAndIndex.value().getDifferentiability() != + // See TF-1305 for solution ideas. For now, `@noDerivative` semantic result + // parameters are not treated as differentiability results. + if (resultParamAndIndex.value().getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) - resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); + resultIndices.push_back(getNumResults() + resultParamAndIndex.index()); + } + + numSemanticResults += getNumParameters(); - auto numSemanticResults = - getNumResults() + getNumIndirectMutatingParameters(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); } @@ -369,21 +372,24 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy, /// Collects the semantic results of the given function type in /// `originalResults`. The semantic results are formal results followed by -/// `inout` parameters, in type order. +/// semantic result parameters, in type order. static void -getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, +getSemanticResults(SILFunctionType *functionType, + IndexSubset *parameterIndices, SmallVectorImpl &originalResults) { // Collect original formal results. originalResults.append(functionType->getResults().begin(), functionType->getResults().end()); - // Collect original `inout` parameters. + // Collect original semantic result parameters. for (auto i : range(functionType->getNumParameters())) { auto param = functionType->getParameters()[i]; - if (!param.isIndirectMutating()) + if (!param.isAutoDiffSemanticResult()) continue; - if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) - originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect); + if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable && + parameterIndices->contains(i)) + originalResults.emplace_back(param.getInterfaceType(), + ResultConvention::Indirect); } } @@ -578,7 +584,9 @@ static CanSILFunctionType getAutoDiffDifferentialType( ->getAutoDiffTangentSpace(lookupConformance) ->getCanonicalType(), param.getConvention()); - differentialParams.push_back({paramTanType, paramConv}); + if (param.getInterfaceType()->getClassOrBoundGenericClass()) + paramConv = ParameterConvention::Indirect_Inout; + differentialParams.emplace_back(paramTanType, paramConv); } SmallVector differentialResults; for (auto resultIndex : resultIndices->getIndices()) { @@ -594,26 +602,23 @@ static CanSILFunctionType getAutoDiffDifferentialType( ->getAutoDiffTangentSpace(lookupConformance) ->getCanonicalType(), result.getConvention()); - differentialResults.push_back({resultTanType, resultConv}); + differentialResults.emplace_back(resultTanType, resultConv); continue; } - // Handle original `inout` parameters. - auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); - auto inoutParamIt = std::next( - originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); - auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); - // If the original `inout` parameter is a differentiability parameter, then - // it already has a corresponding differential parameter. Skip adding a - // corresponding differential result. + // Handle original semantic result parameters. + auto paramIndex = resultIndex - originalFnTy->getNumResults(); + // If the original semantic result parameter is a differentiability + // parameter, then it already has a corresponding differential + // parameter. Skip adding a corresponding differential result. if (parameterIndices->contains(paramIndex)) continue; - auto inoutParam = originalFnTy->getParameters()[paramIndex]; - auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap( - inoutParam.getInterfaceType(), lookupConformance, + + auto resultParam = originalFnTy->getParameters()[paramIndex]; + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, substGenericParams, substReplacements, ctx); - differentialResults.push_back( - {inoutParamTanType, ResultConvention::Indirect}); + differentialResults.emplace_back(resultParamTanType, + ResultConvention::Indirect); } SubstitutionMap substitutions; @@ -734,28 +739,37 @@ static CanSILFunctionType getAutoDiffPullbackType( ->getAutoDiffTangentSpace(lookupConformance) ->getCanonicalType(), origRes.getConvention()); - pullbackParams.push_back({resultTanType, paramConv}); + if (origRes.getInterfaceType().getClassOrBoundGenericClass()) + paramConv = ParameterConvention::Indirect_In_Guaranteed; + + pullbackParams.emplace_back(resultTanType, paramConv); continue; } - // Handle `inout` parameters. - auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); - auto inoutParamIt = std::next( - originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); - auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); - auto inoutParam = originalFnTy->getParameters()[paramIndex]; - // The pullback parameter convention depends on whether the original `inout` - // parameter is a differentiability parameter. + + // Handle original semantic result parameters. + auto paramIndex = resultIndex - originalFnTy->getNumResults(); + auto resultParam = originalFnTy->getParameters()[paramIndex]; + // The pullback parameter convention depends on whether the original + // semantic result parameter is a differentiability parameter. // - If yes, the pullback parameter convention is `@inout`. // - If no, the pullback parameter convention is `@in_guaranteed`. - auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap( - inoutParam.getInterfaceType(), lookupConformance, - substGenericParams, substReplacements, ctx); - bool isWrtInoutParameter = parameterIndices->contains(paramIndex); - auto paramTanConvention = isWrtInoutParameter - ? inoutParam.getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - pullbackParams.push_back({inoutParamTanType, paramTanConvention}); + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, + substGenericParams, substReplacements, ctx); + ParameterConvention paramTanConvention = resultParam.getConvention(); + if (resultParam.isIndirectMutating()) { + if (!parameterIndices->contains(paramIndex)) + paramTanConvention = ParameterConvention::Indirect_In_Guaranteed; + } else { + assert(resultParam.getInterfaceType().getClassOrBoundGenericClass() && + "expected class bound parameter"); + if (!parameterIndices->contains(paramIndex)) + paramTanConvention = ParameterConvention::Indirect_In_Guaranteed; + else + paramTanConvention = ParameterConvention::Indirect_Inout; + } + + pullbackParams.emplace_back(resultParamTanType, paramTanConvention); } // Collect pullback results. @@ -763,9 +777,9 @@ static CanSILFunctionType getAutoDiffPullbackType( getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams); SmallVector pullbackResults; for (auto ¶m : diffParams) { - // Skip `inout` parameters, which semantically behave as original results - // and always appear as pullback parameters. - if (param.isIndirectMutating()) + // Skip semantic result parameters, which semantically behave as original + // results and always appear as pullback parameters. + if (param.isAutoDiffSemanticResult()) continue; auto paramTanType = getAutoDiffTangentTypeForLinearMap( param.getInterfaceType(), lookupConformance, @@ -898,6 +912,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( origTypeOfAbstraction, TC); break; } + // Compute the derivative function parameters. SmallVector newParameters; newParameters.reserve(constrainedOriginalFnTy->getNumParameters()); @@ -4091,6 +4106,38 @@ static llvm::cl::opt DisableConstantInfoCache("sil-disable-typelowering-constantinfo-cache", llvm::cl::init(false)); +static IndexSubset * +getLoweredResultIndices(const SILFunctionType *functionType, + const IndexSubset *parameterIndices) { + SmallVector resultIndices; + + // Check formal results. + for (auto resultAndIndex : enumerate(functionType->getResults())) + if (resultAndIndex.value().getDifferentiability() != + SILResultDifferentiability::NotDifferentiable) + resultIndices.push_back(resultAndIndex.index()); + + auto numResults = functionType->getNumResults(); + + // Collect semantic result parameters. + for (auto resultParamAndIndex + : enumerate(functionType->getParameters())) { + if (!resultParamAndIndex.value().isAutoDiffSemanticResult()) + continue; + + if (resultParamAndIndex.value().getDifferentiability() != + SILParameterDifferentiability::NotDifferentiable && + parameterIndices->contains(resultParamAndIndex.index())) + resultIndices.push_back(numResults + resultParamAndIndex.index()); + } + + numResults += functionType->getNumParameters(); + + return IndexSubset::get(functionType->getASTContext(), + numResults, resultIndices); +} + + const SILConstantInfo & TypeConverter::getConstantInfo(TypeExpansionContext expansion, SILDeclRef constant) { @@ -4149,11 +4196,9 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion, // Use it to compute lowered derivative function type. auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), formalInterfaceType); - auto numResults = - origFnConstantInfo.SILFnType->getNumResults() + - origFnConstantInfo.SILFnType->getNumIndirectMutatingParameters(); - auto *loweredResultIndices = IndexSubset::getDefault( - M.getASTContext(), numResults, /*includeAll*/ true); + auto *loweredResultIndices + = getLoweredResultIndices(origFnConstantInfo.SILFnType, loweredParamIndices); + silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType( loweredParamIndices, loweredResultIndices, derivativeId->getKind(), *this, LookUpConformanceInModule(&M)); diff --git a/lib/SIL/Parser/ParseSIL.cpp b/lib/SIL/Parser/ParseSIL.cpp index 68c777b39a085..bb93cf1997a1c 100644 --- a/lib/SIL/Parser/ParseSIL.cpp +++ b/lib/SIL/Parser/ParseSIL.cpp @@ -2515,7 +2515,7 @@ static bool parseSILDifferentiabilityWitnessConfigAndFunction( P.Context, origFnType->getNumParameters(), rawParameterIndices); auto *resultIndices = IndexSubset::get(P.Context, origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters(), + origFnType->getNumParameters(), rawResultIndices); resultConfig = AutoDiffConfig(parameterIndices, resultIndices, witnessGenSig); return false; @@ -6398,7 +6398,7 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B, P.Context, fnType->getNumParameters(), rawParameterIndices); auto *resultIndices = IndexSubset::get( P.Context, - fnType->getNumResults() + fnType->getNumIndirectMutatingParameters(), + fnType->getNumResults() + fnType->getNumParameters(), rawResultIndices); if (forwardingOwnership != OwnershipKind::None) { ResultVal = B.createDifferentiableFunction( diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index c2cff48ad0ace..4045b05134877 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -5275,11 +5275,13 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( .getIdentifier( mangler.mangleAutoDiffDerivativeFunction(originalAFD, kind, config)) .str(); - auto loc = customDerivativeFn->getLocation(); + SILGenFunctionBuilder fb(*this); // Derivative thunks have the same linkage as the original function, stripping // external. + // FIXME: Currently class-scoped thunks are not allowed. Do we need to + // introduce special "derivative thunk" to allow this? auto linkage = stripExternalFromLinkage(originalFn->getLinkage()); auto *thunk = fb.getOrCreateFunction( loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent, @@ -5287,8 +5289,8 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->isDistributed(), customDerivativeFn->isRuntimeAccessible(), - customDerivativeFn->getEntryCount(), IsThunk, - customDerivativeFn->getClassSubclassScope()); + customDerivativeFn->getEntryCount(), + IsThunk, SubclassScope::NotApplicable); // This thunk may be publicly exposed and cannot be transparent. // Instead, mark it as "always inline" for optimization. thunk->setInlineStrategy(AlwaysInline); diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index 0cd474d74b521..148d3b90bc69c 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -85,8 +85,10 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, LLVM_DEBUG({ auto &s = getADDebugStream(); s << "Outputs in @" << function.getName() << ":\n"; - for (auto val : outputValues) - s << val << '\n'; + for (auto val : outputValues) { + if (val) + s << val << '\n'; + } }); // Propagate variedness starting from the inputs. @@ -104,7 +106,8 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, auto output = outputAndIdx.value(); unsigned i = outputAndIdx.index(); usefulValueSets.push_back({}); - setUsefulAndPropagateToOperands(output, i); + if (output) + setUsefulAndPropagateToOperands(output, i); } } @@ -133,13 +136,13 @@ void DifferentiableActivityInfo::propagateVaried( // Skip non-varying callees. if (isWithoutDerivative(applySite.getCallee())) return; - // If operand is varied, set all direct/indirect results and inout arguments - // as varied. + // If operand is varied, set all direct/indirect results and semantic result + // arguments as varied. if (isVaried(operand->get(), i)) { for (auto indRes : applySite.getIndirectSILResults()) propagateVariedInwardsThroughProjections(indRes, i); - for (auto inoutArg : applySite.getInoutArguments()) - propagateVariedInwardsThroughProjections(inoutArg, i); + for (auto semresArg : applySite.getAutoDiffSemanticResultArguments()) + propagateVariedInwardsThroughProjections(semresArg, i); // Propagate variedness to apply site direct results. forEachApplyDirectResult(applySite, [&](SILValue directResult) { setVariedAndPropagateToUsers(directResult, i); @@ -297,6 +300,7 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands( // Skip already-useful values to prevent infinite recursion. if (isUseful(value, dependentVariableIndex)) return; + if (value->getType().isAddress() || value->getType().getClassOrBoundGenericClass()) { propagateUsefulThroughAddress(value, dependentVariableIndex); @@ -333,12 +337,15 @@ void DifferentiableActivityInfo::propagateUseful( // Propagate usefulness for the given instruction: mark operands as useful and // recursively propagate usefulness to defining instructions of operands. auto i = dependentVariableIndex; + // Handle full apply sites: `apply`, `try_apply`, and `begin_apply`. if (FullApplySite::isa(inst)) { FullApplySite applySite(inst); + // If callee is non-varying, skip. if (isWithoutDerivative(applySite.getCallee())) return; + // If callee is a `modify` accessor, propagate usefulness through yielded // addresses. Semantically, yielded addresses can be viewed as a projection // into the `inout` argument. @@ -350,10 +357,12 @@ void DifferentiableActivityInfo::propagateUseful( for (auto yield : bai->getYieldedValues()) setUsefulAndPropagateToOperands(yield, i); } + // Propagate usefulness through apply site arguments. for (auto arg : applySite.getArgumentsWithoutIndirectResults()) setUsefulAndPropagateToOperands(arg, i); } + // Handle store-like instructions: // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` #define PROPAGATE_USEFUL_THROUGH_STORE(INST) \ @@ -390,14 +399,17 @@ void DifferentiableActivityInfo::propagateUsefulThroughAddress( // Skip already-useful values to prevent infinite recursion. if (isUseful(value, dependentVariableIndex)) return; + setUseful(value, dependentVariableIndex); if (auto *inst = value->getDefiningInstruction()) propagateUseful(inst, dependentVariableIndex); + // Recursively propagate usefulness through users that are projections or // `begin_access` instructions. for (auto use : value->getUses()) { // Propagate usefulness through user's operands. propagateUseful(use->getUser(), dependentVariableIndex); + for (auto res : use->getUser()->getResults()) { #define SKIP_NODERIVATIVE(INST) \ if (auto *projInst = dyn_cast(res)) \ @@ -410,6 +422,12 @@ void DifferentiableActivityInfo::propagateUsefulThroughAddress( if (Projection::isAddressProjection(res) || isa(res) || isa(res)) propagateUsefulThroughAddress(res, dependentVariableIndex); + // class values have reference semantics. Therefore load / load_borrow of + // useful $*Class should produce useful value + else if (auto *li = dyn_cast(res)) { + if (li->getType().getClassOrBoundGenericClass()) + propagateUsefulThroughAddress(li, dependentVariableIndex); + } } } } diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index a8f749b3187c3..ad7c42ac96bb6 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -131,6 +131,7 @@ void forEachApplyDirectResult( void collectAllFormalResultsInTypeOrder(SILFunction &function, SmallVectorImpl &results) { + LLVM_DEBUG(llvm::dbgs() << "Calculating results for: " << function.getName() << "\n"); SILFunctionConventions convs(function.getLoweredFunctionType(), function.getModule()); auto indResults = function.getIndirectResults(); @@ -147,14 +148,15 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function, for (auto &resInfo : convs.getResults()) results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] : indResults[indResIdx++]); - // Treat `inout` parameters as semantic results. - // Append `inout` parameters after formal results. + // Treat semantic result parameters as semantic results. + // Append them after formal results. for (auto i : range(convs.getNumParameters())) { auto paramInfo = convs.getParameters()[i]; - if (!paramInfo.isIndirectMutating()) - continue; - auto *argument = function.getArgumentsWithoutIndirectResults()[i]; - results.push_back(argument); + if (paramInfo.isAutoDiffSemanticResult()) { + auto *argument = function.getArgumentsWithoutIndirectResults()[i]; + results.push_back(argument); + } else + results.push_back(SILValue()); } } @@ -190,6 +192,7 @@ void collectMinimalIndicesForFunctionCall( SmallVectorImpl &resultIndices) { auto calleeFnTy = ai->getSubstCalleeType(); auto calleeConvs = ai->getSubstCalleeConv(); + // Parameter indices are indices (in the callee type signature) of parameter // arguments that are varied or are arguments. // Record all parameter indices in type order. @@ -199,6 +202,7 @@ void collectMinimalIndicesForFunctionCall( paramIndices.push_back(currentParamIdx); ++currentParamIdx; } + // Result indices are indices (in the callee type signature) of results that // are useful. SmallVector directResults; @@ -226,22 +230,21 @@ void collectMinimalIndicesForFunctionCall( ++indResIdx; } } - // Record all `inout` parameters as results. - auto inoutParamResultIndex = calleeFnTy->getNumResults(); + + // Record all semantic result parameters as results. + auto semanticResultParamOffset = calleeFnTy->getNumResults(); for (auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) { auto ¶m = paramAndIdx.value(); - if (!param.isIndirectMutating()) + if (!param.isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); - auto inoutArg = ai->getArgument(idx); - results.push_back(inoutArg); - resultIndices.push_back(inoutParamResultIndex++); + results.push_back(ai->getArgument(idx)); + resultIndices.push_back(semanticResultParamOffset + paramAndIdx.index()); } + // Make sure the function call has active results. #ifndef NDEBUG - auto numResults = calleeFnTy->getNumResults() + - calleeFnTy->getNumIndirectMutatingParameters(); - assert(results.size() == numResults); + assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults()); assert(llvm::any_of(results, [&](SILValue result) { return activityInfo.isActive(result, parentConfig); })); diff --git a/lib/SILOptimizer/Differentiation/JVPCloner.cpp b/lib/SILOptimizer/Differentiation/JVPCloner.cpp index ca58da53baf3e..9a31381001ff7 100644 --- a/lib/SILOptimizer/Differentiation/JVPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/JVPCloner.cpp @@ -103,7 +103,7 @@ class JVPCloner::Implementation final TangentBuilder diffLocalAllocBuilder; /// Stack buffers allocated for storing local tangent values. - SmallVector differentialLocalAllocations; + SmallVector differentialLocalAllocations; /// Mapping from original blocks to differential values. Used to build /// differential struct instances. @@ -301,7 +301,8 @@ class JVPCloner::Implementation final /// original buffer does not already have a tangent buffer. void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, SILValue tangentBuffer) { - assert(originalBuffer->getType().isAddress()); + assert(originalBuffer->getType().isAddress() || + originalBuffer->getType().getClassOrBoundGenericClass()); auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); assert(insertion.second && "Tangent buffer already exists"); @@ -311,7 +312,8 @@ class JVPCloner::Implementation final /// Returns the tangent buffer for the original buffer. Asserts that the /// original buffer has a tangent buffer. SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { - assert(originalBuffer->getType().isAddress()); + assert(originalBuffer->getType().isAddress() || + originalBuffer->getType().getClassOrBoundGenericClass()); assert(originalBuffer->getFunction() == original); auto it = bufferMap.find({origBB, originalBuffer}); assert(it != bufferMap.end() && "Tangent buffer should already exist"); @@ -436,7 +438,7 @@ class JVPCloner::Implementation final TypeSubstCloner::visitInstructionsInBlock(bb); } - // If an `apply` has active results or active inout parameters, replace it + // If an `apply` has active results or active semantic result parameters, replace it // with an `apply` of its JVP. void visitApplyInst(ApplyInst *ai) { bool shouldDifferentiate = @@ -487,13 +489,10 @@ class JVPCloner::Implementation final s << "}\n";); // Form expected indices. - auto numResults = - ai->getSubstCalleeType()->getNumResults() + - ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); + auto numParams = ai->getArgumentsWithoutIndirectResults().size(); + auto numResults = ai->getSubstCalleeType()->getNumResults() + numParams; AutoDiffConfig config( - IndexSubset::get(getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), - activeParamIndices), + IndexSubset::get(getASTContext(), numParams, activeParamIndices), IndexSubset::get(getASTContext(), numResults, activeResultIndices)); // Emit the JVP. @@ -576,10 +575,9 @@ class JVPCloner::Implementation final for (auto resultIndex : config.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResArgIdx = resultIndex - originalFnTy->getNumResults(); + auto semanticResArg = ai->getArgumentsWithoutIndirectResults()[semanticResArgIdx]; + remappedResultType = semanticResArg->getType(); } else { remappedResultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); @@ -775,17 +773,33 @@ class JVPCloner::Implementation final CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { auto &diffBuilder = getDifferentialBuilder(); auto loc = bbi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); - auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); - setTangentValue(bbi->getParent(), bbi, - makeConcreteTangentValue(tanValBorrow)); + auto *bb = bbi->getParent(); + if (bbi->getType().getClassOrBoundGenericClass()) { + //setTangentBuffer(bb, bbi, getTangentBuffer(bb, bbi->getOperand())); + auto tanBufAcc = diffBuilder.createBeginAccess( + loc, getTangentBuffer(bb, bbi->getOperand()), + SILAccessKind::Modify, SILAccessEnforcement::Static, false, false); + setTangentBuffer(bb, bbi, tanBufAcc); + } else { + auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); + auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); + setTangentValue(bbi->getParent(), bbi, + makeConcreteTangentValue(tanValBorrow)); + } } CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { auto &diffBuilder = getDifferentialBuilder(); auto loc = ebi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); - diffBuilder.emitEndBorrowOperation(loc, tanVal); + auto *bb = ebi->getParent(); + + if (ebi->getOperand()->getType().getClassOrBoundGenericClass()) { + auto tanBuf = getTangentBuffer(bb, ebi->getOperand()); + diffBuilder.createEndAccess(loc, tanBuf, false); + } else { + auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); + diffBuilder.emitEndBorrowOperation(loc, tanVal); + } } CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { @@ -809,6 +823,7 @@ class JVPCloner::Implementation final /// Tangent: tan[y] = load tan[x] void visitLoadInst(LoadInst *li) { TypeSubstCloner::visitLoadInst(li); + auto *bb = li->getParent(); // If an active buffer is loaded with take to a non-active value, destroy // the active buffer's tangent buffer. if (!differentialInfo.shouldDifferentiateInstruction(li)) { @@ -819,10 +834,14 @@ class JVPCloner::Implementation final getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf); } return; + } else if (li->getType().getClassOrBoundGenericClass()) { // Treat load of class value as a projection + setTangentBuffer(bb, li, + getTangentBuffer(bb, li->getOperand())); + return; } + // Otherwise, do standard differential cloning. auto &diffBuilder = getDifferentialBuilder(); - auto *bb = li->getParent(); auto loc = li->getLoc(); auto tanBuf = getTangentBuffer(bb, li->getOperand()); auto tanVal = diffBuilder.emitLoadValueOperation( @@ -992,6 +1011,20 @@ class JVPCloner::Implementation final setTangentBuffer(asi->getParent(), asi, mappedAllocStackInst); } + /// Handle `alloc_ref` instruction. + /// Original: y = alloc_ref $C + /// Tangent: tan[y] = alloc_stack $C.Tangent + CLONE_AND_EMIT_TANGENT(AllocRef, ari) { + auto &diffBuilder = getDifferentialBuilder(); + auto *mappedAllocStackInst = diffBuilder.createAllocStack( + ari->getLoc(), getRemappedTangentType(ari->getType())); + diffBuilder.emitZeroIntoBuffer(ari->getLoc(), + mappedAllocStackInst, + IsInitialization_t::IsInitialization); + differentialLocalAllocations.push_back(mappedAllocStackInst); + setTangentBuffer(ari->getParent(), ari, mappedAllocStackInst); + } + /// Handle `dealloc_stack` instruction. /// Original: dealloc_stack x /// Tangent: dealloc_stack tan[x] @@ -1082,6 +1115,35 @@ class JVPCloner::Implementation final setTangentBuffer(bb, seai, tangentInst); } + /// Handle `ref_element_addr` instruction. + /// Original: y = ref_element_addr C, #field + /// Tangent: tan[y] = struct_element_addr *tan[x], #field' + /// ^~~~~~~ + /// field in tangent space corresponding to #field + CLONE_AND_EMIT_TANGENT(RefElementAddr, reai) { + assert(!reai->getField()->getAttrs().hasAttribute() && + "`ref_element_addr` with `@noDerivative` field should not be " + "differentiated; activity analysis should not marked as varied."); + auto diffBuilder = getDifferentialBuilder(); + auto *bb = reai->getParent(); + auto loc = getValidLocation(reai); + // Find the corresponding field in the tangent space. + auto classOperandType = + remapSILTypeInDifferential(reai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(context, reai, classOperandType, invoker); + if (!tanField) { + errorOccurred = true; + return; + } + // Emit tangent `struct_element_addr`. + auto tanOperand = getTangentBuffer(bb, reai->getOperand()); + auto tangentInst = + diffBuilder.createStructElementAddr(loc, tanOperand, tanField); + // Update tangent buffer map for `ref_element_addr`. + setTangentBuffer(bb, reai, tangentInst); + } + /// Handle `tuple` instruction. /// Original: y = tuple (x0, x1, x2, ...) /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) @@ -1235,7 +1297,8 @@ class JVPCloner::Implementation final origCalleeType->getDifferentiabilityParameterIndices(); if (actualOrigCalleeIndices->contains(i)) { SILValue tanParam; - if (origArg->getType().isObject()) { + if (origArg->getType().isObject() && + !origArg->getType().getClassOrBoundGenericClass()) { tanParam = emitZeroDirect( getRemappedTangentType(origArg->getType()).getASTType(), loc); diffArgs.push_back(tanParam); @@ -1252,7 +1315,8 @@ class JVPCloner::Implementation final // getting its tangent value. else { SILValue tanParam; - if (origArg->getType().isObject()) { + if (origArg->getType().isObject() && + !origArg->getType().getClassOrBoundGenericClass()) { tanParam = materializeTangent(getTangentValue(origArg), loc); } else { tanParam = getTangentBuffer(ai->getParent(), origArg); @@ -1279,35 +1343,40 @@ class JVPCloner::Implementation final diffBuilder.createApply(loc, differential, SubstitutionMap(), diffArgs); diffBuilder.emitDestroyValueOperation(loc, differential); - // Get the original `apply` results. + // Get the original `apply` results in result indices order SmallVector origDirectResults; forEachApplyDirectResult(ai, [&](SILValue directResult) { origDirectResults.push_back(directResult); }); SmallVector origAllResults; collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); - - // Get the callee differential `apply` results. + for (auto i : range(ai->getSubstCalleeConv().getNumParameters())) { + auto paramInfo = ai->getSubstCalleeConv().getParameters()[i]; + origAllResults.push_back(paramInfo.isAutoDiffSemanticResult() ? + ai->getArgumentsWithoutIndirectResults()[i] : SILValue()); + } + + // Get the callee differential `apply` results in type order SmallVector differentialDirectResults; extractAllElements(differentialCall, getDifferentialBuilder(), differentialDirectResults); SmallVector differentialAllResults; collectAllActualResultsInTypeOrder( differentialCall, differentialDirectResults, differentialAllResults); - for (auto inoutArg : ai->getInoutArguments()) - origAllResults.push_back(inoutArg); - for (auto inoutArg : differentialCall->getInoutArguments()) - differentialAllResults.push_back(inoutArg); - assert(applyConfig.resultIndices->getNumIndices() == - differentialAllResults.size()); + for (auto semResultArg : differentialCall->getAutoDiffSemanticResultArguments()) + differentialAllResults.push_back(semResultArg); + assert(applyConfig.resultIndices->getNumIndices() == differentialAllResults.size()); // Set tangent values for original `apply` results. unsigned differentialResultIndex = 0; for (auto resultIndex : applyConfig.resultIndices->getIndices()) { + assert(resultIndex < origAllResults.size()); auto origResult = origAllResults[resultIndex]; + assert(origResult && "expected non-trivial result"); auto differentialResult = differentialAllResults[differentialResultIndex++]; - if (origResult->getType().isObject()) { + if (origResult->getType().isObject() && + !origResult->getType().getClassOrBoundGenericClass()) { if (!origResult->getType().is()) { setTangentValue(bb, origResult, makeConcreteTangentValue(differentialResult)); @@ -1347,6 +1416,15 @@ class JVPCloner::Implementation final retElts.push_back(tanVal); } + // Deallocate local allocations. + for (auto alloc : differentialLocalAllocations) { + // Assert that local allocations have at least one use. + // Buffers should not be allocated needlessly. + assert(!alloc->use_empty()); + diffBuilder.emitDestroyAddrAndFold(diffLoc, alloc); + diffBuilder.createDeallocStack(diffLoc, alloc); + } + diffBuilder.createReturn(diffLoc, joinElements(retElts, diffBuilder, diffLoc)); } @@ -1537,7 +1615,7 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() { } // Initialize tangent mapping for original indirect results and non-wrt - // `inout` parameters. The tangent buffers of these address values are + // semantic result parameters. The tangent buffers of these address values are // differential indirect results. // Collect original results. @@ -1564,20 +1642,16 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() { diffLoc); continue; } - // Handle original non-wrt `inout` parameter. - // Only original *non-wrt* `inout` parameters have corresponding + // Handle original non-wrt semantic result parameter. + // Only original *non-wrt* semantic result parameters have corresponding // differential indirect results. - auto inoutParamIndex = resultIndex - origFnTy->getNumResults(); - auto inoutParamIt = std::next( - origFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); - auto paramIndex = - std::distance(origFnTy->getParameters().begin(), &*inoutParamIt); + auto paramIndex = resultIndex - origFnTy->getNumResults(); if (getConfig().parameterIndices->contains(paramIndex)) continue; auto diffIndResult = diffIndResults[differentialIndirectResultIndex++]; setTangentBuffer(origEntry, origResult, diffIndResult); - // Original `inout` parameters are initialized, so their tangent buffers - // must also be initialized. + // Original semantic result parameters are initialized, so their tangent + // buffers must also be initialized. emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult, diffLoc); } @@ -1621,19 +1695,15 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() { ->getReducedType(witnessCanGenSig), origResult.getConvention())); } else { - // Handle original `inout` parameter. - auto inoutParamIndex = resultIndex - origTy->getNumResults(); - auto inoutParamIt = std::next( - origTy->getIndirectMutatingParameters().begin(), inoutParamIndex); - auto paramIndex = - std::distance(origTy->getParameters().begin(), &*inoutParamIt); - // If the original `inout` parameter is a differentiability parameter, + // Handle semantic result parameter. + auto paramIndex = resultIndex - origTy->getNumResults(); + // If the original semantic result parameter is a differentiability parameter, // then it already has a corresponding differential parameter. Do not add // a corresponding differential result. if (config.parameterIndices->contains(paramIndex)) continue; - auto inoutParam = origTy->getParameters()[paramIndex]; - auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace( + auto resultParam = origTy->getParameters()[paramIndex]; + auto paramTan = resultParam.getInterfaceType()->getAutoDiffTangentSpace( lookupConformance); assert(paramTan && "Parameter type does not have a tangent space?"); dfResults.push_back( @@ -1646,12 +1716,15 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() { auto origParam = origParams[i]; origParam = origParam.getWithInterfaceType( origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); - dfParams.push_back( - SILParameterInfo(origParam.getInterfaceType() - ->getAutoDiffTangentSpace(lookupConformance) - ->getType() - ->getReducedType(witnessCanGenSig), - origParam.getConvention())); + + SILParameterInfo paramInfo(origParam.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getReducedType(witnessCanGenSig), + origParam.getConvention()); + if (origParam.getInterfaceType()->getClassOrBoundGenericClass()) + paramInfo = paramInfo.getWithConvention(ParameterConvention::Indirect_Inout); + dfParams.push_back(paramInfo); } // Accept a differential struct in the differential parameter list. This is diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index f469607de2758..6bc78f08f0bef 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -177,7 +177,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { return activityInfo.isActive(res, config); }); - bool hasActiveInoutArgument = false; + bool hasActiveSemanticResultArgument = false; bool hasActiveArguments = false; auto numIndirectResults = ai->getNumIndirectResults(); for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) { @@ -186,13 +186,13 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { hasActiveArguments = true; auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( numIndirectResults + argIdx); - if (paramInfo.isIndirectMutating()) - hasActiveInoutArgument = true; + if (paramInfo.isAutoDiffSemanticResult()) + hasActiveSemanticResultArgument = true; } } if (!hasActiveArguments) return {}; - if (!hasActiveResults && !hasActiveInoutArgument) + if (!hasActiveResults && !hasActiveSemanticResultArgument) return {}; // Compute differentiability parameters. @@ -213,9 +213,8 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); } // Compute differentiability results. - auto numResults = remappedOrigFnSubstTy->getNumResults() + - remappedOrigFnSubstTy->getNumIndirectMutatingParameters(); - auto *results = IndexSubset::get(original->getASTContext(), numResults, + auto *results = IndexSubset::get(original->getASTContext(), + remappedOrigFnSubstTy->getNumParameters() + remappedOrigFnSubstTy->getNumResults(), activeResultIndices); // Create autodiff indices for the `apply` instruction. AutoDiffConfig applyConfig(parameters, results); @@ -234,10 +233,10 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { for (auto resultIndex : applyConfig.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= origFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - origFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResultArgIdx = resultIndex - origFnTy->getNumResults(); + const auto ¶m = origFnTy->getParameters()[semanticResultArgIdx]; + assert(param.isAutoDiffSemanticResult() && "expected autodiff semantic result parameter"); + remappedResultType = param.getSILStorageInterfaceType(); } else { remappedResultType = origFnTy->getResults()[resultIndex].getSILStorageInterfaceType(); @@ -277,8 +276,9 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { SmallVector params; for (auto ¶m : silFnTy->getParameters()) { ParameterTypeFlags flags; - if (param.isIndirectMutating()) + if (param.isAutoDiffSemanticResult()) flags = flags.withInOut(true); + params.push_back( AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags)); } @@ -398,17 +398,17 @@ void LinearMapInfo::generateDifferentiationDataStructures( /// differentiated, given the differentiation indices of the instruction's /// parent function. Whether the `apply` should be differentiated is determined /// sequentially from the following conditions: -/// 1. The instruction has an active `inout` argument. +/// 1. The instruction has an active semantic result argument. /// 2. The instruction is a call to the array literal initialization intrinsic /// ("array.uninitialized_intrinsic"), where the result is active and where /// there is a `store` of an active value into the array's buffer. /// 3. The instruction has both an active result (direct or indirect) and an /// active argument. bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) { - // Function applications with an active inout argument should be + // Function applications with an active semantic result argument should be // differentiated. - for (auto inoutArg : applySite.getInoutArguments()) - if (activityInfo.isActive(inoutArg, config)) + for (auto semanticResArg : applySite.getAutoDiffSemanticResultArguments()) + if (activityInfo.isActive(semanticResArg, config)) return true; bool hasActiveDirectResults = false; diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index c7eb9aa769424..702b433e6ff40 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -257,6 +257,10 @@ class PullbackCloner::Implementation final // is currently always "address". if (v->getType().isAddress()) return SILValueCategory::Address; + // Classes are reference-counted and therefore should be treated as address + // values for the purpose of differentiation + if (v->getType().isClassOrClassMetatype()) + return SILValueCategory::Address; // If the value has an object type and the tangent type is not address-only, // then the tangent value category is "object". auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType()); @@ -567,7 +571,7 @@ class PullbackCloner::Implementation final LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for " << originalValue - << "in bb" << origBB->getDebugID() << '\n'); + << "in bb" << origBB->getDebugID() << ": "); auto bufType = getRemappedTangentType(originalValue->getType()); // Set insertion point for local allocation builder: before the last local @@ -593,6 +597,7 @@ class PullbackCloner::Implementation final dv.Name = adjName; return dv; })); + LLVM_DEBUG(llvm::dbgs() << *newBuf); return (insertion.first->getSecond() = newBuf); } @@ -864,6 +869,7 @@ class PullbackCloner::Implementation final /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) void visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); + // Skip `array.uninitialized_intrinsic` applications, which have special // `store` and `copy_addr` support. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) @@ -901,14 +907,11 @@ class PullbackCloner::Implementation final }); SmallVector origAllResults; collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); - // Append `inout` arguments after original results. - for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { - auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( - ai->getNumIndirectResults() + paramIdx); - if (!paramInfo.isIndirectMutating()) - continue; - origAllResults.push_back( - ai->getArgumentsWithoutIndirectResults()[paramIdx]); + // Append semantic result arguments after original results in result indices order + for (auto i : range(ai->getSubstCalleeConv().getNumParameters())) { + auto paramInfo = ai->getSubstCalleeConv().getParameters()[i]; + origAllResults.push_back(paramInfo.isAutoDiffSemanticResult() ? + ai->getArgumentsWithoutIndirectResults()[i] : SILValue()); } // Get callee pullback arguments. @@ -936,6 +939,7 @@ class PullbackCloner::Implementation final for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { assert(resultIndex < origAllResults.size()); auto origResult = origAllResults[resultIndex]; + assert(origResult && "expected non-trivial result"); // Get the seed (i.e. adjoint value of the original result). SILValue seed; switch (getTangentValueCategory(origResult)) { @@ -981,10 +985,10 @@ class PullbackCloner::Implementation final auto allResultsIt = allResults.begin(); for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); - // Skip adjoint accumulation for `inout` arguments. + // Skip adjoint accumulation for semantic results arguments. auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( ai->getNumIndirectResults() + i); - if (paramInfo.isIndirectMutating()) + if (paramInfo.isAutoDiffSemanticResult()) continue; auto tan = *allResultsIt++; if (tan->getType().isAddress()) { @@ -1218,60 +1222,6 @@ class PullbackCloner::Implementation final } } - /// Handle `ref_element_addr` instruction. - /// Original: y = ref_element_addr x, - /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) - /// ^~~~~~~ - /// field in tangent space corresponding to #field - void visitRefElementAddrInst(RefElementAddrInst *reai) { - auto *bb = reai->getParent(); - auto loc = reai->getLoc(); - auto adjBuf = getAdjointBuffer(bb, reai); - auto classOperand = reai->getOperand(); - auto classType = remapType(reai->getOperand()->getType()).getASTType(); - auto *tanField = - getTangentStoredProperty(getContext(), reai, classType, getInvoker()); - assert(tanField && "Invalid projections should have been diagnosed"); - switch (getTangentValueCategory(classOperand)) { - case SILValueCategory::Object: { - auto classTy = remapType(classOperand->getType()).getASTType(); - auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType(); - auto tangentVectorSILTy = - SILType::getPrimitiveObjectType(tangentVectorTy); - auto *tangentVectorDecl = - tangentVectorTy->getStructOrBoundGenericStruct(); - // Accumulate adjoint for the `ref_element_addr` operand. - SmallVector eltVals; - for (auto *field : tangentVectorDecl->getStoredProperties()) { - if (field == tanField) { - auto adjElt = builder.emitLoadValueOperation( - reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy); - eltVals.push_back(makeConcreteAdjointValue(adjElt)); - recordTemporary(adjElt); - } else { - auto substMap = tangentVectorTy->getMemberSubstitutionMap( - field->getModuleContext(), field); - auto fieldTy = field->getType().subst(substMap); - auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); - assert(fieldSILTy.isObject()); - eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); - } - } - addAdjointValue(bb, classOperand, - makeAggregateAdjointValue(tangentVectorSILTy, eltVals), - loc); - break; - } - case SILValueCategory::Address: { - auto adjBufClass = getAdjointBuffer(bb, classOperand); - auto adjBufElt = - builder.createStructElementAddr(loc, adjBufClass, tanField); - builder.emitInPlaceAdd(loc, adjBufElt, adjBuf); - break; - } - } - } - /// Handle `tuple` instruction. /// Original: y = tuple (x0, x1, x2, ...) /// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y] @@ -1467,6 +1417,18 @@ class PullbackCloner::Implementation final assert(isa(inst) || isa(inst)); auto *bb = inst->getParent(); auto loc = inst->getLoc(); + + // Loading of class values does not produce a new value as classes are + // reference-counted. Treat load as a "projection" in such case, do not + // create a new adjoint buffer for it + if (inst->getType().isObject() && + inst->getType().isReferenceCounted(getModule())) { + assert(getTangentValueCategory(inst) == SILValueCategory::Address && + "expected address tangent"); + // No adjoint here + return; + } + switch (getTangentValueCategory(inst)) { case SILValueCategory::Object: { auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc); @@ -1703,6 +1665,7 @@ class PullbackCloner::Implementation final // Address projections. NO_ADJOINT(StructElementAddr) NO_ADJOINT(TupleElementAddr) + NO_ADJOINT(RefElementAddr) // Array literal initialization address projections. NO_ADJOINT(PointerToAddress) @@ -1712,6 +1675,7 @@ class PullbackCloner::Implementation final NO_ADJOINT(AllocStack) NO_ADJOINT(DeallocStack) NO_ADJOINT(EndAccess) + NO_ADJOINT(AllocRef) // Debugging/reference counting instructions. NO_ADJOINT(DebugValue) @@ -1784,6 +1748,7 @@ bool PullbackCloner::Implementation::run() { collectAllFormalResultsInTypeOrder(original, origFormalResults); for (auto resultIndex : getConfig().resultIndices->getIndices()) { auto origResult = origFormalResults[resultIndex]; + assert(origResult && "expected result"); // If original result is non-varied, it will always have a zero derivative. // Skip full pullback generation and simply emit zero derivatives for wrt // parameters. @@ -2036,6 +2001,7 @@ bool PullbackCloner::Implementation::run() { // the adjoint buffer of the original result. auto seedParamInfo = pullback.getLoweredFunctionType()->getParameters()[seedIndex]; + if (seedParamInfo.isIndirectInOut()) { setAdjointBuffer(originalExitBlock, origResult, seed); } @@ -2064,7 +2030,7 @@ bool PullbackCloner::Implementation::run() { // If the original function is an accessor with special-case pullback // generation logic, do special-case generation. - if (isSemanticMemberAccessor(&original)) { + if (0 && isSemanticMemberAccessor(&original)) { if (runForSemanticMemberAccessor()) return true; } @@ -2123,7 +2089,7 @@ bool PullbackCloner::Implementation::run() { // Collect differentiation parameter adjoints. // Do a first pass to collect non-inout values. for (auto i : getConfig().parameterIndices->getIndices()) { - if (!conv.getParameters()[i].isIndirectMutating()) { + if (!conv.getParameters()[i].isAutoDiffSemanticResult()) { addRetElt(i); } } @@ -2136,14 +2102,14 @@ bool PullbackCloner::Implementation::run() { const auto &pullbackConv = pullback.getConventions(); SmallVector pullbackInOutArgs; for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) { - if (pullbackConv.getParameters()[pullbackArg.index()].isIndirectMutating()) + if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult()) pullbackInOutArgs.push_back(pullbackArg.value()); } unsigned pullbackInoutArgumentIdx = 0; for (auto i : getConfig().parameterIndices->getIndices()) { // Skip non-inout parameters. - if (!conv.getParameters()[i].isIndirectMutating()) + if (!conv.getParameters()[i].isAutoDiffSemanticResult()) continue; // For functions with multiple basic blocks, accumulation is needed @@ -2662,6 +2628,7 @@ bool PullbackCloner::Implementation::runForSemanticMemberGetter() { assert(getConfig().resultIndices->getNumIndices() == 1 && "Getter should have one semantic result"); auto origResult = origFormalResults[*getConfig().resultIndices->begin()]; + assert(origResult && "expected result"); auto tangentVectorSILTy = pullback.getConventions().getResults().front() .getSILStorageType(getModule(), @@ -2842,56 +2809,23 @@ SILValue PullbackCloner::Implementation::getAdjointProjection( } return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex); } + // Handle `ref_element_addr`. - // Adjoint projection: a local allocation initialized with the corresponding - // field value from the class's base adjoint value. + // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. if (auto *reai = dyn_cast(originalProjection)) { assert(!reai->getField()->getAttrs().hasAttribute() && "`@noDerivative` class projections should never be active"); - auto loc = reai->getLoc(); // Get the class operand, stripping `begin_borrow`. auto classOperand = stripBorrow(reai->getOperand()); + auto adjSource = getAdjointBuffer(origBB, classOperand); auto classType = remapType(reai->getOperand()->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), reai->getField(), classType, reai->getLoc(), getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); - // Create a local allocation for the element adjoint buffer. - auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); - auto eltTanSILType = - remapType(SILType::getPrimitiveAddressType(eltTanType)); - auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); - // Check the class operand's `TangentVector` value category. - switch (getTangentValueCategory(classOperand)) { - case SILValueCategory::Object: { - // Get the class operand's adjoint value. Currently, it must be a - // `TangentVector` struct. - auto adjClass = - materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); - builder.emitScopedBorrowOperation( - loc, adjClass, [&](SILValue borrowedAdjClass) { - // Initialize the element adjoint buffer with the base adjoint - // value. - auto *adjElt = - builder.createStructExtract(loc, borrowedAdjClass, tanField); - auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt); - builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer, - StoreOwnershipQualifier::Init); - }); - return eltAdjBuffer; - } - case SILValueCategory::Address: { - // Get the class operand's adjoint buffer. Currently, it must be a - // `TangentVector` struct. - auto adjClass = getAdjointBuffer(origBB, classOperand); - // Initialize the element adjoint buffer with the base adjoint buffer. - auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField); - builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake, - IsInitialization); - return eltAdjBuffer; - } - } + return builder.createStructElementAddr(reai->getLoc(), adjSource, tanField); } + // Handle `begin_access`. // Adjoint projection: the base adjoint buffer itself. if (auto *bai = dyn_cast(originalProjection)) { @@ -2901,6 +2835,21 @@ SILValue PullbackCloner::Implementation::getAdjointProjection( // Return the base buffer's adjoint buffer. return adjBase; } + + // Handle `load` that produces class value. Loading of class values does not + // produce a new value as classes are reference-counted. Treat load as a + // "projection" in such case, do not create a new adjoint buffer for it + if (auto *li = dyn_cast(originalProjection)) { + if (li->getType().isObject() && + li->getType().isReferenceCounted(getModule())) { + auto adjBase = getAdjointBuffer(origBB, li->getOperand()); + if (errorOccurred) + return (bufferMap[{origBB, originalProjection}] = SILValue()); + // Return the base buffer's adjoint buffer. + return adjBase; + } + } + // Handle `array.uninitialized_intrinsic` application element addresses. // Adjoint projection: a local allocation initialized by applying // `Array.TangentVector.subscript` to the base array's adjoint value. diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 487bce2929183..30317cef1f512 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -365,7 +365,9 @@ getOrCreateSubsetParametersThunkForLinearMap( const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, ADContext &adContext) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for " << linearMapType + << "Getting a subset parameters thunk for " + << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") + << " linear map " << linearMapType << " from " << actualConfig << " to " << desiredConfig << '\n'); assert(!linearMapType->getCombinedSubstitutions()); @@ -539,10 +541,10 @@ getOrCreateSubsetParametersThunkForLinearMap( unsigned pullbackResultIndex = 0; for (unsigned i : actualConfig.parameterIndices->getIndices()) { auto origParamInfo = origFnType->getParameters()[i]; - // Skip original `inout` parameters. All non-indirect-result pullback - // arguments (including `inout` arguments) are appended to `arguments` + // Skip original semantic result parameters. All non-indirect-result pullback + // arguments (including semantic result` arguments) are appended to `arguments` // later. - if (origParamInfo.isIndirectMutating()) + if (origParamInfo.isAutoDiffSemanticResult()) continue; auto resultInfo = linearMapType->getResults()[pullbackResultIndex]; assert(pullbackResultIndex < linearMapType->getNumResults()); @@ -589,11 +591,15 @@ getOrCreateSubsetParametersThunkForLinearMap( extractAllElements(ai, builder, differentialDirectResults); SmallVector allResults; collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults); - unsigned numResults = thunk->getConventions().getNumDirectSILResults() + - thunk->getConventions().getNumDirectSILResults(); SmallVector results; + for (unsigned idx : *actualConfig.resultIndices) { - if (idx >= numResults) + // Skip indirect results, they were handled previously (either forwarded + // or function local temporary was passed) + if (idx < ai->getSubstCalleeConv().getNumIndirectSILResults()) + continue; + // Larger indices are reserved for semantic result parameters + if (idx >= allResults.size()) break; auto result = allResults[idx]; @@ -619,16 +625,18 @@ getOrCreateSubsetParametersThunkForLinearMap( extractAllElements(ai, builder, pullbackDirectResults); SmallVector allResults; collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults); - // Collect pullback `inout` arguments in type order. - unsigned inoutArgIdx = 0; + // Collect pullback semantic result arguments in type order. + unsigned semanticResultArgIdx = 0; SILFunctionConventions origConv(origFnType, thunk->getModule()); for (auto paramIdx : actualConfig.parameterIndices->getIndices()) { auto paramInfo = origConv.getParameters()[paramIdx]; - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; - auto inoutArg = *std::next(ai->getInoutArguments().begin(), inoutArgIdx++); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx++); unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx); - allResults.insert(allResults.begin() + mappedParamIdx, inoutArg); + allResults.insert(allResults.begin() + mappedParamIdx, semanticResultArg); } assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() && "Number of pullback results should match number of differentiability " @@ -668,8 +676,10 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, ADContext &adContext) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for derivative function " - << derivativeFn << " of the original function " << origFnOperand + << "Getting a subset parameters thunk for derivative " + << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") + << " function " << derivativeFn + << " of the original function " << origFnOperand << " from " << actualConfig << " to " << desiredConfig << '\n'); auto origFnType = origFnOperand->getType().castTo(); @@ -823,9 +833,7 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( SILType::getPrimitiveObjectType(linearMapTargetType), /*withoutActuallyEscaping*/ false); } - assert(origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters() > - 0); + assert(origFnType->getNumAutoDiffSemanticResults() > 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { directResults.push_back(thunkedLinearMap); @@ -835,6 +843,9 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( builder.createReturn(loc, thunkedLinearMap); } + LLVM_DEBUG(getADDebugStream() << + "Generated thunk:\n" << *thunk); + return {thunk, interfaceSubs}; } diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index f91175c655e91..2d0ec305ef198 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -445,29 +445,25 @@ class VJPCloner::Implementation final SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo, - allResults, activeParamIndices, - activeResultIndices); + allResults, + activeParamIndices, activeResultIndices); assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); assert(!activeResultIndices.empty() && "Result indices cannot be empty"); - LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; llvm::interleave( activeParamIndices.begin(), activeParamIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}, results={"; llvm::interleave( + s << "), results=("; llvm::interleave( activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}\n";); + s << ")\n";); // Form expected indices. - auto numSemanticResults = - ai->getSubstCalleeType()->getNumResults() + - ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); + auto numParams = ai->getArgumentsWithoutIndirectResults().size(); + auto numResults = ai->getSubstCalleeType()->getNumResults() + numParams; AutoDiffConfig config( - IndexSubset::get(getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), - activeParamIndices), - IndexSubset::get(getASTContext(), numSemanticResults, - activeResultIndices)); + IndexSubset::get(getASTContext(), numParams, activeParamIndices), + IndexSubset::get(getASTContext(), numResults, activeResultIndices)); // Emit the VJP. SILValue vjpValue; @@ -537,10 +533,10 @@ class VJPCloner::Implementation final for (auto resultIndex : config.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults(); + const auto ¶m = originalFnTy->getParameters()[semanticResultArgIdx]; + assert(param.isAutoDiffSemanticResult() && "expected autodiff semantic result parameter"); + remappedResultType = param.getSILStorageInterfaceType(); } else { remappedResultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); @@ -891,55 +887,57 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { auto config = witness->getConfig(); // Add pullback parameters based on original result indices. - SmallVector inoutParamIndices; + SmallVector semanticResultParamIndices; for (auto i : range(origTy->getNumParameters())) { auto origParam = origParams[i]; - if (!origParam.isIndirectInOut()) + if (!origParam.isAutoDiffSemanticResult()) continue; - inoutParamIndices.push_back(i); + semanticResultParamIndices.push_back(i); } + for (auto resultIndex : config.resultIndices->getIndices()) { // Handle formal result. if (resultIndex < origTy->getNumResults()) { auto origResult = origTy->getResults()[resultIndex]; origResult = origResult.getWithInterfaceType( origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); - pbParams.push_back(getTangentParameterInfoForOriginalResult( + auto paramInfo = getTangentParameterInfoForOriginalResult( origResult.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), - origResult.getConvention())); + origResult.getConvention()); + if (origResult.getInterfaceType()->getClassOrBoundGenericClass()) + paramInfo = paramInfo.getWithConvention(ParameterConvention::Indirect_In_Guaranteed); + pbParams.push_back(paramInfo); continue; } - // Handle `inout` parameter. - unsigned paramIndex = 0; - unsigned inoutParamIndex = 0; - for (auto i : range(origTy->getNumParameters())) { - auto origParam = origTy->getParameters()[i]; - if (!origParam.isIndirectMutating()) { - ++paramIndex; - continue; - } - if (inoutParamIndex == resultIndex - origTy->getNumResults()) - break; - ++paramIndex; - ++inoutParamIndex; + + // Handle semantic result parameter. + unsigned resultParamIndex = resultIndex - origTy->getNumResults(); + auto resultParam = origParams[resultParamIndex]; + assert(resultParam.isAutoDiffSemanticResult() && "expected autodiff semantic result parameter"); + + auto origResult = resultParam.getWithInterfaceType( + resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); + auto resultParamTanConvention = resultParam.getConvention(); + if (resultParam.isIndirectMutating()) { + if (!config.isWrtParameter(resultParamIndex)) + resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; + } else { + assert(resultParam.getInterfaceType()->getClassOrBoundGenericClass() && + "expected class-bound semantic result param"); + if (config.isWrtParameter(resultParamIndex)) + resultParamTanConvention = ParameterConvention::Indirect_Inout; + else + resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; } - auto inoutParam = origParams[paramIndex]; - auto origResult = inoutParam.getWithInterfaceType( - inoutParam.getInterfaceType()->getReducedType(witnessCanGenSig)); - auto inoutParamTanConvention = - config.isWrtParameter(paramIndex) - ? inoutParam.getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - SILParameterInfo inoutParamTanParam( - origResult.getInterfaceType() - ->getAutoDiffTangentSpace(lookupConformance) - ->getType() - ->getReducedType(witnessCanGenSig), - inoutParamTanConvention); - pbParams.push_back(inoutParamTanParam); + + pbParams.emplace_back(origResult.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getReducedType(witnessCanGenSig), + resultParamTanConvention); } if (pullbackInfo.hasHeapAllocatedContext()) { @@ -961,7 +959,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { // Add pullback results for the requested wrt parameters. for (auto i : config.parameterIndices->getIndices()) { auto origParam = origParams[i]; - if (origParam.isIndirectMutating()) + if (origParam.isAutoDiffSemanticResult()) continue; origParam = origParam.getWithInterfaceType( origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); @@ -997,6 +995,8 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { original->isRuntimeAccessible()); pullback->setDebugScope(new (module) SILDebugScope(original->getLocation(), pullback)); + pullback->setInlineStrategy(original->getInlineStrategy()); + return pullback; } diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 8a1f59a260a76..07d64173f021d 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -544,11 +544,10 @@ emitDerivativeFunctionReference( for (auto resultIndex : desiredResultIndices->getIndices()) { SILType resultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutParamIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutParam = - *std::next(originalFnTy->getIndirectMutatingParameters().begin(), - inoutParamIdx); - resultType = inoutParam.getSILStorageInterfaceType(); + auto semanticResultParamIdx = resultIndex - originalFnTy->getNumResults(); + const auto ¶m = originalFnTy->getParameters()[semanticResultParamIdx]; + assert(param.isAutoDiffSemanticResult() && "expected autodiff semantic result parameter"); + resultType = param.getSILStorageInterfaceType(); } else { resultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); @@ -785,6 +784,7 @@ static SILFunction *createEmptyVJP(ADContext &context, original->isDistributed(), original->isRuntimeAccessible()); vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp)); + vjp->setInlineStrategy(original->getInlineStrategy()); LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType() << "\n"); @@ -1105,6 +1105,7 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction( // - For VJPs: the thunked VJP returns a pullback that drops the unused // tangent values. auto actualConfig = derivativeFnAndIndices->second; + // NOTE: `desiredIndices` may come from a partially-applied function and // have smaller capacity than `actualIndices`. We expect this logic to go // away when we support `@differentiable` partial apply. @@ -1112,8 +1113,11 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction( auto extendedDesiredParameterIndices = desiredConfig.parameterIndices->extendingCapacity( astCtx, actualConfig.parameterIndices->getCapacity()); + auto extendedDesiredResultIndices = + desiredConfig.resultIndices->extendingCapacity( + astCtx, actualConfig.resultIndices->getCapacity()); if (!actualConfig.parameterIndices->equals(extendedDesiredParameterIndices) - || !actualConfig.resultIndices->equals(desiredConfig.resultIndices)) { + || !actualConfig.resultIndices->equals(extendedDesiredResultIndices)) { // Destroy the already emitted derivative function reference because it // is no longer used. builder.emitDestroyValueOperation(loc, derivativeFn); @@ -1139,6 +1143,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction( // Create the parameter subset thunk. assert(actualConfig.parameterIndices->isSupersetOf( extendedDesiredParameterIndices)); + assert(actualConfig.resultIndices->isSupersetOf( + extendedDesiredResultIndices)); SILFunction *thunk; SubstitutionMap interfaceSubs; SILOptFunctionBuilder fb(transform); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index ff93b7f366ffd..12d15c61e8f0b 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5764,18 +5764,9 @@ resolveDifferentiableAccessors(DifferentiableAttr *attr, if (!typecheckAccessor(asd->getSynthesizedAccessor(AccessorKind::Get))) return nullptr; - if (asd->supportsMutation()) { - // FIXME: Class-typed values have reference semantics and can be freely - // mutated. Thus, they should be treated like inout parameters for the - // purposes of @differentiable and @derivative type-checking. Until - // https://github.com/apple/swift/issues/55542 is fixed, check if setter has - // computed semantic results and do not typecheck if they are none - // (class-typed `self' parameter is not treated as a "semantic result" - // currently) - if (!asd->getDeclContext()->getSelfClassDecl()) - if (!typecheckAccessor(asd->getSynthesizedAccessor(AccessorKind::Set))) - return nullptr; - } + if (asd->supportsMutation()) + if (!typecheckAccessor(asd->getSynthesizedAccessor(AccessorKind::Set))) + return nullptr; // Remove `@differentiable` attribute from storage declaration to prevent // duplicate attribute registration during SILGen. diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index e97c685ded1ce..08b57fbcac108 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -3068,7 +3068,7 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, ListOfValues.slice(numParamIndices, numResultIndices), [](uint64_t i) { return (unsigned)i; }); auto *resultIndices = - IndexSubset::get(MF->getContext(), numResults, rawResultIndices); + IndexSubset::get(MF->getContext(), numParams + numResults, rawResultIndices); SmallVector operands; for (auto i = numParamIndices + numResultIndices; i < numParamIndices + numOperands * 3; i += 3) { @@ -4206,10 +4206,10 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) { IndexSubset::get(MF->getContext(), originalFnType->getNumParameters(), ArrayRef(parameterAndResultIndices) .take_front(numParameterIndices)); - auto numResults = originalFnType->getNumResults() + - originalFnType->getNumIndirectMutatingParameters(); + auto numSemanticResults = originalFnType->getNumResults() + + originalFnType->getNumParameters(); auto *resultIndices = - IndexSubset::get(MF->getContext(), numResults, + IndexSubset::get(MF->getContext(), numSemanticResults, ArrayRef(parameterAndResultIndices) .take_back(numResultIndices)); diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 629e9fe90990b..7fb8d293a929e 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -3013,11 +3013,7 @@ void SILSerializer::writeSILDifferentiabilityWitness( dw.getParameterIndices()->getCapacity() && "Original function parameter count should match differentiability " "witness parameter indices capacity"); - unsigned numInoutParameters = llvm::count_if( - originalFnType->getParameters(), [](SILParameterInfo paramInfo) { - return paramInfo.isIndirectMutating(); - }); - assert(originalFnType->getNumResults() + numInoutParameters == + assert(originalFnType->getNumResults() + originalFnType->getNumParameters() == dw.getResultIndices()->getCapacity() && "Original function result count should match differentiability " "witness result indices capacity"); diff --git a/test/AutoDiff/SIL/differentiability_witness_function_inst.sil b/test/AutoDiff/SIL/differentiability_witness_function_inst.sil index 9318d934a9faf..ebdba78e4cc3a 100644 --- a/test/AutoDiff/SIL/differentiability_witness_function_inst.sil +++ b/test/AutoDiff/SIL/differentiability_witness_function_inst.sil @@ -83,26 +83,26 @@ bb0: // CHECK: {{%.*}} = differentiability_witness_function [vjp] [reverse] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 // CHECK: } -// IRGEN: @fooWJrSUUpSr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT:[0-9]+]] -// IRGEN: @fooWJrSSUpSr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] -// IRGEN: @barWJrSUUpSUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] -// IRGEN: @barWJrSSUpSSr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] -// IRGEN: @generic16_Differentiation14DifferentiableRzlWJrSUpSr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] -// IRGEN: @generics18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlWJrSSpSr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] -// IRGEN: @generic16_Differentiation14DifferentiableRz13TangentVector{{.*}}WJrSSpSr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @fooWJrSUUpSUUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT:[0-9]+]] +// IRGEN: @fooWJrSSUpSUUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @barWJrSUUpSUUUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @barWJrSSUpSSUUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @generic16_Differentiation14DifferentiableRzlWJrSUpSUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @generics18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlWJrSSpSUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @generic16_Differentiation14DifferentiableRz13TangentVector{{.*}}WJrSSpSUUr = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] // IRGEN-LABEL: define {{.*}} @test_derivative_witnesses() -// IRGEN: [[PTR1:%.*]] = load ptr, ptr @fooWJrSUUpSr, align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR1:%.*]] = load ptr, ptr @fooWJrSUUpSUUUr, align [[PTR_ALIGNMENT]] -// IRGEN: [[PTR2:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @fooWJrSSUpSr, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR2:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @fooWJrSSUpSUUUr, i32 0, i32 1), align [[PTR_ALIGNMENT]] -// IRGEN: [[PTR3:%.*]] = load ptr, ptr @barWJrSUUpSUr, align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR3:%.*]] = load ptr, ptr @barWJrSUUpSUUUUr, align [[PTR_ALIGNMENT]] -// IRGEN: [[PTR4:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @barWJrSSUpSSr, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR4:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @barWJrSSUpSSUUUr, i32 0, i32 1), align [[PTR_ALIGNMENT]] -// IRGEN: [[PTR5:%.*]] = load ptr, ptr @generic16_Differentiation14DifferentiableRzlWJrSUpSr, align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR5:%.*]] = load ptr, ptr @generic16_Differentiation14DifferentiableRzlWJrSUpSUUr, align [[PTR_ALIGNMENT]] -// IRGEN: [[PTR6:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @generics18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlWJrSSpSr, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR6:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @generics18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlWJrSSpSUUr, i32 0, i32 1), align [[PTR_ALIGNMENT]] -// IRGEN: [[PTR7:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @generic16_Differentiation14DifferentiableRz13TangentVector{{.*}}WJrSSpSr, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[PTR7:%.*]] = load ptr, ptr getelementptr inbounds (%swift.differentiability_witness, ptr @generic16_Differentiation14DifferentiableRz13TangentVector{{.*}}WJrSSpSUUr, i32 0, i32 1), align [[PTR_ALIGNMENT]] diff --git a/test/AutoDiff/SIL/sil_differentiability_witness.sil b/test/AutoDiff/SIL/sil_differentiability_witness.sil index b050775c91c4a..b7c1119264c17 100644 --- a/test/AutoDiff/SIL/sil_differentiability_witness.sil +++ b/test/AutoDiff/SIL/sil_differentiability_witness.sil @@ -54,7 +54,7 @@ sil_differentiability_witness [reverse] [parameters 0] [results 0] @externalFn1 // ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // ROUNDTRIP: } -// IRGEN-LABEL: @externalFn1WJrSpSr ={{( protected)?}} global { ptr, ptr } { +// IRGEN-LABEL: @externalFn1WJrSpSUr ={{( protected)?}} global { ptr, ptr } { // IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0 // IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0 // IRGEN-SAME: } @@ -78,7 +78,7 @@ sil_differentiability_witness [reverse] [parameters 0] [results 0] @externalFn2 // ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // ROUNDTRIP: } -// IRGEN-LABEL: @externalFn2WJrSpSr ={{( protected)?}} global { ptr, ptr } { +// IRGEN-LABEL: @externalFn2WJrSpSUr ={{( protected)?}} global { ptr, ptr } { // IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0 // IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0 // IRGEN-SAME: } @@ -125,7 +125,7 @@ sil_differentiability_witness [reverse] [parameters 0] [results 0] @foo : $@conv // ROUNDTRIP: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // ROUNDTRIP: } -// IRGEN-LABEL: @fooWJrSpSr ={{( protected)?}} global { ptr, ptr } { +// IRGEN-LABEL: @fooWJrSpSUr ={{( protected)?}} global { ptr, ptr } { // IRGEN-SAME: @AD__foo__jvp_src_0_wrt_0 // IRGEN-SAME: @AD__foo__vjp_src_0_wrt_0 // IRGEN-SAME: } @@ -163,7 +163,7 @@ sil_differentiability_witness hidden [reverse] [parameters 0 1] [results 0] <τ_ // ROUNDTRIP: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) // ROUNDTRIP: } -// IRGEN: @generic16_Differentiation14DifferentiableRzlWJrSSpSr = hidden global { ptr, ptr } { +// IRGEN: @generic16_Differentiation14DifferentiableRzlWJrSSpSUUr = hidden global { ptr, ptr } { // IRGEN-SAME: @AD__generic__jvp_src_0_wrt_0_1 // IRGEN-SAME: @AD__generic__vjp_src_0_wrt_0_1 // IRGEN-SAME: } diff --git a/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift b/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift index 96d6a3231f35b..e7a3b6a500425 100644 --- a/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift +++ b/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift @@ -49,7 +49,7 @@ extension AllConcrete where T == Float { // CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_where_gensig_constrained : $@convention(method) (AllConcrete) -> AllConcrete { -// CHECK-NEXT: jvp: @allconcrete_where_gensig_constrainedSfRszlTJfSpSr : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: jvp: @allconcrete_where_gensig_constrainedSfRszlTJfSpSUr : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) // CHECK-NEXT: } // If a `@differentiable` or `@derivative` attribute satisfies two conditions: @@ -78,7 +78,7 @@ extension AllConcrete where T == Float { // CHECK-LABEL: // differentiability witness for allconcrete_original_gensig // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete) -> AllConcrete { -// CHECK-NEXT: jvp: @allconcrete_original_gensigTJfSpSr : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: jvp: @allconcrete_original_gensigTJfSpSUr : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) // CHECK-NEXT: } // Original generic signature: `` @@ -99,7 +99,7 @@ extension AllConcrete where T == Float { // CHECK-LABEL: // differentiability witness for allconcrete_where_gensig // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete) -> AllConcrete { -// CHECK-NEXT: jvp: @allconcrete_where_gensigTJfSpSr : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: jvp: @allconcrete_where_gensigTJfSpSUr : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) // CHECK-NEXT: } } @@ -130,7 +130,7 @@ extension NotAllConcrete where T == Float { // CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @notallconcrete_where_gensig_constrained : $@convention(method) (NotAllConcrete) -> NotAllConcrete { -// CHECK-NEXT: jvp: @notallconcrete_where_gensig_constrainedSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: jvp: @notallconcrete_where_gensig_constrainedSfRszr0_lTJfSpSUr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) // CHECK-NEXT: } extension NotAllConcrete where T == Float { @@ -152,7 +152,7 @@ extension NotAllConcrete where T == Float { // CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @notallconcrete_original_gensig : $@convention(method) (NotAllConcrete) -> NotAllConcrete { -// CHECK-NEXT: jvp: @notallconcrete_original_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: jvp: @notallconcrete_original_gensigSfRszr0_lTJfSpSUr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) // CHECK-NEXT: } // Original generic signature: `` @@ -173,6 +173,6 @@ extension NotAllConcrete where T == Float { // CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @notallconcrete_where_gensig : $@convention(method) (NotAllConcrete) -> NotAllConcrete { -// CHECK-NEXT: jvp: @notallconcrete_where_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: jvp: @notallconcrete_where_gensigSfRszr0_lTJfSpSUr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) // CHECK-NEXT: } } diff --git a/test/AutoDiff/SILGen/has_symbol.swift b/test/AutoDiff/SILGen/has_symbol.swift index f2041fa381f1f..2d1c050e7f816 100644 --- a/test/AutoDiff/SILGen/has_symbol.swift +++ b/test/AutoDiff/SILGen/has_symbol.swift @@ -34,8 +34,8 @@ func testGlobalFunctions() { // --- foo(_:) --- // CHECK: sil @$s7Library3fooyS2fF : $@convention(thin) (Float) -> Float -// CHECK: sil @$s7Library3fooyS2fFTJfSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK: sil @$s7Library3fooyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: sil @$s7Library3fooyS2fFTJfSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: sil @$s7Library3fooyS2fFTJrSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // FIXME: missing reverse-mode differentiability witness for foo(_:) // --- bar(_:) --- diff --git a/test/AutoDiff/SILGen/inout_differentiability_witness.swift b/test/AutoDiff/SILGen/inout_differentiability_witness.swift index e49b4e92a947d..809bf3f7ab5e9 100644 --- a/test/AutoDiff/SILGen/inout_differentiability_witness.swift +++ b/test/AutoDiff/SILGen/inout_differentiability_witness.swift @@ -17,7 +17,7 @@ func test3(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return @differentiable(reverse, wrt: y) func test4(x: Int, y: inout DiffableStruct, z: Float) -> Void { } -@differentiable(reverse, wrt: z) +@differentiable(reverse, wrt: (y, z)) func test5(x: Int, y: inout DiffableStruct, z: Float) -> Void { } @differentiable(reverse, wrt: (y, z)) @@ -25,36 +25,36 @@ func test6(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return // CHECK-LABEL: differentiability witness for test1(x:y:z:) // CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftF : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> Float { -// CHECK: jvp: @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK: vjp: @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: jvp: @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftFTJfUUSpSUUUr : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: vjp: @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftFTJrUUSpSUUUr : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK: } // CHECK-LABEL: differentiability witness for test2(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0] @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { -// CHECK: jvp: @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftFTJfUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> () -// CHECK: vjp: @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftFTJrUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 1] @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftFTJfUSSpUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftFTJrUSSpUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float // CHECK: } // CHECK-LABEL: differentiability witness for test3(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0 1 2] @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float) { -// CHECK: jvp: @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftFTJfUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> (Float, Float)) -// CHECK: vjp: @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftFTJrUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (Float, Float, @inout DiffableStruct.TangentVector) -> Float) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0 1 3] @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float) { +// CHECK: jvp: @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftFTJfUSSpSSUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> (Float, Float)) +// CHECK: vjp: @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftFTJrUSSpSSUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (Float, Float, @inout DiffableStruct.TangentVector) -> Float) // CHECK: } // CHECK-LABEL: differentiability witness for test4(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1] [results 0] @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { -// CHECK: jvp: @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftFTJfUSUpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> () -// CHECK: vjp: @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftFTJrUSUpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> () +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1] [results 1] @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftFTJfUSUpUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftFTJrUSUpUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> () // CHECK: } // CHECK-LABEL: differentiability witness for test5(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { -// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (Float) -> @out DiffableStruct.TangentVector -// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@in_guaranteed DiffableStruct.TangentVector) -> Float +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 1] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUSSpUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUSSpUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float // CHECK: } // CHECK-LABEL: differentiability witness for test6(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0 1 2] @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float) { -// CHECK: jvp: @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftFTJfUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> (Float, Float)) -// CHECK: vjp: @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftFTJrUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (Float, Float, @inout DiffableStruct.TangentVector) -> Float) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0 1 3] @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float) { +// CHECK: jvp: @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftFTJfUSSpSSUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> (Float, Float)) +// CHECK: vjp: @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftFTJrUSSpSSUSUr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (Float, Float, @inout DiffableStruct.TangentVector) -> Float) // CHECK: } diff --git a/test/AutoDiff/SILGen/sil_differentiability_witness.swift b/test/AutoDiff/SILGen/sil_differentiability_witness.swift index 995d56e58b96c..07913d21895d8 100644 --- a/test/AutoDiff/SILGen/sil_differentiability_witness.swift +++ b/test/AutoDiff/SILGen/sil_differentiability_witness.swift @@ -31,8 +31,8 @@ public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { // CHECK-LABEL: // differentiability witness for foo(_:) // CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3fooyS2fF : $@convention(thin) (Float) -> Float { -// CHECK-NEXT: jvp: @$s29sil_differentiability_witness3fooyS2fFTJfSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK-NEXT: vjp: @$s29sil_differentiability_witness3fooyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-NEXT: jvp: @$s29sil_differentiability_witness3fooyS2fFTJfSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-NEXT: vjp: @$s29sil_differentiability_witness3fooyS2fFTJrSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-NEXT: } // Test internal non-generic function. @@ -50,7 +50,7 @@ func bar_jvp(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> F // CHECK-LABEL: // differentiability witness for bar(_:_:) // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <τ_0_0> @$s29sil_differentiability_witness3baryS2f_xtlF : $@convention(thin) (Float, @in_guaranteed T) -> Float { -// CHECK-NEXT: jvp: @$s29sil_differentiability_witness3baryS2f_xtlFlTJfSUpSr : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-NEXT: jvp: @$s29sil_differentiability_witness3baryS2f_xtlFlTJfSUpSUUr : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-NEXT: } // Test internal generic function. @@ -76,8 +76,8 @@ func generic_vjp(_ x: T, _ y: Float) -> ( // CHECK-LABEL: // differentiability witness for generic(_:_:) // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @$s29sil_differentiability_witness7genericyxx_SftlF : $@convention(thin) (@in_guaranteed T, Float) -> @out T { -// CHECK-NEXT: jvp: @$s29sil_differentiability_witness7genericyxx_SftlF16_Differentiation14DifferentiableRzlTJfSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>) -// CHECK-NEXT: vjp: @$s29sil_differentiability_witness7genericyxx_SftlF16_Differentiation14DifferentiableRzlTJrSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>) +// CHECK-NEXT: jvp: @$s29sil_differentiability_witness7genericyxx_SftlF16_Differentiation14DifferentiableRzlTJfSSpSUUr : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>) +// CHECK-NEXT: vjp: @$s29sil_differentiability_witness7genericyxx_SftlF16_Differentiation14DifferentiableRzlTJrSSpSUUr : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>) // CHECK-NEXT: } public struct Foo: Differentiable { @@ -191,7 +191,7 @@ extension P1 { // CHECK-LABEL: // differentiability witness for P1.foo() // CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : P1> @$s29sil_differentiability_witness2P1PAAE3fooSfyF : $@convention(method) (@in_guaranteed Self) -> Float { -// CHECK-NEXT: vjp: @$s29sil_differentiability_witness2P1PAAE3fooSfyFAaBRzlTJrSpSr : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) +// CHECK-NEXT: vjp: @$s29sil_differentiability_witness2P1PAAE3fooSfyFAaBRzlTJrSpSUr : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // CHECK-NEXT: } // Test custom derivatives of functions with generic signatures and `@differentiable` attributes. diff --git a/test/AutoDiff/SILGen/vtable.swift b/test/AutoDiff/SILGen/vtable.swift index 882031b8f9c16..49391bdd7eaf6 100644 --- a/test/AutoDiff/SILGen/vtable.swift +++ b/test/AutoDiff/SILGen/vtable.swift @@ -82,7 +82,7 @@ class Sub: Super { @differentiable(reverse) override var property: Float { base } @derivative(of: property) - final func vjpProperty() -> (value: Float, pullback: (Float) -> TangentVector) { + final func vjpProperty() -> (value: Float, pullback: (Float, inout TangentVector) -> ()) { fatalError() } diff --git a/test/AutoDiff/SILGen/witness_table.swift b/test/AutoDiff/SILGen/witness_table.swift index 5928bd844d09f..4f631f4067dc2 100644 --- a/test/AutoDiff/SILGen/witness_table.swift +++ b/test/AutoDiff/SILGen/witness_table.swift @@ -12,7 +12,7 @@ protocol Protocol: Differentiable { @differentiable(reverse) var property: Float { get set } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) subscript(_ x: Float, _ y: Float) -> Float { get set } } @@ -82,22 +82,22 @@ struct Struct: Protocol { // CHECK: apply [[VJP_FN]] // CHECK: } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) subscript(_ x: Float, _ y: Float) -> Float { get { x } set {} } - // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, @in_guaranteed τ_0_0) -> Float for ) // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float - // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]] // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] // CHECK: apply [[JVP_FN]] // CHECK: } - // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, @out τ_0_0) for ) // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float - // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]] // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] // CHECK: apply [[VJP_FN]] // CHECK: } @@ -118,10 +118,10 @@ struct Struct: Protocol { // CHECK-NEXT: method #Protocol.property!setter.vjp.SS.: (inout Self) -> (Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW_vjp_SS // CHECK-NEXT: method #Protocol.property!modify: (inout Self) -> () -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvMTW // CHECK-NEXT: method #Protocol.subscript!getter: (Self) -> (Float, Float) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW -// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SU -// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU +// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUS.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS +// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUS.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS // CHECK-NEXT: method #Protocol.subscript!setter: (inout Self) -> (Float, Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW -// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUU.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUU -// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUU.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUU +// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUS.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUS +// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUS.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUS // CHECK-NEXT: method #Protocol.subscript!modify: (inout Self) -> (Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftciMTW // CHECK: } diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 02f75ebba20f4..532cc2e22b29b 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -407,23 +407,21 @@ func testArrayUninitializedIntrinsicApplyIndirectResult(_ x: T, _ y: T) -> [W struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } -// CHECK-LABEL: [AD] Activity info for ${{.*}}3MutV14mutatingMethodyyACF at parameter indices (0) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}3MutV14mutatingMethodyyACF at parameter indices (0, 1) and result indices (1) // CHECK: [VARIED] %0 = argument of bb0 : $Mut -// CHECK: [USEFUL] %1 = argument of bb0 : $*Mut +// CHECK: [ACTIVE] %1 = argument of bb0 : $*Mut -// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as -// active. -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) { nonactive.mutatingMethod(x) nonactive = x } -// CHECK-LABEL: [AD] Activity info for ${{.*}}17nonActiveInoutArgyyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}17nonActiveInoutArgyyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = begin_access [modify] [static] %0 : $*Mut @@ -449,14 +447,14 @@ func activeInoutArgMutatingMethod(_ x: Mut) -> Mut { // CHECK: [ACTIVE] %11 = begin_access [read] [static] %2 : $*Mut // CHECK: [ACTIVE] %12 = load [trivial] %11 : $*Mut -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { var result = nonactive result.mutatingMethod(x) nonactive = result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}31activeInoutArgMutatingMethodVaryyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}31activeInoutArgMutatingMethodVaryyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = alloc_stack $Mut, var, name "result" @@ -470,14 +468,14 @@ func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { // CHECK: [ACTIVE] %15 = begin_access [modify] [static] %0 : $*Mut // CHECK: [NONE] %19 = tuple () -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { var result = (nonactive, x) result.0.mutatingMethod(result.0) nonactive = result.0 } -// CHECK-LABEL: [AD] Activity info for ${{.*}}33activeInoutArgMutatingMethodTupleyyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}33activeInoutArgMutatingMethodTupleyyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = alloc_stack $(Mut, Mut), var, name "result" @@ -499,39 +497,39 @@ func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { // Check `inout` arguments. @differentiable(reverse) -func activeInoutArg(_ x: Float) -> Float { +func activeInoutArg(_ x: inout Float) -> Float { var result = x result += x return result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at parameter indices (0) and result indices (0, 1) +// CHECK: [ACTIVE] %0 = argument of bb0 : $*Float // CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result" -// CHECK: [ACTIVE] %5 = begin_access [modify] [static] %2 : $*Float +// CHECK: [ACTIVE] %10 = begin_access [modify] [static] %2 : $*Float // CHECK: [NONE] // function_ref static Float.+= infix(_:_:) -// CHECK: [NONE] %7 = apply %6(%5, %0, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () -// CHECK: [ACTIVE] %9 = begin_access [read] [static] %2 : $*Float -// CHECK: [ACTIVE] %10 = load [trivial] %9 : $*Float +// CHECK: [NONE] %12 = apply %11(%10, %8, %6) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [ACTIVE] %14 = begin_access [read] [static] %2 : $*Float +// CHECK: [ACTIVE] %15 = load [trivial] %14 : $*Float @differentiable(reverse) -func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float { +func activeInoutArgNonactiveInitialResult(_ x: inout Float) -> Float { var result: Float = 1 result += x return result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at parameter indices (0) and result indices (0, 1) +// CHECK: [ACTIVE] %0 = argument of bb0 : $*Float // CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result" // CHECK: [NONE] // function_ref Float.init(_builtinIntegerLiteral:) // CHECK: [USEFUL] %6 = apply %5(%3, %4) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // CHECK: [USEFUL] %8 = metatype $@thin Float.Type -// CHECK: [ACTIVE] %9 = begin_access [modify] [static] %2 : $*Float +// CHECK: [ACTIVE] %12 = begin_access [modify] [static] %2 : $*Float // CHECK: [NONE] // function_ref static Float.+= infix(_:_:) -// CHECK: [NONE] %11 = apply %10(%9, %0, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () -// CHECK: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float -// CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float +// CHECK: [NONE] %14 = apply %13(%12, %10, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [ACTIVE] %16 = begin_access [read] [static] %2 : $*Float +// CHECK: [ACTIVE] %17 = load [trivial] %16 : $*Float //===----------------------------------------------------------------------===// // Throwing function differentiation (`try_apply`) @@ -703,7 +701,7 @@ class C: Differentiable { x * float } -// CHECK-LABEL: [AD] Activity info for ${{.*}}1CC6methodyS2fF at parameter indices (0, 1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}1CC6methodyS2fF at parameter indices (0, 1) and result indices (0, 2) // CHECK: bb0: // CHECK: [ACTIVE] %0 = argument of bb0 : $Float // CHECK: [ACTIVE] %1 = argument of bb0 : $C @@ -716,8 +714,11 @@ class C: Differentiable { } // TF-1176: Test class property `modify` accessor. +// expected-error @+1 {{function is not differentiable}} @differentiable(reverse) +// expected-note @+1 {{when differentiating this function definition}} func testClassModifyAccessor(_ c: inout C) { + // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} c.float *= c.float } @@ -726,9 +727,9 @@ func testClassModifyAccessor(_ c: inout C) { // CHECK: [ACTIVE] %0 = argument of bb0 : $*C // CHECK: [NONE] %2 = metatype $@thin Float.Type // CHECK: [ACTIVE] %3 = begin_access [read] [static] %0 : $*C -// CHECK: [VARIED] %4 = load [copy] %3 : $*C +// CHECK: [ACTIVE] %4 = load [copy] %3 : $*C // CHECK: [ACTIVE] %6 = begin_access [read] [static] %0 : $*C -// CHECK: [VARIED] %7 = load [copy] %6 : $*C +// CHECK: [ACTIVE] %7 = load [copy] %6 : $*C // CHECK: [VARIED] %9 = class_method %7 : $C, #C.float!getter : (C) -> () -> Float, $@convention(method) (@guaranteed C) -> Float // CHECK: [VARIED] %10 = apply %9(%7) : $@convention(method) (@guaranteed C) -> Float // CHECK: [VARIED] %12 = class_method %4 : $C, #C.float!modify : (C) -> () -> (), $@yield_once @convention(method) (@guaranteed C) -> @yields @inout Float diff --git a/test/AutoDiff/SILOptimizer/derivative_sil.swift b/test/AutoDiff/SILOptimizer/derivative_sil.swift index 51f02d97127a6..ac65efe003814 100644 --- a/test/AutoDiff/SILOptimizer/derivative_sil.swift +++ b/test/AutoDiff/SILOptimizer/derivative_sil.swift @@ -32,7 +32,7 @@ func foo(_ x: Float) -> Float { // CHECK-SIL-LABEL: enum _AD__fooMethod_bb0__Pred__src_0_wrt_0 { // CHECK-SIL-NEXT: } -// CHECK-SIL-LABEL: sil hidden [ossa] @fooTJfSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @fooTJfSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[X:%.*]] : $Float): // CHECK-SIL: [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [reverse] [parameters 0 1] [results 0] @add @@ -42,13 +42,13 @@ func foo(_ x: Float) -> Float { // CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_JVP_FN]]([[X]], [[X]], {{.*}}) // CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_DF:%.*]]) = destructure_tuple [[ADD_RESULT]] // CHECK-SIL: [[DF_STRUCT:%.*]] = tuple ([[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: [[DF_REF:%.*]] = function_ref @fooTJdSpSr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float +// CHECK-SIL: [[DF_REF:%.*]] = function_ref @fooTJdSpSUr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float // CHECK-SIL: [[DF_FN:%.*]] = partial_apply [callee_guaranteed] [[DF_REF]]([[DF_STRUCT]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[DF_FN]] : $@callee_guaranteed (Float) -> Float) // CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float) // CHECK-SIL: } -// CHECK-SIL-LABEL: sil private [ossa] @fooTJdSpSr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float { +// CHECK-SIL-LABEL: sil private [ossa] @fooTJdSpSUr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float { // CHECK-SIL: bb0([[DX:%.*]] : $Float, [[DF_STRUCT:%.*]] : @owned $(_: @callee_guaranteed (Float, Float) -> Float)): // CHECK-SIL: [[ADD_DF:%.*]] = destructure_tuple [[DF_STRUCT]] : $(_: @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[DY:%.*]] = apply [[ADD_DF]]([[DX]], [[DX]]) : $@callee_guaranteed (Float, Float) -> Float @@ -56,7 +56,7 @@ func foo(_ x: Float) -> Float { // CHECK-SIL: return [[DY]] : $Float // CHECK-SIL: } -// CHECK-SIL-LABEL: sil hidden [ossa] @fooTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @fooTJrSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[X:%.*]] : $Float): // CHECK-SIL: [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [reverse] [parameters 0 1] [results 0] @add @@ -65,13 +65,13 @@ func foo(_ x: Float) -> Float { // CHECK-SIL: [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]] // CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}}) // CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]] -// CHECK-SIL: [[PB_REF:%.*]] = function_ref @fooTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float +// CHECK-SIL: [[PB_REF:%.*]] = function_ref @fooTJpSpSUr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // CHECK-SIL: [[PB_FN:%.*]] = partial_apply [callee_guaranteed] [[PB_REF]]([[ADD_PB]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB_FN]] : $@callee_guaranteed (Float) -> Float) // CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float) // CHECK-SIL: } -// CHECK-SIL-LABEL: sil private [ossa] @fooTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { +// CHECK-SIL-LABEL: sil private [ossa] @fooTJpSpSUr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { // CHECK-SIL: bb0([[DY:%.*]] : $Float, [[ADD_PB:%.*]] : @owned $@callee_guaranteed (Float) -> (Float, Float)): // CHECK-SIL: debug_value [[DY]] : $Float, let, name "y" // CHECK-SIL: [[ADD_PB_RES:%.*]] = apply [[ADD_PB]]([[DY]]) : $@callee_guaranteed (Float) -> (Float, Float) @@ -104,10 +104,10 @@ struct ExampleStruct { } } -// CHECK-SIL-LABEL: sil hidden [ossa] @fooMethodTJfSUpSr : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @fooMethodTJfSUpSUUr : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -// CHECK-SIL-LABEL: sil private [ossa] @fooMethodTJdSUpSr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float { +// CHECK-SIL-LABEL: sil private [ossa] @fooMethodTJdSUpSUUr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float { -// CHECK-SIL-LABEL: sil hidden [ossa] @fooMethodTJrSUpSr : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @fooMethodTJrSUpSUUr : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -// CHECK-SIL-LABEL: sil private [ossa] @fooMethodTJpSUpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { +// CHECK-SIL-LABEL: sil private [ossa] @fooMethodTJpSUpSUUr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { diff --git a/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil b/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil index 42051ae462985..13a98f62f4b58 100644 --- a/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil +++ b/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil @@ -37,7 +37,7 @@ bb0(%0 : $Float): // CHECK: differentiability_witness_function [vjp] [reverse] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float %3 = differentiability_witness_function [vjp] [reverse] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float - // CHECK: function_ref @$sSf1poiyS2f_SftFZTJrSSUpSr : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) + // CHECK: function_ref @$sSf1poiyS2f_SftFZTJrSSUpSUUUr : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) return undef : $() } diff --git a/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift b/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift index 31c137a502d49..4f4e0d4cc11e8 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift @@ -39,7 +39,7 @@ func cond(_ x: Float) -> Float { // CHECK-DATA-STRUCTURES: case bb1((predecessor: _AD__cond_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float))) // CHECK-DATA-STRUCTURES: } -// CHECK-SIL-LABEL: sil hidden [ossa] @condTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @condTJrSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float): // CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = tuple () // CHECK-SIL: cond_br {{%.*}}, bb1, bb2 @@ -57,13 +57,13 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) // CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0) -// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr +// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSUr // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) // CHECK-SIL: return [[VJP_RESULT]] -// CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float { +// CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSUr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float { // CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0): // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3 @@ -147,7 +147,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float { } } -// CHECK-SIL-LABEL: sil hidden [ossa] @enum_notactiveTJrUSpSr : $@convention(thin) (Enum, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @enum_notactiveTJrUSpSUUr : $@convention(thin) (Enum, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[ENUM_ARG:%.*]] : $Enum, [[X_ARG:%.*]] : $Float): // CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = tuple () // CHECK-SIL: switch_enum [[ENUM_ARG]] : $Enum, case #Enum.a!enumelt: bb1, case #Enum.b!enumelt: bb2 @@ -165,7 +165,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float { // CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) // CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) -// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr +// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSUUr // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) // CHECK-SIL: return [[VJP_RESULT]] @@ -187,7 +187,7 @@ func enum_addr_notactive(_ e: AddressOnlyEnum, _ x: Float) -> Float { return x } -// CHECK-SIL-LABEL: sil hidden [ossa] @enum_addr_notactivelTJrUSpSr : $@convention(thin) <τ_0_0> (@in_guaranteed AddressOnlyEnum<τ_0_0>, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @enum_addr_notactivelTJrUSpSUUr : $@convention(thin) <τ_0_0> (@in_guaranteed AddressOnlyEnum<τ_0_0>, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[ENUM_ARG:%.*]] : $*AddressOnlyEnum<τ_0_0>, [[X_ARG:%.*]] : $Float): // CHECK-SIL: [[ENUM_ADDR:%.*]] = alloc_stack $AddressOnlyEnum<τ_0_0> // CHECK-SIL: copy_addr [[ENUM_ARG]] to [init] [[ENUM_ADDR]] : $*AddressOnlyEnum<τ_0_0> @@ -212,7 +212,7 @@ func enum_addr_notactive(_ e: AddressOnlyEnum, _ x: Float) -> Float { // CHECK-SIL: bb3([[BB3_PRED_ARG:%.*]] : $_AD__enum_addr_notactive_bb3__Pred__src_0_wrt_1_l<τ_0_0>): -// CHECK-SIL: [[PB_FNREF:%.*]] = function_ref @enum_addr_notactivelTJpUSpSr : $@convention(thin) <τ_0_0> (Float, @owned _AD__enum_addr_notactive_bb3__Pred__src_0_wrt_1_l<τ_0_0>) -> Float +// CHECK-SIL: [[PB_FNREF:%.*]] = function_ref @enum_addr_notactivelTJpUSpSUUr : $@convention(thin) <τ_0_0> (Float, @owned _AD__enum_addr_notactive_bb3__Pred__src_0_wrt_1_l<τ_0_0>) -> Float // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PB_FNREF]]<τ_0_0>([[BB3_PRED_ARG]]) : $@convention(thin) <τ_0_0> (Float, @owned _AD__enum_addr_notactive_bb3__Pred__src_0_wrt_1_l<τ_0_0>) -> Float // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[X_ARG]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) // CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float) @@ -231,7 +231,7 @@ func cond_tuple_var(_ x: Float) -> Float { return y.1 } -// CHECK-SIL-LABEL: sil private [ossa] @cond_tuple_varTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_tuple_var_bb3__Pred__src_0_wrt_0) -> Float { +// CHECK-SIL-LABEL: sil private [ossa] @cond_tuple_varTJpSpSUr : $@convention(thin) (Float, @owned _AD__cond_tuple_var_bb3__Pred__src_0_wrt_0) -> Float { // CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0): // CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float) // CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index 0f106c55bff56..e149bcb3116db 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -400,11 +400,11 @@ func activeInoutParamControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func nonActiveInoutParam(_ nonactive: inout Mut, _ x: Mut) { nonactive.mutatingMethod(x) } @@ -416,14 +416,14 @@ func activeInoutParamMutatingMethod(_ x: Mut) -> Mut { return result } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutParamMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { var result = nonactive result.mutatingMethod(x) nonactive = result } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { var result = (nonactive, x) result.0.mutatingMethod(result.0) diff --git a/test/AutoDiff/SILOptimizer/differentiation_sil.swift b/test/AutoDiff/SILOptimizer/differentiation_sil.swift index 92454bcfde542..773ccd01615d9 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_sil.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_sil.swift @@ -15,8 +15,8 @@ func basic(_ x: Float) -> Float { x } // CHECK-SILGEN-NEXT: } // CHECK-SIL-LABEL: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @basic : $@convention(thin) (Float) -> Float { -// CHECK-SIL-NEXT: jvp: @basicTJfSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK-SIL-NEXT: vjp: @basicTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-SIL-NEXT: jvp: @basicTJfSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-SIL-NEXT: vjp: @basicTJrSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-SIL-NEXT: } // Test `differentiable_function` instructions. diff --git a/test/AutoDiff/SILOptimizer/differentiation_subset_parameters_thunk.swift b/test/AutoDiff/SILOptimizer/differentiation_subset_parameters_thunk.swift index 2ed9a34cf7cae..1f57cba3ab43e 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_subset_parameters_thunk.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_subset_parameters_thunk.swift @@ -19,17 +19,17 @@ func differentiate_foo_wrt_0(_ x: Float) -> Float { foo(x, 1) } -// CHECK-LABEL: sil hidden @$s39differentiation_subset_parameters_thunk23differentiate_foo_wrt_0yS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-LABEL: sil hidden @$s39differentiation_subset_parameters_thunk23differentiate_foo_wrt_0yS2fFTJrSpSUr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK: bb0 // CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // CHECK: [[FOO_JVP:%.*]] = differentiability_witness_function [jvp] [reverse] [parameters 0 1] [results 0] @{{.*}}foo{{.*}} : $@convention(thin) (@in_guaranteed T, @in_guaranteed T) -> @out T // CHECK: [[FOO_JVP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_JVP]]() : $@convention(thin) <τ_0_0 where τ_0_0 : NumericDifferentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) -// CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @$s39differentiation_subset_parameters_thunk3fooyxx_xtSjRzlFS5fIegnr_Iegnnro_TJSfSSpSrSUP : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) +// CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @$s39differentiation_subset_parameters_thunk3fooyxx_xtSjRzlFS5fIegnr_Iegnnro_TJSfSSpSUUrSUP : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_JVP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_JVP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_VJP:%.*]] = differentiability_witness_function [vjp] [reverse] [parameters 0 1] [results 0] @{{.*}}foo{{.*}} : $@convention(thin) (@in_guaranteed T, @in_guaranteed T) -> @out T // CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]() : $@convention(thin) <τ_0_0 where τ_0_0 : NumericDifferentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) -// CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @$s39differentiation_subset_parameters_thunk3fooyxx_xtSjRzlFS5fIegnr_Iegnnro_TJSrSSpSrSUP : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) +// CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @$s39differentiation_subset_parameters_thunk3fooyxx_xtSjRzlFS5fIegnr_Iegnnro_TJSrSSpSUUrSUP : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_DIFF:%.*]] = differentiable_function [parameters 0] [results 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with_derivative {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} // CHECK: } @@ -81,7 +81,7 @@ func genericInoutIndirectCaller( return inoutIndirectCaller(x, y, z) } -// CHECK-LABEL: sil shared [transparent] [thunk] @$sSdSfSdSfIegnrrr_TJSpSSSpSrSUSP : $@convention(thin) (@in_guaranteed Double, @guaranteed @callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)) -> (@out Float, @out Float) { +// CHECK-LABEL: sil shared [transparent] [thunk] @$sSdSfSdSfIegnrrr_TJSpSSSpSUUUrSUSP : $@convention(thin) (@in_guaranteed Double, @guaranteed @callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)) -> (@out Float, @out Float) { // CHECK: bb0(%0 : $*Float, %1 : $*Float, %2 : $*Double, %3 : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)): // CHECK: %4 = alloc_stack $Double // CHECK: %5 = apply %3(%0, %4, %1, %2) : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float) @@ -91,7 +91,7 @@ func genericInoutIndirectCaller( // CHECK: return %8 : $() // CHECK: } -// CHECK-LABEL: sil shared [transparent] [thunk] @$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_TJSpSSSpSrSSUP : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> @out τ_0_0.TangentVector { +// CHECK-LABEL: sil shared [transparent] [thunk] @$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_TJSpSSSpUSUrSSUP : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> @out τ_0_0.TangentVector { // CHECK: bb0(%0 : $*τ_0_0.TangentVector, %1 : $*τ_0_1.TangentVector, %2 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)): // CHECK: %3 = alloc_stack $τ_0_2.TangentVector // CHECK: %4 = apply %2(%0, %3, %1) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector) @@ -101,7 +101,7 @@ func genericInoutIndirectCaller( // CHECK: return %7 : $() // CHECK: } -// CHECK-LABEL: sil shared [transparent] [thunk] @$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_TJSpSSSpSrUSUP : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> () { +// CHECK-LABEL: sil shared [transparent] [thunk] @$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_TJSpSSSpUSUrUSUP : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> () { // CHECK: bb0(%0 : $*τ_0_1.TangentVector, %1 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)): // CHECK: %2 = alloc_stack $τ_0_0.TangentVector // CHECK: %3 = alloc_stack $τ_0_2.TangentVector diff --git a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift index 5fb4d13407bba..cd29f9019c45c 100644 --- a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift @@ -89,7 +89,7 @@ func activeInoutParamControlFlow(_ array: [Float]) -> Float { struct X: Differentiable { var x: Float - @differentiable(reverse, wrt: y) + @differentiable(reverse, wrt: (self, y)) mutating func mutate(_ y: X) { self.x = y.x } } @@ -104,7 +104,7 @@ func activeMutatingMethod(_ x: Float) -> Float { struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } diff --git a/test/AutoDiff/SILOptimizer/generics.swift b/test/AutoDiff/SILOptimizer/generics.swift index dfdca134889fb..dbb3be0a8de8d 100644 --- a/test/AutoDiff/SILOptimizer/generics.swift +++ b/test/AutoDiff/SILOptimizer/generics.swift @@ -11,7 +11,7 @@ _ = gradient(at: Float(1), of: { x in identity(x) }) // Test PullbackCloner local buffer allocation. // Verify that local buffers are immediately set to zero. -// CHECK-SIL-LABEL: sil private @identity16_Differentiation14DifferentiableRzlTJpSpSr +// CHECK-SIL-LABEL: sil private @identity16_Differentiation14DifferentiableRzlTJpSpSUr // CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.TangentVector // CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter // CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.TangentVector.Type diff --git a/test/AutoDiff/SILOptimizer/licm_context.swift b/test/AutoDiff/SILOptimizer/licm_context.swift index b12ab8f43533c..0f0ba4950afa0 100644 --- a/test/AutoDiff/SILOptimizer/licm_context.swift +++ b/test/AutoDiff/SILOptimizer/licm_context.swift @@ -73,10 +73,10 @@ public func s(y: B) -> B { return w(y) } -// CHECK-LABEL: sil private @$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSr : +// CHECK-LABEL: sil private @$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSUr : // CHECK: autoDiffCreateLinearMapContext // CHECK: autoDiffCreateLinearMapContext -// CHECK-LABEL: end sil function '$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSr' +// CHECK-LABEL: end sil function '$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSUr' func o(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R { f(x) diff --git a/test/AutoDiff/SILOptimizer/replicated-adjoint-prop.swift b/test/AutoDiff/SILOptimizer/replicated-adjoint-prop.swift index 7c3e39a72b3ec..c9c8b9c020135 100644 --- a/test/AutoDiff/SILOptimizer/replicated-adjoint-prop.swift +++ b/test/AutoDiff/SILOptimizer/replicated-adjoint-prop.swift @@ -8,7 +8,7 @@ struct Test: Differentiable { @differentiable(reverse) mutating func doSomething(input: Float) { -// CHECK-SIL-LABEL: TestV11doSomething5inputySf_tFTJpSSpSr : +// CHECK-SIL-LABEL: TestV11doSomething5inputySf_tFTJpSSpUSr : // Ensure that only two adjoint buffers will be propagated // CHECK-SIL: copy_addr %0 to %22 : $*Test.TangentVector // CHECK-SIL-NEXT: debug_value diff --git a/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift b/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift index f5e29d0a6edd1..19f969248bab1 100644 --- a/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift +++ b/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift @@ -553,7 +553,7 @@ class C_55238: Differentiable { func method() -> Float { x } @derivative(of: method) - func vjpMethod() -> (value: Float, pullback: (Float) -> TangentVector) { fatalError() } + func vjpMethod() -> (value: Float, pullback: (Float, inout TangentVector) -> ()) { fatalError() } // Test usage of synthesized `TangentVector` type. // This should not produce an error: "reference to invalid associated type 'TangentVector'". diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index f3e26e8a32b13..f8a40d2f2469e 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -615,15 +615,15 @@ extension Class { } extension Class where T: Differentiable { @derivative(of: subscript.get) - func vjpSubscriptGetter() -> (value: Float, pullback: (Float) -> TangentVector) { - return (1, { _ in .zero }) + func vjpSubscriptGetter() -> (value: Float, pullback: (Float, inout TangentVector) -> ()) { + return (1, { $1 = TangentVector.zero }) } // expected-error @+2 {{a derivative already exists for getter for 'subscript()'}} // expected-note @-6 {{other attribute declared here}} @derivative(of: subscript) - func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) { - return (1, { _ in .zero }) + func vjpSubscript() -> (value: Float, pullback: (Float, inout TangentVector) -> ()) { + return (1, { $1 = TangentVector.zero }) } // FIXME: Enable derivative registration for class property/subscript setters (https://github.com/apple/swift/issues/55542). @@ -769,7 +769,8 @@ struct InoutParameters: Differentiable { } extension InoutParameters { - // expected-note @+1 4 {{'staticMethod(_:rhs:)' defined here}} + // expected-note @+2 {{'staticMethod(_:rhs:)' defined here}} + // expected-note @+1 {{'staticMethod(_:rhs:)' defined here}} static func staticMethod(_ lhs: inout Self, rhs: Self) {} // Test wrt `inout` parameter. @@ -800,33 +801,34 @@ extension InoutParameters { // Test non-wrt `inout` parameter. + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func vjpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( value: Void, pullback: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}} + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func vjpNotWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> ( - // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, pullback: (inout TangentVector) -> TangentVector ) { fatalError() } + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( value: Void, differential: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}} + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( - // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, differential: (inout TangentVector) -> TangentVector ) { fatalError() } } extension InoutParameters { - // expected-note @+1 4 {{'mutatingMethod' defined here}} + // expected-note @+2 {{'mutatingMethod' defined here}} + // expected-note @+1 {{'mutatingMethod' defined here}} mutating func mutatingMethod(_ other: Self) {} // Test wrt `inout` `self` parameter. @@ -857,27 +859,27 @@ extension InoutParameters { // Test non-wrt `inout` `self` parameter. + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func vjpNotWrtInout(_ other: Self) -> ( value: Void, pullback: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}} + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func vjpNotWrtInoutMismatch(_ other: Self) -> ( - // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, pullback: (inout TangentVector) -> TangentVector ) { fatalError() } + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func jvpNotWrtInout(_ other: Self) -> ( value: Void, differential: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}} + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func jvpNotWrtInoutMismatch(_ other: Self) -> ( - // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, differential: (TangentVector, TangentVector) -> Void ) { fatalError() } } diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 53299fe2f639b..eefa81f12a70b 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -681,7 +681,7 @@ struct InoutParameters: Differentiable { } extension NonDiffableStruct { - // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'NonDiffableStruct' does not conform to 'Differentiable'}} + // expected-error @+1 {{cannot differentiate void function 'nondiffResult(x:y:z:)'}} @differentiable(reverse) static func nondiffResult(x: Int, y: inout NonDiffableStruct, z: Float) {} diff --git a/test/AutoDiff/TBD/derivative_symbols.swift b/test/AutoDiff/TBD/derivative_symbols.swift index 5408e9e8cab1b..6c357eb83a133 100644 --- a/test/AutoDiff/TBD/derivative_symbols.swift +++ b/test/AutoDiff/TBD/derivative_symbols.swift @@ -141,7 +141,7 @@ public final class Class: Differentiable { @derivative(of: subscript) public func vjpSubscript(_ x: Float) -> ( - value: Float, pullback: (Float) -> (TangentVector, Float) + value: Float, pullback: (Float, inout TangentVector) -> Float ) { fatalError() } diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift b/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift index 62a7d2ea018cf..5ab20d9360a82 100644 --- a/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift +++ b/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift @@ -4,15 +4,16 @@ import _Differentiation // https://github.com/apple/swift/issues/55745 // Test protocol witness thunk for `@differentiable` protocol requirement, where -// the required method has a non-wrt `inout` parameter that should be treated as -// a differentiability result. +// the required method has a non-wrt `inout` parameter. protocol Proto { + // expected-error @+1 {{cannot differentiate void function 'method(x:y:)'}} @differentiable(reverse, wrt: x) func method(x: Float, y: inout Float) } struct Struct: Proto { + // expected-error @+1 {{cannot differentiate void function 'method(x:y:)'}} @differentiable(reverse, wrt: x) func method(x: Float, y: inout Float) { y = y * x diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-58123-invalid-debug-info.swift b/test/AutoDiff/compiler_crashers_fixed/issue-58123-invalid-debug-info.swift index 081606ba3b358..3868feaf6959b 100644 --- a/test/AutoDiff/compiler_crashers_fixed/issue-58123-invalid-debug-info.swift +++ b/test/AutoDiff/compiler_crashers_fixed/issue-58123-invalid-debug-info.swift @@ -4,7 +4,7 @@ // Mutating functions with control flow can cause assertion failure for // conflicting debug variable type -// CHECK-LABEL: define internal swiftcc float @"$s4main8TestTypeV24doDifferentiableTimeStep04timeG0ySf_tFTJpSSpSrTA" +// CHECK-LABEL: define internal swiftcc float @"$s4main8TestTypeV24doDifferentiableTimeStep04timeG0ySf_tFTJpSSpUSrTA" // CHECK: [[SELF:%.*]] = alloca %T4main8TestTypeV06ManualB7TangentV // CHECK: call void @llvm.dbg.declare(metadata ptr [[SELF]] diff --git a/test/AutoDiff/validation-test/class_differentiation.swift b/test/AutoDiff/validation-test/class_differentiation.swift index bc3f4b926eb4a..82a0ad70cfd5b 100644 --- a/test/AutoDiff/validation-test/class_differentiation.swift +++ b/test/AutoDiff/validation-test/class_differentiation.swift @@ -260,10 +260,11 @@ ClassTests.test("ClassMethods - wrt self") { @derivative(of: f) final func vjpf( _ x: Tracked - ) -> (value: Tracked, pullback: (Tracked) -> (TangentVector, Tracked)) { + ) -> (value: Tracked, pullback: (Tracked, inout TangentVector) -> Tracked) { let base = self.base - return (f(x), { v in - (TangentVector(base: v * x), base * v) + return (f(x), { v, tv in + tv = TangentVector(base: v * x) + return base * v }) } } @@ -504,9 +505,9 @@ ClassTests.test("ClassProperties") { var squared: Tracked { base * base } @derivative(of: squared) - final func vjpSquared() -> (value: Tracked, pullback: (Tracked) -> TangentVector) { + final func vjpSquared() -> (value: Tracked, pullback: (Tracked, inout TangentVector) -> ()) { let base = self.base - return (base * base, { v in TangentVector(base: 2 * base * v) }) + return (base * base, { v, tv in tv = TangentVector(base: 2 * base * v) }) } } diff --git a/test/AutoDiff/validation-test/forward_mode_simple.swift b/test/AutoDiff/validation-test/forward_mode_simple.swift index 0b4cb384e368c..cedadde1ab356 100644 --- a/test/AutoDiff/validation-test/forward_mode_simple.swift +++ b/test/AutoDiff/validation-test/forward_mode_simple.swift @@ -761,10 +761,11 @@ ForwardModeTests.test("SimpleWrtSelf") { return (f(x), { (dself, dx) in dself.base * dx }) } @derivative(of: f) - final func vjpf(_ x: Float) -> (value: Float, pullback: (Float) -> (TangentVector, Float)) { + final func vjpf(_ x: Float) -> (value: Float, pullback: (Float, inout TangentVector) -> Float) { let base = self.base - return (f(x), { v in - (TangentVector(base: v * x, _nontrivial: []), base * v) + return (f(x), { v, tv in + tv = TangentVector(base: v * x, _nontrivial: []) + return base * v }) } } @@ -1320,29 +1321,14 @@ ForwardModeTests.test("ForceUnwrapping") { } ForwardModeTests.test("NonVariedResult") { - @differentiable(reverse, wrt: x) - func nonWrtInoutParam(_ x: T, _ y: inout T) { - y = x - } - @differentiable(reverse) func wrtInoutParam(_ x: T, _ y: inout T) { y = x } - @differentiable(reverse, wrt: x) - func nonWrtInoutParamNonVaried(_ x: T, _ y: inout T) {} - - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func wrtInoutParamNonVaried(_ x: T, _ y: inout T) {} - @differentiable(reverse) - func variedResultTracked(_ x: Tracked) -> Tracked { - var result: Tracked = 0 - nonWrtInoutParam(x, &result) - return result - } - @differentiable(reverse) func variedResultTracked2(_ x: Tracked) -> Tracked { var result: Tracked = 0 @@ -1352,13 +1338,6 @@ ForwardModeTests.test("NonVariedResult") { @differentiable(reverse) func nonVariedResultTracked(_ x: Tracked) -> Tracked { - var result: Tracked = 0 - nonWrtInoutParamNonVaried(x, &result) - return result - } - - @differentiable(reverse) - func nonVariedResultTracked2(_ x: Tracked) -> Tracked { // expected-warning @+1 {{variable 'result' was never mutated}} var result: Tracked = 0 return result diff --git a/test/AutoDiff/validation-test/inout_parameters.swift b/test/AutoDiff/validation-test/inout_parameters.swift index 800b373ffcec0..a7df0577354c9 100644 --- a/test/AutoDiff/validation-test/inout_parameters.swift +++ b/test/AutoDiff/validation-test/inout_parameters.swift @@ -174,43 +174,62 @@ InoutParameterAutoDiffTests.test("InoutClassParameter") { } do { - func squaredViaModifyAccessor(_ c: inout Class) { - // The line below calls `Class.x.modify`. - c.x *= c.x + func squaredViaGetSetAccessors(_ c: inout Class) { + // The code below crafted not to call `Class.x.modify`. + // `Class.x.get` / `Class.x.set` are called instead + let a = c.x * c.x + c.x = a } func squared(_ x: Float) -> Float { var c = Class(x) - squaredViaModifyAccessor(&c) + squaredViaGetSetAccessors(&c) return c.x } - // FIXME(TF-1080): Fix incorrect class property `modify` accessor derivative values. - // expectEqual((100, 20), valueWithGradient(at: 10, of: squared)) - // expectEqual(200, pullback(at: 10, of: squared)(10)) - expectEqual((100, 1), valueWithGradient(at: 10, of: squared)) - expectEqual(10, pullback(at: 10, of: squared)(10)) + expectEqual((100, 20), valueWithGradient(at: 10, of: squared)) + expectEqual(200, pullback(at: 10, of: squared)(10)) } + + // FIXME: Support differentiation of `modify` accessors: + // https://github.com/apple/swift/issues/54401 + //do { + // func squaredViaModifyAccessor(_ c: inout Class) { + // // The line below calls `Class.x.modify`. + // c.x *= c.x + // } + // func squared(_ x: Float) -> Float { + // var c = Class(x) + // squaredViaModifyAccessor(&c) + // return c.x + // } + // // FIXME(TF-1080): Fix incorrect class property `modify` accessor derivative values. + // // expectEqual((100, 20), valueWithGradient(at: 10, of: squared)) + // // expectEqual(200, pullback(at: 10, of: squared)(10)) + // expectEqual((100, 1), valueWithGradient(at: 10, of: squared)) + // expectEqual(10, pullback(at: 10, of: squared)(10)) + //} } -// https://github.com/apple/swift/issues/55745 -// Test function with non-wrt `inout` parameter, which should be -// treated as a differentiability result. +// Test function with wrt `inout` parameter, which should be treated as a differentiability result. +// Original issue https://github.com/apple/swift/issues/55745 deals with non-wrt `inout` which +// we explicitly disallow now + protocol P_55745 { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func method(_ x: Float, _ y: inout Float) - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func genericMethod(_ x: T, _ y: inout T) } InoutParameterAutoDiffTests.test("non-wrt inout parameter") { struct Struct: P_55745 { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func method(_ x: Float, _ y: inout Float) { y = y * x } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func genericMethod(_ x: T, _ y: inout T) { y = x }