Skip to content
15 changes: 11 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3462,8 +3462,8 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none,
ERROR(autodiff_attr_original_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(autodiff_attr_original_multiple_semantic_results,none,
"cannot differentiate functions with both an 'inout' parameter and a "
"result", ())
"cannot differentiate functions with both a differentiable 'inout' "
"parameter and a differentiable result", ())
ERROR(autodiff_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
Expand Down Expand Up @@ -5040,12 +5040,19 @@ ERROR(differentiable_function_type_invalid_result,none,
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(_linear)}1'",
(StringRef, bool))
ERROR(differentiable_function_type_no_differentiability_parameters,
none,
ERROR(differentiable_function_type_multiple_semantic_results,none,
"'@differentiable' function type cannot have both a differentiable "
"'inout' parameter and a differentiable result", ())
ERROR(differentiable_function_type_no_differentiability_parameters,none,
"'@differentiable' function type requires at least one differentiability "
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
(/*isLinear*/ bool))
ERROR(differentiable_function_type_no_differentiable_result,none,
"'@differentiable' function type requires a differentiable result, i.e. "
"a non-'Void' type that conforms to 'Differentiable'%select{| with its "
"'TangentVector' equal to itself}0",
(/*isLinear*/ bool))

// SIL
ERROR(opened_non_protocol,none,
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6178,7 +6178,7 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
return tangentSpace;
};

// For tuple types: the tangent space is a tuple of the elements' tangent
// For tuple types: the tangent space is a tuple of the elements' tangent
// space types, for the elements that have a tangent space.
if (auto *tupleTy = getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
Expand Down
46 changes: 40 additions & 6 deletions lib/Sema/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,13 +619,17 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
dc, stage);
}) != params.end();
bool alreadyDiagnosedOneParam = false;
bool hasDifferentiableInoutParameter = false;
for (unsigned i = 0, end = fnTy->getNumParams(); i != end; ++i) {
auto param = params[i];
if (param.isNoDerivative())
continue;
auto paramType = param.getPlainType();
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage))
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage)) {
if (param.isInOut())
hasDifferentiableInoutParameter = true;
continue;
}
auto diagLoc =
repr ? (*repr)->getArgsTypeRepr()->getElement(i).Type->getLoc() : loc;
auto paramTypeString = paramType->getString();
Expand All @@ -637,6 +641,7 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
if (hasValidDifferentiabilityParam)
diagnostic.fixItInsert(diagLoc, "@noDerivative ");
}

// Reject the case where all parameters have '@noDerivative'.
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
Expand All @@ -651,11 +656,27 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
}
}

// Check the result
bool differentiable = isDifferentiable(result,
/*tangentVectorEqualsSelf*/ isLinear,
dc, stage);
if (!differentiable) {
// Check the result.
bool resultExists = !(result->isVoid());
bool resultIsDifferentiable = TypeChecker::isDifferentiable(
result, /*tangentVectorEqualsSelf*/ isLinear, dc, stage);
bool differentiableResultExists = resultExists && resultIsDifferentiable;

// Reject the case where there are multiple semantic results.
if (differentiableResultExists && hasDifferentiableInoutParameter) {
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
auto diag = ctx.Diags.diagnose(
diagLoc,
diag::differentiable_function_type_multiple_semantic_results);
hadAnyError = true;

if (repr) {
diag.highlight((*repr)->getSourceRange());
}
}

// Reject the case where the semantic result is not differentiable.
if (!resultIsDifferentiable && !hasDifferentiableInoutParameter) {
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
auto resultStr = fnTy->getResult()->getString();
auto diag = ctx.Diags.diagnose(
Expand All @@ -667,6 +688,19 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
}
}

// Reject the case where there are no semantic results.
if (!resultExists && !hasDifferentiableInoutParameter) {
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
auto diag = ctx.Diags.diagnose(
diagLoc, diag::differentiable_function_type_no_differentiable_result,
isLinear);
hadAnyError = true;

if (repr) {
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
}
}
}

