Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,16 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
return s;
}

/// A semantic function result type: either a formal function result type or
/// an `inout` parameter type. Used in derivative function type calculation.
/// A semantic function result type: either a formal function result type or a
/// semantic result (an `inout`) parameter type. Used in derivative function type
/// calculation.
struct AutoDiffSemanticFunctionResultType {
Type type;
unsigned index : 30;
bool isInout : 1;
bool isWrtParam : 1;
bool isSemanticResultParameter : 1;

AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt)
: type(t), index(idx), isInout(inout), isWrtParam(wrt) { }
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param)
: type(t), index(idx), isSemanticResultParameter(param) { }
};

/// Key for caching SIL derivative function types.
Expand Down
54 changes: 44 additions & 10 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3192,6 +3192,12 @@ class AnyFunctionType : public TypeBase {
/// Whether the parameter is marked '@noDerivative'.
bool isNoDerivative() const { return Flags.isNoDerivative(); }

/// Whether the parameter might be a semantic result for autodiff purposes.
/// This includes inout parameters.
bool isAutoDiffSemanticResult() const {
return isInOut();
}

ValueOwnership getValueOwnership() const {
return Flags.getValueOwnership();
}
Expand Down Expand Up @@ -3509,19 +3515,16 @@ class AnyFunctionType : public TypeBase {
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - There is one semantic function result type: either the formal original
/// result or an `inout` parameter. It must conform to `Differentiable`.
/// - There are semantic function result type: either the formal original
/// result or a "wrt" semantic result parameter.
///
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
/// "wrt" result derivative.
///
/// - Case 1: original function has no `inout` parameters.
/// - Original: `(T0, T1, ...) -> R`
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
/// - Case 2: original function has a non-wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
/// - Case 3: original function has a wrt `inout` parameter.
/// - Case 2: original function has a wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
///
Expand All @@ -3531,10 +3534,7 @@ class AnyFunctionType : public TypeBase {
/// - Case 1: original function has no `inout` parameters.
/// - Original: `(T0, T1, ...) -> R`
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
/// - Case 2: original function has a non-wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
/// - Case 3: original function has a wrt `inout` parameter.
/// - Case 2: original function has a wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
///
Expand Down Expand Up @@ -4101,6 +4101,9 @@ class SILParameterInfo {
return getConvention() == ParameterConvention::Indirect_Inout
|| getConvention() == ParameterConvention::Indirect_InoutAliasable;
}
bool isAutoDiffSemanticResult() const {
return isIndirectMutating();
}

bool isPack() const {
return isPackParameter(getConvention());
Expand Down Expand Up @@ -4836,6 +4839,37 @@ class SILFunctionType final
return llvm::count_if(getParameters(), IndirectMutatingParameterFilter());
}

struct AutoDiffSemanticResultsParameterFilter {
bool operator()(SILParameterInfo param) const {
return param.isAutoDiffSemanticResult();
}
};

using AutoDiffSemanticResultsParameterIter =
llvm::filter_iterator<const SILParameterInfo *,
AutoDiffSemanticResultsParameterFilter>;
using AutoDiffSemanticResultsParameterRange =
iterator_range<AutoDiffSemanticResultsParameterIter>;

/// A range of SILParameterInfo for all semantic results parameters.
AutoDiffSemanticResultsParameterRange
getAutoDiffSemanticResultsParameters() const {
return llvm::make_filter_range(getParameters(),
AutoDiffSemanticResultsParameterFilter());
}

/// Returns the number of semantic results parameters.
unsigned getNumAutoDiffSemanticResultsParameters() const {
return llvm::count_if(getParameters(), AutoDiffSemanticResultsParameterFilter());
}

/// Returns the number of function potential semantic results:
/// * Usual results
/// * Inout parameters
unsigned getNumAutoDiffSemanticResults() const {
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
}

/// Get the generic signature that the component types are specified
/// in terms of, if any.
CanGenericSignature getSubstGenericSignature() const {
Expand Down
12 changes: 12 additions & 0 deletions include/swift/SIL/ApplySite.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,18 @@ class FullApplySite : public ApplySite {
llvm_unreachable("invalid apply kind");
}

AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const {
switch (getKind()) {
case FullApplySiteKind::ApplyInst:
return cast<ApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
case FullApplySiteKind::TryApplyInst:
return cast<TryApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
case FullApplySiteKind::BeginApplyInst:
return cast<BeginApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
}
llvm_unreachable("invalid apply kind");
}

/// Returns true if \p op is the callee operand of this apply site
/// and not an argument operand.
bool isCalleeOperand(const Operand &op) const {
Expand Down
29 changes: 29 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,25 @@ struct OperandToInoutArgument {
using InoutArgumentRange =
OptionalTransformRange<IntRange<size_t>, OperandToInoutArgument>;

/// Predicate used to filter AutoDiffSemanticResultArgumentRange.
struct OperandToAutoDiffSemanticResultArgument {
ArrayRef<SILParameterInfo> paramInfos;
OperandValueArrayRef arguments;
OperandToAutoDiffSemanticResultArgument(ArrayRef<SILParameterInfo> paramInfos,
OperandValueArrayRef arguments)
: paramInfos(paramInfos), arguments(arguments) {
assert(paramInfos.size() == arguments.size());
}
llvm::Optional<SILValue> operator()(size_t i) const {
if (paramInfos[i].isAutoDiffSemanticResult())
return arguments[i];
return llvm::None;
}
};

using AutoDiffSemanticResultArgumentRange =
OptionalTransformRange<IntRange<size_t>, OperandToAutoDiffSemanticResultArgument>;

/// The partial specialization of ApplyInstBase for full applications.
/// Adds some methods relating to 'self' and to result types that don't
/// make sense for partial applications.
Expand Down Expand Up @@ -2894,6 +2913,16 @@ class ApplyInstBase<Impl, Base, true>
impl.getArgumentsWithoutIndirectResults()));
}

/// Returns all autodiff semantic result (`@inout`, `@inout_aliasable`)
/// arguments passed to the instruction.
AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const {
auto &impl = asImpl();
return AutoDiffSemanticResultArgumentRange(
indices(getArgumentsWithoutIndirectResults()),
OperandToAutoDiffSemanticResultArgument(impl.getSubstCalleeConv().getParameters(),
impl.getArgumentsWithoutIndirectResults()));
}

bool hasSemantics(StringRef semanticsString) const {
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
}
Expand Down
18 changes: 7 additions & 11 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,32 +199,28 @@ void autodiff::getFunctionSemanticResults(
if (formalResultType->is<TupleType>()) {
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
resultTypes.emplace_back(elt.getType(), resultIdx++,
/*isInout*/ false, /*isWrt*/ false);
/*isParameter*/ false);
}
} else {
resultTypes.emplace_back(formalResultType, resultIdx++,
/*isInout*/ false, /*isWrt*/ false);
/*isParameter*/ false);
}
}

