Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3449,8 +3449,8 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none,
ERROR(autodiff_attr_original_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(autodiff_attr_original_multiple_semantic_results,none,
"cannot differentiate functions with both an 'inout' parameter and a "
"result", ())
"cannot differentiate functions with both a differentiable 'inout' "
"parameter and a result", ())
ERROR(autodiff_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
Expand Down Expand Up @@ -5017,12 +5017,19 @@ ERROR(differentiable_function_type_invalid_result,none,
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(_linear)}1'",
(StringRef, bool))
ERROR(differentiable_function_type_no_differentiability_parameters,
none,
ERROR(differentiable_function_type_multiple_semantic_results,none,
"'@differentiable' function type cannot have both a differentiable "
"'inout' parameter and a differentiable result", ())
ERROR(differentiable_function_type_no_differentiability_parameters, none,
"'@differentiable' function type requires at least one differentiability "
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
(/*isLinear*/ bool))
ERROR(differentiable_function_type_no_differentiable_result, none,
"'@differentiable' function type requires a differentiable result, "
"i.e. a non-'Void' type that conforms to 'Differentiable'%select{| with "
"its 'TangentVector' equal to itself}0",
(/*isLinear*/ bool))

// SIL
ERROR(opened_non_protocol,none,
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6019,7 +6019,7 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
return tangentSpace;
};

// For tuple types: the tangent space is a tuple of the elements' tangent
// For tuple types: the tangent space is a tuple of the elements' tangent
// space types, for the elements that have a tangent space.
if (auto *tupleTy = getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
Expand Down
47 changes: 41 additions & 6 deletions lib/Sema/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,13 +596,17 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
dc, stage);
}) != params.end();
bool alreadyDiagnosedOneParam = false;
bool hasInoutDiffParameter = false;
for (unsigned i = 0, end = fnTy->getNumParams(); i != end; ++i) {
auto param = params[i];
if (param.isNoDerivative())
continue;
auto paramType = param.getPlainType();
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage))
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage)) {
if (param.isInOut())
hasInoutDiffParameter = true;
continue;
}
auto diagLoc =
repr ? (*repr)->getArgsTypeRepr()->getElement(i).Type->getLoc() : loc;
auto paramTypeString = paramType->getString();
Expand All @@ -614,6 +618,7 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
if (hasValidDifferentiabilityParam)
diagnostic.fixItInsert(diagLoc, "@noDerivative ");
}

// Reject the case where all parameters have '@noDerivative'.
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
Expand All @@ -628,11 +633,28 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
}
}

// Check the result
bool differentiable = isDifferentiable(result,
/*tangentVectorEqualsSelf*/ isLinear,
dc, stage);
if (!differentiable) {
// Check the result.
bool resultExists = !(result->isVoid());
bool resultIsDifferentiable =
isDifferentiable(result, /*tangentVectorEqualsSelf*/ isLinear, dc,
stage);
bool differentiableResultExists = resultExists && resultIsDifferentiable;

// Reject the case where there are multiple semantic results.
if (differentiableResultExists && hasInoutDiffParameter) {
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
auto diag = ctx.Diags.diagnose(
diagLoc,
diag::differentiable_function_type_multiple_semantic_results);
hadAnyError = true;

if (repr) {
diag.highlight((*repr)->getSourceRange());
}
}

// Reject the case where the semantic result is not differentiable.
if (!resultIsDifferentiable && !hasInoutDiffParameter) {
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
auto resultStr = fnTy->getResult()->getString();
auto diag = ctx.Diags.diagnose(
Expand All @@ -644,6 +666,19 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
}
}

// Reject the case where there are no semantic results.
if (!resultExists && !hasInoutDiffParameter) {
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
auto diag = ctx.Diags.diagnose(
diagLoc, diag::differentiable_function_type_no_differentiable_result,
isLinear);
hadAnyError = true;

if (repr) {
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
}
}
}