return hadAnyError;
Expand Down
6 changes: 3 additions & 3 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ extension ProtocolRequirementDerivative {
func multipleSemanticResults(_ x: inout Float) -> Float {
return x
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(x: inout Float) -> (
value: Float, pullback: (Float) -> Float
Expand Down Expand Up @@ -885,14 +885,14 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}

extension InoutParameters {
func multipleSemanticResults(_ x: inout Float) -> Float { x }
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(_ x: inout Float) -> (
value: Float, pullback: (inout Float) -> Void
) { fatalError() }

func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@derivative(of: inoutVoid)
func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
value: Float, pullback: (inout Float) -> Void
Expand Down
10 changes: 5 additions & 5 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func two9(x: Float, y: Float) -> Float {
func inout1(x: Float, y: inout Float) -> Void {
let _ = x + y
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse, wrt: y)
func inout2(x: Float, y: inout Float) -> Float {
let _ = x + y
Expand Down Expand Up @@ -670,11 +670,11 @@ final class FinalClass: Differentiable {
@differentiable(reverse, wrt: y)
func inoutVoid(x: Float, y: inout Float) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse)
func multipleSemanticResults(_ x: inout Float) -> Float { x }

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse, wrt: y)
func swap(x: inout Float, y: inout Float) {}

Expand All @@ -687,7 +687,7 @@ extension InoutParameters {
@differentiable(reverse)
static func staticMethod(_ lhs: inout Self, rhs: Self) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse)
static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {}
}
Expand All @@ -696,7 +696,7 @@ extension InoutParameters {
@differentiable(reverse)
mutating func mutatingMethod(_ other: Self) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse)
mutating func mutatingMethod(_ other: Self) -> Self {}
}
Expand Down
13 changes: 12 additions & 1 deletion test/AutoDiff/Sema/differentiable_func_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ let _: @differentiable(reverse) (Float) throws -> Float

struct NonDiffType { var x: Int }

// FIXME: Properly type-check parameters and the result's differentiability
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
let _: @differentiable(reverse) (NonDiffType) -> Float

Expand All @@ -29,6 +28,12 @@ let _: @differentiable(reverse) (Float, NonDiffType) -> Float
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'}}
let _: @differentiable(_linear) (Float) -> NonDiffType

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(reverse) (inout Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(_linear) (inout Float) -> Float

// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'; did you want to add '@noDerivative' to this parameter?}} {{41-41=@noDerivative }}
let _: @differentiable(_linear) (Float, NonDiffType) -> Float
Expand All @@ -41,6 +46,12 @@ let _: @differentiable(_linear) (Float) -> NonDiffType

let _: @differentiable(_linear) (Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(reverse) (inout Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(_linear) (inout Float) -> Float

// expected-error @+1 {{result type '@differentiable(reverse) (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
func test1<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> @differentiable(reverse) (U) -> Float) {}
// expected-error @+1 {{result type '(U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// SR-15808: In AST, type checking skips a closure with non-differentiable input
// where `Void` is included as a parameter without being marked `@noDerivative`.
// It also crashes when the output is `Void` and no input is `inout`. As a
// result, the compiler crashes during Sema.
import _Differentiation

// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
func helloWorld(_ x: @differentiable(reverse) (()) -> Void) {}

func helloWorld(_ x: @differentiable(reverse) (()) -> Float) {}

// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
func helloWorld(_ x: @differentiable(reverse) (Float) -> Void) {}

func helloWorld(_ x: @differentiable(reverse) (@noDerivative Float, Void) -> Float) {}

// Original crash:
// Assertion failed: (!parameterIndices->isEmpty() && "Parameter indices must not be empty"), function getAutoDiffDerivativeFunctionType, file SILFunctionType.cpp, line 800.
// Stack dump:
// ...
// 1. Apple Swift version 5.6-dev (LLVM 7b20e61dd04138a, Swift 9438cf6b2e83c5f)
// 2. Compiling with the current language version
// 3. While evaluating request ASTLoweringRequest(Lowering AST to SIL for file "/Users/philipturner/Desktop/Experimentation4/Experimentation4/main.swift")
// Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
// 0 swift-frontend 0x0000000108d7a5c0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 56
// 1 swift-frontend 0x0000000108d79820 llvm::sys::RunSignalHandlers() + 128
// 2 swift-frontend 0x0000000108d7ac24 SignalHandler(int) + 304
// 3 libsystem_platform.dylib 0x00000001bb5304e4 _sigtramp + 56
// 4 libsystem_pthread.dylib 0x00000001bb518eb0 pthread_kill + 288
// 5 libsystem_c.dylib 0x00000001bb456314 abort + 164
// 6 libsystem_c.dylib 0x00000001bb45572c err + 0
// 7 swift-frontend 0x0000000108d9ae3c swift::SILFunctionType::getAutoDiffDerivativeFunctionType(swift::IndexSubset*, swift::IndexSubset*, swift::AutoDiffDerivativeFunctionKind, swift::Lowering::TypeConverter&, llvm::function_ref<swift::ProtocolConformanceRef (swift::CanType, swift::Type, swift::ProtocolDecl*)>, swift::CanGenericSignature, bool, swift::CanType) (.cold.3) + 0
// 8 swift-frontend 0x0000000104abc35c swift::SILFunctionType::getAutoDiffDerivativeFunctionType(swift::IndexSubset*, swift::IndexSubset*, swift::AutoDiffDerivativeFunctionKind, swift::Lowering::TypeConverter&, llvm::function_ref<swift::ProtocolConformanceRef (swift::CanType, swift::Type, swift::ProtocolDecl*)>, swift::CanGenericSignature, bool, swift::CanType) + 152
// 9 swift-frontend 0x0000000104b496cc (anonymous namespace)::TypeClassifierBase<(anonymous namespace)::LowerType, swift::Lowering::TypeLowering*>::getNormalDifferentiableSILFunctionTypeRecursiveProperties(swift::CanTypeWrapper<swift::SILFunctionType>, swift::Lowering::AbstractionPattern) + 184
// 10 swift-frontend 0x0000000104b3b72c swift::CanTypeVisitor<(anonymous namespace)::LowerType, swift::Lowering::TypeLowering*, swift::Lowering::AbstractionPattern, swift::Lowering::IsTypeExpansionSensitive_t>::visit(swift::CanType, swift::Lowering::AbstractionPattern, swift::Lowering::IsTypeExpansionSensitive_t) + 1980
// 11 swift-frontend 0x0000000104b3c0e0 swift::Lowering::TypeConverter::getTypeLoweringForLoweredType(swift::Lowering::AbstractionPattern, swift::CanType, swift::TypeExpansionContext, swift::Lowering::IsTypeExpansionSensitive_t) + 648
// 12 swift-frontend 0x0000000104b3ae08 swift::Lowering::TypeConverter::getTypeLowering(swift::Lowering::AbstractionPattern, swift::Type, swift::TypeExpansionContext) + 708
// 13 swift-frontend 0x0000000104ac8544 (anonymous namespace)::DestructureInputs::visit(swift::ValueOwnership, bool, swift::Lowering::AbstractionPattern, swift::CanType, bool, bool) + 184
// 14 swift-frontend 0x0000000104ac6a1c getSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, (anonymous namespace)::Conventions const&, swift::ForeignInfo const&, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>) + 2584
// 15 swift-frontend 0x0000000104ac5f98 getNativeSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>)::$_12::operator()((anonymous namespace)::Conventions const&) const + 316
// 16 swift-frontend 0x0000000104abf55c getNativeSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>) + 508
// 17 swift-frontend 0x0000000104ac0b44 getUncachedSILFunctionTypeForConstant(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::SILDeclRef, swift::Lowering::TypeConverter::LoweredFormalTypes) + 1920
// 18 swift-frontend 0x0000000104ac1474 swift::Lowering::TypeConverter::getConstantInfo(swift::TypeExpansionContext, swift::SILDeclRef) + 216
// 19 swift-frontend 0x0000000104ab9808 swift::SILFunctionBuilder::getOrCreateFunction(swift::SILLocation, swift::SILDeclRef, swift::ForDefinition_t, llvm::function_ref<swift::SILFunction* (swift::SILLocation, swift::SILDeclRef)>, swift::ProfileCounter) + 132
// 20 swift-frontend 0x0000000104f1d120 swift::Lowering::SILGenModule::getFunction(swift::SILDeclRef, swift::ForDefinition_t) + 328
// 21 swift-frontend 0x0000000104f2086c emitOrDelayFunction(swift::Lowering::SILGenModule&, swift::SILDeclRef, bool) + 344
// 22 swift-frontend 0x0000000104f1d828 swift::Lowering::SILGenModule::emitFunction(swift::FuncDecl*) + 140
// 23 swift-frontend 0x0000000104f2294c swift::ASTLoweringRequest::evaluate(swift::Evaluator&, swift::ASTLoweringDescriptor) const + 1612
// 24 swift-frontend 0x0000000104fcbca4 swift::SimpleRequest<swift::ASTLoweringRequest, std::__1::unique_ptr<swift::SILModule, std::__1::default_delete<swift::SILModule> > (swift::ASTLoweringDescriptor), (swift::RequestFlags)9>::evaluateRequest(swift::ASTLoweringRequest const&, swift::Evaluator&) + 156
// 25 swift-frontend 0x0000000104f2647c llvm::Expected<swift::ASTLoweringRequest::OutputType> swift::Evaluator::getResultUncached<swift::ASTLoweringRequest>(swift::ASTLoweringRequest const&) + 408
// 26 swift-frontend 0x0000000104f233b0 swift::performASTLowering(swift::FileUnit&, swift::Lowering::TypeConverter&, swift::SILOptions const&) + 104
// 27 swift-frontend 0x0000000104a12088 swift::performCompileStepsPostSema(swift::CompilerInstance&, int&, swift::FrontendObserver*) + 496
// 28 swift-frontend 0x0000000104a13d08 swift::performFrontend(llvm::ArrayRef<char const*>, char const*, void*, swift::FrontendObserver*) + 2936
// 29 swift-frontend 0x00000001049b213c swift::mainEntry(int, char const**) + 500
// 30 dyld 0x00000001113c90f4 start + 520
Loading