Skip to content
Draft
12 changes: 6 additions & 6 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,16 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
return s;
}

/// A semantic function result type: either a formal function result type or
/// an `inout` parameter type. Used in derivative function type calculation.
/// A semantic function result type: either a formal function result type or a
/// semantic result (an `inout` or class-bound) parameter type. Used in
/// derivative function type calculation.
struct AutoDiffSemanticFunctionResultType {
Type type;
unsigned index : 30;
bool isInout : 1;
bool isWrtParam : 1;
bool isSemanticResultParameter : 1;

AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt)
: type(t), index(idx), isInout(inout), isWrtParam(wrt) { }
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param)
: type(t), index(idx), isSemanticResultParameter(param) { }
};

/// Key for caching SIL derivative function types.
Expand Down
56 changes: 46 additions & 10 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3192,6 +3192,12 @@ class AnyFunctionType : public TypeBase {
/// Whether the parameter is marked '@noDerivative'.
bool isNoDerivative() const { return Flags.isNoDerivative(); }

/// Whether the parameter might be a semantic result for autodiff purposes.
/// This includes inout and class-bound parameters.
bool isAutoDiffSemanticResult() const {
return isInOut() || Ty->getClassOrBoundGenericClass() != nullptr;
}

ValueOwnership getValueOwnership() const {
return Flags.getValueOwnership();
}
Expand Down Expand Up @@ -3509,19 +3515,16 @@ class AnyFunctionType : public TypeBase {
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - There is one semantic function result type: either the formal original
/// result or an `inout` parameter. It must conform to `Differentiable`.
/// - There are semantic function result type: either the formal original
/// result or a "wrt" semantic result parameter.
///
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
/// "wrt" result derivative.
///
/// - Case 1: original function has no `inout` parameters.
/// - Original: `(T0, T1, ...) -> R`
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
/// - Case 2: original function has a non-wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
/// - Case 3: original function has a wrt `inout` parameter.
/// - Case 2: original function has a wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
///
Expand All @@ -3531,10 +3534,7 @@ class AnyFunctionType : public TypeBase {
/// - Case 1: original function has no `inout` parameters.
/// - Original: `(T0, T1, ...) -> R`
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
/// - Case 2: original function has a non-wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
/// - Case 3: original function has a wrt `inout` parameter.
/// - Case 2: original function has a wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
///
Expand Down Expand Up @@ -4101,6 +4101,10 @@ class SILParameterInfo {
return getConvention() == ParameterConvention::Indirect_Inout
|| getConvention() == ParameterConvention::Indirect_InoutAliasable;
}
bool isAutoDiffSemanticResult() const {
return isIndirectMutating() ||
getInterfaceType().getClassOrBoundGenericClass() != nullptr;
}

bool isPack() const {
return isPackParameter(getConvention());
Expand Down Expand Up @@ -4836,6 +4840,38 @@ class SILFunctionType final
return llvm::count_if(getParameters(), IndirectMutatingParameterFilter());
}

struct AutoDiffSemanticResultsParameterFilter {
bool operator()(SILParameterInfo param) const {
return param.isAutoDiffSemanticResult();
}
};

using AutoDiffSemanticResultsParameterIter =
llvm::filter_iterator<const SILParameterInfo *,
AutoDiffSemanticResultsParameterFilter>;
using AutoDiffSemanticResultsParameterRange =
iterator_range<AutoDiffSemanticResultsParameterIter>;

/// A range of SILParameterInfo for all semantic results parameters.
AutoDiffSemanticResultsParameterRange
getAutoDiffSemanticResultsParameters() const {
return llvm::make_filter_range(getParameters(),
AutoDiffSemanticResultsParameterFilter());
}

/// Returns the number of semantic results parameters.
unsigned getNumAutoDiffSemanticResultsParameters() const {
return llvm::count_if(getParameters(), AutoDiffSemanticResultsParameterFilter());
}

/// Returns the number of function potential semantic results:
/// * Usual results
/// * Inout parameters
/// * Class or class-bound parameters
unsigned getNumAutoDiffSemanticResults() const {
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
}

/// Get the generic signature that the component types are specified
/// in terms of, if any.
CanGenericSignature getSubstGenericSignature() const {
Expand Down
12 changes: 12 additions & 0 deletions include/swift/SIL/ApplySite.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,18 @@ class FullApplySite : public ApplySite {
llvm_unreachable("invalid apply kind");
}

AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const {
switch (getKind()) {
case FullApplySiteKind::ApplyInst:
return cast<ApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
case FullApplySiteKind::TryApplyInst:
return cast<TryApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
case FullApplySiteKind::BeginApplyInst:
return cast<BeginApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
}
llvm_unreachable("invalid apply kind");
}

/// Returns true if \p op is the callee operand of this apply site
/// and not an argument operand.
bool isCalleeOperand(const Operand &op) const {
Expand Down
29 changes: 29 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,25 @@ struct OperandToInoutArgument {
using InoutArgumentRange =
OptionalTransformRange<IntRange<size_t>, OperandToInoutArgument>;

/// Predicate used to filter AutoDiffSemanticResultArgumentRange.
struct OperandToAutoDiffSemanticResultArgument {
ArrayRef<SILParameterInfo> paramInfos;
OperandValueArrayRef arguments;
OperandToAutoDiffSemanticResultArgument(ArrayRef<SILParameterInfo> paramInfos,
OperandValueArrayRef arguments)
: paramInfos(paramInfos), arguments(arguments) {
assert(paramInfos.size() == arguments.size());
}
llvm::Optional<SILValue> operator()(size_t i) const {
if (paramInfos[i].isAutoDiffSemanticResult())
return arguments[i];
return llvm::None;
}
};

using AutoDiffSemanticResultArgumentRange =
OptionalTransformRange<IntRange<size_t>, OperandToAutoDiffSemanticResultArgument>;

/// The partial specialization of ApplyInstBase for full applications.
/// Adds some methods relating to 'self' and to result types that don't
/// make sense for partial applications.
Expand Down Expand Up @@ -2894,6 +2913,16 @@ class ApplyInstBase<Impl, Base, true>
impl.getArgumentsWithoutIndirectResults()));
}

/// Returns all autodiff semantic result (`@inout`, `@inout_aliasable`)
/// arguments passed to the instruction.
AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const {
auto &impl = asImpl();
return AutoDiffSemanticResultArgumentRange(
indices(getArgumentsWithoutIndirectResults()),
OperandToAutoDiffSemanticResultArgument(impl.getSubstCalleeConv().getParameters(),
impl.getArgumentsWithoutIndirectResults()));
}

bool hasSemantics(StringRef semanticsString) const {
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
}
Expand Down
51 changes: 31 additions & 20 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,9 @@ void AnyFunctionType::getSubsetParameters(
}
}

void autodiff::getFunctionSemanticResults(
const AnyFunctionType *functionType,
const IndexSubset *parameterIndices,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
static void getFunctionFormalResults(
const AnyFunctionType *functionType,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
auto &ctx = functionType->getASTContext();

// Collect formal result type as a semantic result, unless it is
Expand All @@ -199,33 +198,36 @@ void autodiff::getFunctionSemanticResults(
if (formalResultType->is<TupleType>()) {
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
resultTypes.emplace_back(elt.getType(), resultIdx++,
/*isInout*/ false, /*isWrt*/ false);
/*isParameter*/ false);
}
} else {
resultTypes.emplace_back(formalResultType, resultIdx++,
/*isInout*/ false, /*isWrt*/ false);
/*isParameter*/ false);
}
}
}