bool addNonWrts = resultTypes.empty();

// Collect wrt `inout` parameters as semantic results
// As an extention, collect all (including non-wrt) inouts as results for
// functions returning void.
// Collect wrt semantic result (`inout`) parameters as
// semantic results
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
unsigned curryOffset = 0) {
for (auto paramAndIndex : enumerate(functionType->getParams())) {
if (!paramAndIndex.value().isInOut())
if (!paramAndIndex.value().isAutoDiffSemanticResult())
continue;

unsigned idx = paramAndIndex.index() + curryOffset;
assert(idx < parameterIndices->getCapacity() &&
"invalid parameter index");
bool isWrt = parameterIndices->contains(idx);
if (addNonWrts || isWrt)
if (parameterIndices->contains(idx))
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
resultIdx, /*isInout*/ true, isWrt);
resultIdx, /*isParameter*/ true);
resultIdx += 1;
}
};
Expand Down
43 changes: 15 additions & 28 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5558,7 +5558,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);

// Accumulate non-inout result tangent spaces.
// Accumulate non-semantic result tangent spaces.
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
for (auto i : range(originalResults.size())) {
auto originalResult = originalResults[i];
Expand All @@ -5577,16 +5577,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
std::make_pair(originalResultType, unsigned(originalResult.index)));

if (!originalResult.isInout)
if (!originalResult.isSemanticResultParameter)
resultTanTypes.push_back(resultTan->getType());
else if (originalResult.isInout && !originalResult.isWrtParam)
inoutTanTypes.push_back(resultTan->getType());
}

// Treat non-wrt inouts as semantic results for functions returning Void
if (resultTanTypes.empty())
resultTanTypes = inoutTanTypes;

// Compute the result linear map function type.
FunctionType *linearMapType;
switch (kind) {
Expand All @@ -5597,11 +5591,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
// - Original: `(T0, T1, ...) -> R`
// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
//
// Case 2: original function has a non-wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, ...) -> T1.Tan`
//
// Case 3: original function has a wrt `inout` parameter.
// Case 2: original function has a wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
SmallVector<AnyFunctionType::Param, 4> differentialParams;
Expand Down Expand Up @@ -5648,15 +5638,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
// - Original: `(T0, T1, ...) -> R`
// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
//
// Case 2: original function has a non-wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
//
// Case 3: original function has wrt `inout` parameters.
// Case 2: original function has wrt `inout` parameters.
// - Original: `(T0, inout T1, ...) -> R`
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
SmallVector<TupleTypeElt, 4> pullbackResults;
SmallVector<AnyFunctionType::Param, 2> inoutParams;
SmallVector<AnyFunctionType::Param, 2> semanticResultParams;
for (auto i : range(diffParams.size())) {
auto diffParam = diffParams[i];
auto paramType = diffParam.getPlainType();
Expand All @@ -5669,10 +5655,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
NonDifferentiableDifferentiabilityParameter,
std::make_pair(paramType, i));

if (diffParam.isInOut()) {
if (diffParam.isAutoDiffSemanticResult()) {
if (paramType->isVoid())
continue;
inoutParams.push_back(diffParam);
semanticResultParams.push_back(diffParam);
continue;
}
pullbackResults.emplace_back(paramTan->getType());
Expand All @@ -5693,22 +5679,23 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
pullbackParams.push_back(AnyFunctionType::Param(
resultTanType, Identifier(), flags));
}
// Then append inout parameters.
for (auto i : range(inoutParams.size())) {
auto inoutParam = inoutParams[i];
auto inoutParamType = inoutParam.getPlainType();
auto inoutParamTan =
inoutParamType->getAutoDiffTangentSpace(lookupConformance);
// Then append semantic result parameters.
for (auto i : range(semanticResultParams.size())) {
auto semanticResultParam = semanticResultParams[i];
auto semanticResultParamType = semanticResultParam.getPlainType();
auto semanticResultParamTan =
semanticResultParamType->getAutoDiffTangentSpace(lookupConformance);
auto flags = ParameterTypeFlags().withInOut(true);
pullbackParams.push_back(AnyFunctionType::Param(
inoutParamTan->getType(), Identifier(), flags));
semanticResultParamTan->getType(), Identifier(), flags));
}
// FIXME: Verify ExtInfo state is correct, not working by accident.
FunctionType::ExtInfo info;
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
break;
}
}

assert(linearMapType && "Expected linear map type");
return linearMapType;
}
Expand Down
Loading