diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 531773c22b667..f886df23b92c0 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1908,10 +1908,6 @@ class DifferentiableAttr final FuncDecl *getVJPFunction() const { return VJPFunction; } void setVJPFunction(FuncDecl *decl); - bool parametersMatch(const DifferentiableAttr &other) const { - return getParameterIndices() == other.getParameterIndices(); - } - /// Get the derivative generic environment for the given `@differentiable` /// attribute and original function. GenericEnvironment * @@ -1999,8 +1995,8 @@ class DerivativeAttr final IndexSubset *getParameterIndices() const { return ParameterIndices; } - void setParameterIndices(IndexSubset *pi) { - ParameterIndices = pi; + void setParameterIndices(IndexSubset *parameterIndices) { + ParameterIndices = parameterIndices; } static bool classof(const DeclAttribute *DA) { @@ -2079,8 +2075,8 @@ class TransposeAttr final IndexSubset *getParameterIndices() const { return ParameterIndices; } - void setParameterIndices(IndexSubset *pi) { - ParameterIndices = pi; + void setParameterIndices(IndexSubset *parameterIndices) { + ParameterIndices = parameterIndices; } static bool classof(const DeclAttribute *DA) { diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 1b429b3111073..1c9e14e6ae2f5 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -26,6 +26,46 @@ namespace swift { +enum class DifferentiabilityKind : uint8_t { + NonDifferentiable = 0, + Normal = 1, + Linear = 2 +}; + +/// The kind of an linear map. +struct AutoDiffLinearMapKind { + enum innerty : uint8_t { + // The differential function. + Differential = 0, + // The pullback function. + Pullback = 1 + } rawValue; + + AutoDiffLinearMapKind() = default; + AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {} + operator innerty() const { return rawValue; } +}; + +/// The kind of a derivative function. +struct AutoDiffDerivativeFunctionKind { + enum innerty : uint8_t { + // The Jacobian-vector products function. + JVP = 0, + // The vector-Jacobian products function. + VJP = 1 + } rawValue; + + AutoDiffDerivativeFunctionKind() = default; + AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {} + AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind) + : rawValue(static_cast<innerty>(linMapKind.rawValue)) {} + explicit AutoDiffDerivativeFunctionKind(StringRef string); + operator innerty() const { return rawValue; } + AutoDiffLinearMapKind getLinearMapKind() { + return (AutoDiffLinearMapKind::innerty)rawValue; + } +}; + class ParsedAutoDiffParameter { public: enum class Kind { Named, Ordered, Self }; @@ -89,12 +129,6 @@ class ParsedAutoDiffParameter { } }; -enum class DifferentiabilityKind : uint8_t { - NonDifferentiable = 0, - Normal = 1, - Linear = 2 -}; - } // end namespace swift // SWIFT_ENABLE_TENSORFLOW @@ -120,40 +154,6 @@ class SILFunctionType; typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType; enum class SILLinkage : uint8_t; -/// The kind of an linear map. -struct AutoDiffLinearMapKind { - enum innerty : uint8_t { - // The differential function. - Differential = 0, - // The pullback function. - Pullback = 1 - } rawValue; - - AutoDiffLinearMapKind() = default; - AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {} - operator innerty() const { return rawValue; } -}; - -/// The kind of a derivative function. -struct AutoDiffDerivativeFunctionKind { - enum innerty : uint8_t { - // The Jacobian-vector products function. - JVP = 0, - // The vector-Jacobian products function. - VJP = 1 - } rawValue; - - AutoDiffDerivativeFunctionKind() = default; - AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {} - AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind) - : rawValue(static_cast<innerty>(linMapKind.rawValue)) {} - explicit AutoDiffDerivativeFunctionKind(StringRef string); - operator innerty() const { return rawValue; } - AutoDiffLinearMapKind getLinearMapKind() { - return (AutoDiffLinearMapKind::innerty)rawValue; - } -}; - /// The kind of a differentiability witness function. struct DifferentiabilityWitnessFunctionKind { enum innerty : uint8_t { diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 83f277753dbde..60e6dff34f8f3 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1377,6 +1377,11 @@ ERROR(attr_expected_comma,none, ERROR(attr_expected_string_literal,none, "expected string literal in '%0' attribute", (StringRef)) +ERROR(attr_missing_label,PointsToFirstBadToken, + "missing label '%0:' in '@%1' attribute", (StringRef, StringRef)) +ERROR(attr_expected_label,none, + "expected label '%0:' in '@%1' attribute", (StringRef, StringRef)) + ERROR(alignment_must_be_positive_integer,none, "alignment value must be a positive integer literal", ()) @@ -1570,10 +1575,6 @@ ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken, // derivative ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken, "expected an original function name", ()) -ERROR(attr_missing_label,PointsToFirstBadToken, - "missing label '%0:' in '@%1' attribute", (StringRef, StringRef)) -ERROR(attr_expected_label,none, - "expected label '%0:' in '@%1' attribute", (StringRef, StringRef)) WARNING(attr_differentiating_deprecated,PointsToFirstBadToken, "'@differentiating' attribute is deprecated; use '@derivative(of:)' " "instead", ()) diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index cb1e0439e2599..12063d56ae5b8 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -804,6 +804,18 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) { ProtocolType.get(), MemberName, MemberNameLoc)); } +/// Parse a `@differentiable` attribute, returning true on error. +/// +/// \verbatim +/// differentiable-attribute-arguments: +/// '(' (differentiation-params-clause ',')? +/// (differentiable-attr-func-specifier ',')? +/// differentiable-attr-func-specifier? +/// where-clause? +/// ')' +/// differentiable-attr-func-specifier: +/// ('jvp' | 'vjp') ':' decl-name +/// \endverbatim ParserResult<DifferentiableAttr> Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { StringRef AttrName = "differentiable"; @@ -852,6 +864,16 @@ static bool errorAndSkipUntilConsumeRightParen(Parser &P, StringRef attrName, return true; }; +/// Parse a differentiation parameters 'wrt:' clause, returning true on error. +/// +/// \verbatim +/// differentiation-params-clause: +/// 'wrt' ':' (differentiation-param | differentiation-params) +/// differentiation-params: +/// '(' differentiation-param (',' differentiation-param)* ')' +/// differentiation-param: +/// 'self' | identifier +/// \endverbatim bool Parser::parseDifferentiationParametersClause( SmallVectorImpl<ParsedAutoDiffParameter> ¶ms, StringRef attrName) { SyntaxParsingContext DiffParamsClauseContext( @@ -929,6 +951,16 @@ bool Parser::parseDifferentiationParametersClause( } // SWIFT_ENABLE_TENSORFLOW +/// Parse a transposed parameters 'wrt:' clause, returning true on error. +/// +/// \verbatim +/// transposed-params-clause: +/// 'wrt' ':' (transposed-param | transposed-params) +/// transposed-params: +/// '(' transposed-param (',' transposed-param)* ')' +/// transposed-param: +/// 'self' | [0-9]+ +/// \endverbatim bool Parser::parseTransposedParametersClause( SmallVectorImpl<ParsedAutoDiffParameter> ¶ms, StringRef attrName) { SyntaxParsingContext TransposeParamsClauseContext( @@ -1130,7 +1162,13 @@ bool Parser::parseDifferentiableAttributeArguments( return false; } -/// SWIFT_ENABLE_TENSORFLOW +// SWIFT_ENABLE_TENSORFLOW +/// Parse a `@derivative(of:)` attribute, returning true on error. +/// +/// \verbatim +/// derivative-attribute-arguments: +/// '(' 'of' ':' decl-name (',' differentiation-params-clause)? ')' +/// \endverbatim ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc, SourceLoc loc) { StringRef AttrName = "derivative"; @@ -1196,7 +1234,7 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc, SourceRange(loc, rParenLoc), original, params)); } -/// SWIFT_ENABLE_TENSORFLOW +// SWIFT_ENABLE_TENSORFLOW ParserResult<DerivativeAttr> Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) { StringRef AttrName = "differentiating"; @@ -1315,6 +1353,12 @@ bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError, // SWIFT_ENABLE_TENSORFLOW END // SWIFT_ENABLE_TENSORFLOW +/// Parse a `@transpose(of:)` attribute, returning true on error. +/// +/// \verbatim +/// transpose-attribute-arguments: +/// '(' 'of' ':' decl-name (',' transposed-params-clause)? ')' +/// \endverbatim ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc, SourceLoc loc) { StringRef AttrName = "transpose"; diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 1de674be6bd71..a09aa02540081 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -237,7 +237,7 @@ using DifferentiabilityKindField = BCFixed<2>; // module version. enum class AutoDiffDerivativeFunctionKind : uint8_t { JVP = 0, - VJP = 1 + VJP }; using AutoDiffDerivativeFunctionKindField = BCFixed<1>; // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 26184824eb6ed..b3b96be993def 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2030,8 +2030,8 @@ static uint8_t getRawStableVarDeclIntroducer(swift::VarDecl::Introducer intr) { } // SWIFT_ENABLE_TENSORFLOW -/// Translate from the AST differentiability kind enum to the Serialization enum -/// values, which are guaranteed to be stable. +/// Translate from the AST derivative function kind enum to the Serialization +/// enum values, which are guaranteed to be stable. static uint8_t getRawStableAutoDiffDerivativeFunctionKind( swift::AutoDiffDerivativeFunctionKind kind) { switch (kind) { diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index 26828af587f02..afcbd238e064f 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -183,11 +183,11 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original, AutoDiffLinearMapKind kind) { auto declRef = SILDeclRef(original); - // Linear maps are only public when the original function is serialized. + // Linear maps are public only when the original function is serialized. if (!declRef.isSerialized()) return; - // Differentials are only emitted when forward mode is turned on. + // Linear maps are emitted only when forward mode is enabled. if (kind == AutoDiffLinearMapKind::Differential && !original->getASTContext() .LangOpts.EnableExperimentalForwardModeDifferentiation) diff --git a/test/AutoDiff/derivative_attr_parse.swift b/test/AutoDiff/derivative_attr_parse.swift index fc48499c5c52a..eab75c32b75c4 100644 --- a/test/AutoDiff/derivative_attr_parse.swift +++ b/test/AutoDiff/derivative_attr_parse.swift @@ -2,12 +2,6 @@ /// Good -@derivative(of: sin) // ok -func jvpSin(x: @nondiff Float) --> (value: Float, differential: (Float)-> (Float)) { - return (x, { $0 }) -} - @derivative(of: sin, wrt: x) // ok func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) { return (x, { $0 }) @@ -15,54 +9,54 @@ func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) { @derivative(of: add, wrt: (x, y)) // ok func vjpAdd(x: Float, y: Float) --> (value: Float, pullback: (Float) -> (Float, Float)) { + -> (value: Float, pullback: (Float) -> (Float, Float)) { return (x + y, { ($0, $0) }) } -extension AdditiveArithmetic where Self : Differentiable { +extension AdditiveArithmetic where Self: Differentiable { @derivative(of: +) // ok - static func vjpPlus(x: Self, y: Self) -> (value: Self, - pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) { + static func vjpAdd(x: Self, y: Self) + -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) { return (x + y, { v in (v, v) }) } } -@derivative(of: linear) // ok +@derivative(of: foo) // ok func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } -// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}} -@derivative(of: linear, linear) // ok +/// Bad + +// expected-error @+3 {{expected an original function name}} +// expected-error @+2 {{expected ')' in 'derivative' attribute}} +// expected-error @+1 {{expected declaration}} +@derivative(of: 3) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } // expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}} -@derivative(of: foo, linear, wrt: x) // ok +@derivative(of: wrt, foo) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } -/// Bad - -// expected-error @+3 {{expected an original function name}} -// expected-error @+2 {{expected ')' in 'derivative' attribute}} -// expected-error @+1 {{expected declaration}} -@derivative(of: 3) +// expected-error @+1 {{expected a colon ':' after 'wrt'}} +@derivative(of: foo, wrt) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } // expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}} -@derivative(of: linear, foo) +@derivative(of: foo, blah, wrt: x) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } // expected-error @+2 {{expected ')' in 'derivative' attribute}} // expected-error @+1 {{expected declaration}} -@derivative(of: foo, wrt: x, linear) +@derivative(of: foo, wrt: x, blah) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } @@ -81,13 +75,13 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { } // expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}} -@derivative(of: linear, foo,) +@derivative(of: foo, foo,) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) } // expected-error @+1 {{unexpected ',' separator}} -@derivative(of: linear,) +@derivative(of: foo,) func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x, { $0 }) }