return hadAnyError;
Expand Down
6 changes: 3 additions & 3 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ extension ProtocolRequirementDerivative {
func multipleSemanticResults(_ x: inout Float) -> Float {
return x
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(x: inout Float) -> (
value: Float, pullback: (Float) -> Float
Expand Down Expand Up @@ -885,14 +885,14 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}

extension InoutParameters {
func multipleSemanticResults(_ x: inout Float) -> Float { x }
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(_ x: inout Float) -> (
value: Float, pullback: (inout Float) -> Void
) { fatalError() }

func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@derivative(of: inoutVoid)
func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
value: Float, pullback: (inout Float) -> Void
Expand Down
10 changes: 5 additions & 5 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func two9(x: Float, y: Float) -> Float {
func inout1(x: Float, y: inout Float) -> Void {
let _ = x + y
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@differentiable(reverse, wrt: y)
func inout2(x: Float, y: inout Float) -> Float {
let _ = x + y
Expand Down Expand Up @@ -670,11 +670,11 @@ final class FinalClass: Differentiable {
@differentiable(reverse, wrt: y)
func inoutVoid(x: Float, y: inout Float) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@differentiable(reverse)
func multipleSemanticResults(_ x: inout Float) -> Float { x }

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@differentiable(reverse, wrt: y)
func swap(x: inout Float, y: inout Float) {}

Expand All @@ -687,7 +687,7 @@ extension InoutParameters {
@differentiable(reverse)
static func staticMethod(_ lhs: inout Self, rhs: Self) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@differentiable(reverse)
static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {}
}
Expand All @@ -696,7 +696,7 @@ extension InoutParameters {
@differentiable(reverse)
mutating func mutatingMethod(_ other: Self) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
@differentiable(reverse)
mutating func mutatingMethod(_ other: Self) -> Self {}
}
Expand Down
7 changes: 7 additions & 0 deletions test/AutoDiff/Sema/differentiable_func_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ let _: @differentiable(reverse) (Float, NonDiffType) -> Float
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'}}
let _: @differentiable(_linear) (Float) -> NonDiffType

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(reverse) (inout Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(_linear) (inout Float) -> Float


// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'; did you want to add '@noDerivative' to this parameter?}} {{41-41=@noDerivative }}
let _: @differentiable(_linear) (Float, NonDiffType) -> Float
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// SR-15808: In AST, type checking skips a closure with non-differentiable input
// where `Void` is included as a parameter without being marked `@noDerivative`.
// It also crashes when the output is `Void` and no input is `inout`. As a
// result, the compiler crashes during Sema.
import _Differentiation

// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
func helloWorld(_ x: @differentiable(reverse) (()) -> Void) {}

func helloWorld(_ x: @differentiable(reverse) (()) -> Float) {}

// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
func helloWorld(_ x: @differentiable(reverse) (Float) -> Void) {}

func helloWorld(_ x: @differentiable(reverse) (@noDerivative Float, Void) -> Float) {}

// Original crash:
// Assertion failed: (!parameterIndices->isEmpty() && "Parameter indices must not be empty"), function getAutoDiffDerivativeFunctionType, file SILFunctionType.cpp, line 800.
// Stack dump:
// ...
// 1. Apple Swift version 5.6-dev (LLVM 7b20e61dd04138a, Swift 9438cf6b2e83c5f)
// 2. Compiling with the current language version
// 3. While evaluating request ASTLoweringRequest(Lowering AST to SIL for file "/Users/philipturner/Desktop/Experimentation4/Experimentation4/main.swift")
// Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
// 0 swift-frontend 0x0000000108d7a5c0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 56
// 1 swift-frontend 0x0000000108d79820 llvm::sys::RunSignalHandlers() + 128
// 2 swift-frontend 0x0000000108d7ac24 SignalHandler(int) + 304
// 3 libsystem_platform.dylib 0x00000001bb5304e4 _sigtramp + 56
// 4 libsystem_pthread.dylib 0x00000001bb518eb0 pthread_kill + 288
// 5 libsystem_c.dylib 0x00000001bb456314 abort + 164
// 6 libsystem_c.dylib 0x00000001bb45572c err + 0
// 7 swift-frontend 0x0000000108d9ae3c swift::SILFunctionType::getAutoDiffDerivativeFunctionType(swift::IndexSubset*, swift::IndexSubset*, swift::AutoDiffDerivativeFunctionKind, swift::Lowering::TypeConverter&, llvm::function_ref<swift::ProtocolConformanceRef (swift::CanType, swift::Type, swift::ProtocolDecl*)>, swift::CanGenericSignature, bool, swift::CanType) (.cold.3) + 0
// 8 swift-frontend 0x0000000104abc35c swift::SILFunctionType::getAutoDiffDerivativeFunctionType(swift::IndexSubset*, swift::IndexSubset*, swift::AutoDiffDerivativeFunctionKind, swift::Lowering::TypeConverter&, llvm::function_ref<swift::ProtocolConformanceRef (swift::CanType, swift::Type, swift::ProtocolDecl*)>, swift::CanGenericSignature, bool, swift::CanType) + 152
// 9 swift-frontend 0x0000000104b496cc (anonymous namespace)::TypeClassifierBase<(anonymous namespace)::LowerType, swift::Lowering::TypeLowering*>::getNormalDifferentiableSILFunctionTypeRecursiveProperties(swift::CanTypeWrapper<swift::SILFunctionType>, swift::Lowering::AbstractionPattern) + 184
// 10 swift-frontend 0x0000000104b3b72c swift::CanTypeVisitor<(anonymous namespace)::LowerType, swift::Lowering::TypeLowering*, swift::Lowering::AbstractionPattern, swift::Lowering::IsTypeExpansionSensitive_t>::visit(swift::CanType, swift::Lowering::AbstractionPattern, swift::Lowering::IsTypeExpansionSensitive_t) + 1980
// 11 swift-frontend 0x0000000104b3c0e0 swift::Lowering::TypeConverter::getTypeLoweringForLoweredType(swift::Lowering::AbstractionPattern, swift::CanType, swift::TypeExpansionContext, swift::Lowering::IsTypeExpansionSensitive_t) + 648
// 12 swift-frontend 0x0000000104b3ae08 swift::Lowering::TypeConverter::getTypeLowering(swift::Lowering::AbstractionPattern, swift::Type, swift::TypeExpansionContext) + 708
// 13 swift-frontend 0x0000000104ac8544 (anonymous namespace)::DestructureInputs::visit(swift::ValueOwnership, bool, swift::Lowering::AbstractionPattern, swift::CanType, bool, bool) + 184
// 14 swift-frontend 0x0000000104ac6a1c getSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, (anonymous namespace)::Conventions const&, swift::ForeignInfo const&, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>) + 2584
// 15 swift-frontend 0x0000000104ac5f98 getNativeSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>)::$_12::operator()((anonymous namespace)::Conventions const&) const + 316
// 16 swift-frontend 0x0000000104abf55c getNativeSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>) + 508
// 17 swift-frontend 0x0000000104ac0b44 getUncachedSILFunctionTypeForConstant(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::SILDeclRef, swift::Lowering::TypeConverter::LoweredFormalTypes) + 1920
// 18 swift-frontend 0x0000000104ac1474 swift::Lowering::TypeConverter::getConstantInfo(swift::TypeExpansionContext, swift::SILDeclRef) + 216
// 19 swift-frontend 0x0000000104ab9808 swift::SILFunctionBuilder::getOrCreateFunction(swift::SILLocation, swift::SILDeclRef, swift::ForDefinition_t, llvm::function_ref<swift::SILFunction* (swift::SILLocation, swift::SILDeclRef)>, swift::ProfileCounter) + 132
// 20 swift-frontend 0x0000000104f1d120 swift::Lowering::SILGenModule::getFunction(swift::SILDeclRef, swift::ForDefinition_t) + 328
// 21 swift-frontend 0x0000000104f2086c emitOrDelayFunction(swift::Lowering::SILGenModule&, swift::SILDeclRef, bool) + 344
// 22 swift-frontend 0x0000000104f1d828 swift::Lowering::SILGenModule::emitFunction(swift::FuncDecl*) + 140
// 23 swift-frontend 0x0000000104f2294c swift::ASTLoweringRequest::evaluate(swift::Evaluator&, swift::ASTLoweringDescriptor) const + 1612
// 24 swift-frontend 0x0000000104fcbca4 swift::SimpleRequest<swift::ASTLoweringRequest, std::__1::unique_ptr<swift::SILModule, std::__1::default_delete<swift::SILModule> > (swift::ASTLoweringDescriptor), (swift::RequestFlags)9>::evaluateRequest(swift::ASTLoweringRequest const&, swift::Evaluator&) + 156
// 25 swift-frontend 0x0000000104f2647c llvm::Expected<swift::ASTLoweringRequest::OutputType> swift::Evaluator::getResultUncached<swift::ASTLoweringRequest>(swift::ASTLoweringRequest const&) + 408
// 26 swift-frontend 0x0000000104f233b0 swift::performASTLowering(swift::FileUnit&, swift::Lowering::TypeConverter&, swift::SILOptions const&) + 104
// 27 swift-frontend 0x0000000104a12088 swift::performCompileStepsPostSema(swift::CompilerInstance&, int&, swift::FrontendObserver*) + 496
// 28 swift-frontend 0x0000000104a13d08 swift::performFrontend(llvm::ArrayRef<char const*>, char const*, void*, swift::FrontendObserver*) + 2936
// 29 swift-frontend 0x00000001049b213c swift::mainEntry(int, char const**) + 500
// 30 dyld 0x00000001113c90f4 start + 520
Loading