bool addNonWrts = resultTypes.empty();

// Collect wrt `inout` parameters as semantic results
// As an extention, collect all (including non-wrt) inouts as results for
// functions returning void.
void autodiff::getFunctionSemanticResults(
const AnyFunctionType *functionType,
const IndexSubset *parameterIndices,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
getFunctionFormalResults(functionType, resultTypes);
unsigned numResults = resultTypes.size();

// Collect wrt semantic result (`inout` and class references) parameters as
// semantic results
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
unsigned curryOffset = 0) {
for (auto paramAndIndex : enumerate(functionType->getParams())) {
if (!paramAndIndex.value().isInOut())
if (!paramAndIndex.value().isAutoDiffSemanticResult())
continue;

unsigned idx = paramAndIndex.index() + curryOffset;
assert(idx < parameterIndices->getCapacity() &&
"invalid parameter index");
bool isWrt = parameterIndices->contains(idx);
if (addNonWrts || isWrt)
if (parameterIndices->contains(idx))
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
resultIdx, /*isInout*/ true, isWrt);
resultIdx += 1;
numResults + idx, /*isParameter*/ true);
}
};

Expand All @@ -245,17 +247,26 @@ autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType,
const IndexSubset *parameterIndices) {
auto &ctx = functionType->getASTContext();

SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
SmallVector<AutoDiffSemanticFunctionResultType, 1> formalResults, semanticResults;
getFunctionFormalResults(functionType, formalResults);
autodiff::getFunctionSemanticResults(functionType, parameterIndices,
semanticResults);
unsigned numSemanticResults = formalResults.size();
if (auto *resultFnType =
functionType->getResult()->getAs<AnyFunctionType>()) {
assert(functionType->getNumParams() == 1 && "unexpected function type");
numSemanticResults += 1 + resultFnType->getNumParams();
} else {
numSemanticResults += functionType->getNumParams();
}

SmallVector<unsigned> resultIndices;
unsigned cap = 0;
for (const auto& result : semanticResults) {
assert(result.index < numSemanticResults);
resultIndices.push_back(result.index);
cap = std::max(cap, result.index + 1U);
}

return IndexSubset::get(ctx, cap, resultIndices);
return IndexSubset::get(ctx, numSemanticResults, resultIndices);
}

IndexSubset *
Expand Down
51 changes: 20 additions & 31 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5558,7 +5558,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);

// Accumulate non-inout result tangent spaces.
// Accumulate non-semantic result tangent spaces.
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
for (auto i : range(originalResults.size())) {
auto originalResult = originalResults[i];
Expand All @@ -5577,31 +5577,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
std::make_pair(originalResultType, unsigned(originalResult.index)));

if (!originalResult.isInout)
if (!originalResult.isSemanticResultParameter)
resultTanTypes.push_back(resultTan->getType());
else if (originalResult.isInout && !originalResult.isWrtParam)
inoutTanTypes.push_back(resultTan->getType());
}

// Treat non-wrt inouts as semantic results for functions returning Void
if (resultTanTypes.empty())
resultTanTypes = inoutTanTypes;

// Compute the result linear map function type.
FunctionType *linearMapType;
switch (kind) {
case AutoDiffLinearMapKind::Differential: {
// Compute the differential type, returned by JVP functions.
//
// Case 1: original function has no `inout` parameters.
// Case 1: original function has no semantic result parameters.
// - Original: `(T0, T1, ...) -> R`
// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
//
// Case 2: original function has a non-wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, ...) -> T1.Tan`
//
// Case 3: original function has a wrt `inout` parameter.
// Case 2: original function has a wrt semantic result parameter
// (e.g. `inout`)
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
SmallVector<AnyFunctionType::Param, 4> differentialParams;
Expand Down Expand Up @@ -5644,19 +5635,16 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
case AutoDiffLinearMapKind::Pullback: {
// Compute the pullback type, returned by VJP functions.
//
// Case 1: original function has no `inout` parameters.
// Case 1: original function has no semantic result parameters.
// - Original: `(T0, T1, ...) -> R`
// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
//
// Case 2: original function has a non-wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
//
// Case 3: original function has wrt `inout` parameters.
// Case 2: original function has wrt semantic result parameters
// (e.g. an `inout` one)
// - Original: `(T0, inout T1, ...) -> R`
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
SmallVector<TupleTypeElt, 4> pullbackResults;
SmallVector<AnyFunctionType::Param, 2> inoutParams;
SmallVector<AnyFunctionType::Param, 2> semanticResultParams;
for (auto i : range(diffParams.size())) {
auto diffParam = diffParams[i];
auto paramType = diffParam.getPlainType();
Expand All @@ -5669,10 +5657,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
NonDifferentiableDifferentiabilityParameter,
std::make_pair(paramType, i));

if (diffParam.isInOut()) {
if (diffParam.isAutoDiffSemanticResult()) {
if (paramType->isVoid())
continue;
inoutParams.push_back(diffParam);
semanticResultParams.push_back(diffParam);
continue;
}
pullbackResults.emplace_back(paramTan->getType());
Expand All @@ -5685,30 +5673,31 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
} else {
pullbackResult = TupleType::get(pullbackResults, ctx);
}
// First accumulate non-inout results as pullback parameters.
// First accumulate results as pullback parameters.
SmallVector<FunctionType::Param, 2> pullbackParams;
for (auto i : range(resultTanTypes.size())) {
auto resultTanType = resultTanTypes[i];
auto flags = ParameterTypeFlags().withInOut(false);
pullbackParams.push_back(AnyFunctionType::Param(
resultTanType, Identifier(), flags));
}
// Then append inout parameters.
for (auto i : range(inoutParams.size())) {
auto inoutParam = inoutParams[i];
auto inoutParamType = inoutParam.getPlainType();
auto inoutParamTan =
inoutParamType->getAutoDiffTangentSpace(lookupConformance);
// Then append semantic result parameters.
for (auto i : range(semanticResultParams.size())) {
auto semanticResultParam = semanticResultParams[i];
auto semanticResultParamType = semanticResultParam.getPlainType();
auto semanticResultParamTan =
semanticResultParamType->getAutoDiffTangentSpace(lookupConformance);
auto flags = ParameterTypeFlags().withInOut(true);
pullbackParams.push_back(AnyFunctionType::Param(
inoutParamTan->getType(), Identifier(), flags));
semanticResultParamTan->getType(), Identifier(), flags));
}
// FIXME: Verify ExtInfo state is correct, not working by accident.
FunctionType::ExtInfo info;
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
break;
}
}

assert(linearMapType && "Expected linear map type");
return linearMapType;
}
Expand Down
Loading