Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] NFC: Gardening. #28673

Merged
merged 1 commit into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1908,10 +1908,6 @@ class DifferentiableAttr final
FuncDecl *getVJPFunction() const { return VJPFunction; }
void setVJPFunction(FuncDecl *decl);

bool parametersMatch(const DifferentiableAttr &other) const {
return getParameterIndices() == other.getParameterIndices();
}

/// Get the derivative generic environment for the given `@differentiable`
/// attribute and original function.
GenericEnvironment *
Expand Down Expand Up @@ -1999,8 +1995,8 @@ class DerivativeAttr final
IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}

static bool classof(const DeclAttribute *DA) {
Expand Down Expand Up @@ -2079,8 +2075,8 @@ class TransposeAttr final
IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}

static bool classof(const DeclAttribute *DA) {
Expand Down
80 changes: 40 additions & 40 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,46 @@

namespace swift {

enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};

/// The kind of an linear map.
struct AutoDiffLinearMapKind {
enum innerty : uint8_t {
// The differential function.
Differential = 0,
// The pullback function.
Pullback = 1
} rawValue;

AutoDiffLinearMapKind() = default;
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
operator innerty() const { return rawValue; }
};

/// The kind of a derivative function.
struct AutoDiffDerivativeFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
} rawValue;

AutoDiffDerivativeFunctionKind() = default;
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
explicit AutoDiffDerivativeFunctionKind(StringRef string);
operator innerty() const { return rawValue; }
AutoDiffLinearMapKind getLinearMapKind() {
return (AutoDiffLinearMapKind::innerty)rawValue;
}
};

class ParsedAutoDiffParameter {
public:
enum class Kind { Named, Ordered, Self };
Expand Down Expand Up @@ -89,12 +129,6 @@ class ParsedAutoDiffParameter {
}
};

enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};

} // end namespace swift

// SWIFT_ENABLE_TENSORFLOW
Expand All @@ -120,40 +154,6 @@ class SILFunctionType;
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
enum class SILLinkage : uint8_t;

/// The kind of an linear map.
struct AutoDiffLinearMapKind {
enum innerty : uint8_t {
// The differential function.
Differential = 0,
// The pullback function.
Pullback = 1
} rawValue;

AutoDiffLinearMapKind() = default;
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
operator innerty() const { return rawValue; }
};

/// The kind of a derivative function.
struct AutoDiffDerivativeFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
} rawValue;

AutoDiffDerivativeFunctionKind() = default;
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
explicit AutoDiffDerivativeFunctionKind(StringRef string);
operator innerty() const { return rawValue; }
AutoDiffLinearMapKind getLinearMapKind() {
return (AutoDiffLinearMapKind::innerty)rawValue;
}
};

/// The kind of a differentiability witness function.
struct DifferentiabilityWitnessFunctionKind {
enum innerty : uint8_t {
Expand Down
9 changes: 5 additions & 4 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,11 @@ ERROR(attr_expected_comma,none,
ERROR(attr_expected_string_literal,none,
"expected string literal in '%0' attribute", (StringRef))

ERROR(attr_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
ERROR(attr_expected_label,none,
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))

ERROR(alignment_must_be_positive_integer,none,
"alignment value must be a positive integer literal", ())

Expand Down Expand Up @@ -1570,10 +1575,6 @@ ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
// derivative
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
"expected an original function name", ())
ERROR(attr_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
ERROR(attr_expected_label,none,
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
"instead", ())
Expand Down
48 changes: 46 additions & 2 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,18 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {
ProtocolType.get(), MemberName, MemberNameLoc));
}

/// Parse a `@differentiable` attribute, returning true on error.
///
/// \verbatim
/// differentiable-attribute-arguments:
/// '(' (differentiation-params-clause ',')?
/// (differentiable-attr-func-specifier ',')?
/// differentiable-attr-func-specifier?
/// where-clause?
/// ')'
/// differentiable-attr-func-specifier:
/// ('jvp' | 'vjp') ':' decl-name
/// \endverbatim
ParserResult<DifferentiableAttr>
Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
StringRef AttrName = "differentiable";
Expand Down Expand Up @@ -852,6 +864,16 @@ static bool errorAndSkipUntilConsumeRightParen(Parser &P, StringRef attrName,
return true;
};

/// Parse a differentiation parameters 'wrt:' clause, returning true on error.
///
/// \verbatim
/// differentiation-params-clause:
/// 'wrt' ':' (differentiation-param | differentiation-params)
/// differentiation-params:
/// '(' differentiation-param (',' differentiation-param)* ')'
/// differentiation-param:
/// 'self' | identifier
/// \endverbatim
bool Parser::parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {
SyntaxParsingContext DiffParamsClauseContext(
Expand Down Expand Up @@ -929,6 +951,16 @@ bool Parser::parseDifferentiationParametersClause(
}

// SWIFT_ENABLE_TENSORFLOW
/// Parse a transposed parameters 'wrt:' clause, returning true on error.
///
/// \verbatim
/// transposed-params-clause:
/// 'wrt' ':' (transposed-param | transposed-params)
/// transposed-params:
/// '(' transposed-param (',' transposed-param)* ')'
/// transposed-param:
/// 'self' | [0-9]+
/// \endverbatim
bool Parser::parseTransposedParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {
SyntaxParsingContext TransposeParamsClauseContext(
Expand Down Expand Up @@ -1130,7 +1162,13 @@ bool Parser::parseDifferentiableAttributeArguments(
return false;
}

/// SWIFT_ENABLE_TENSORFLOW
// SWIFT_ENABLE_TENSORFLOW
/// Parse a `@derivative(of:)` attribute, returning true on error.
///
/// \verbatim
/// derivative-attribute-arguments:
/// '(' 'of' ':' decl-name (',' differentiation-params-clause)? ')'
/// \endverbatim
ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
SourceLoc loc) {
StringRef AttrName = "derivative";
Expand Down Expand Up @@ -1196,7 +1234,7 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
SourceRange(loc, rParenLoc), original, params));
}

/// SWIFT_ENABLE_TENSORFLOW
// SWIFT_ENABLE_TENSORFLOW
ParserResult<DerivativeAttr>
Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
StringRef AttrName = "differentiating";
Expand Down Expand Up @@ -1315,6 +1353,12 @@ bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
// SWIFT_ENABLE_TENSORFLOW END

// SWIFT_ENABLE_TENSORFLOW
/// Parse a `@transpose(of:)` attribute, returning true on error.
///
/// \verbatim
/// transpose-attribute-arguments:
/// '(' 'of' ':' decl-name (',' transposed-params-clause)? ')'
/// \endverbatim
ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
SourceLoc loc) {
StringRef AttrName = "transpose";
Expand Down
2 changes: 1 addition & 1 deletion lib/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ using DifferentiabilityKindField = BCFixed<2>;
// module version.
enum class AutoDiffDerivativeFunctionKind : uint8_t {
JVP = 0,
VJP = 1
VJP
};
using AutoDiffDerivativeFunctionKindField = BCFixed<1>;
// SWIFT_ENABLE_TENSORFLOW END
Expand Down
4 changes: 2 additions & 2 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,8 +2030,8 @@ static uint8_t getRawStableVarDeclIntroducer(swift::VarDecl::Introducer intr) {
}

// SWIFT_ENABLE_TENSORFLOW
/// Translate from the AST differentiability kind enum to the Serialization enum
/// values, which are guaranteed to be stable.
/// Translate from the AST derivative function kind enum to the Serialization
/// enum values, which are guaranteed to be stable.
static uint8_t getRawStableAutoDiffDerivativeFunctionKind(
swift::AutoDiffDerivativeFunctionKind kind) {
switch (kind) {
Expand Down
4 changes: 2 additions & 2 deletions lib/TBDGen/TBDGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
AutoDiffLinearMapKind kind) {
auto declRef = SILDeclRef(original);

// Linear maps are only public when the original function is serialized.
// Linear maps are public only when the original function is serialized.
if (!declRef.isSerialized())
return;

// Differentials are only emitted when forward mode is turned on.
// Linear maps are emitted only when forward mode is enabled.
if (kind == AutoDiffLinearMapKind::Differential &&
!original->getASTContext()
.LangOpts.EnableExperimentalForwardModeDifferentiation)
Expand Down
42 changes: 18 additions & 24 deletions test/AutoDiff/derivative_attr_parse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,61 @@

/// Good

@derivative(of: sin) // ok
func jvpSin(x: @nondiff Float)
-> (value: Float, differential: (Float)-> (Float)) {
return (x, { $0 })
}

@derivative(of: sin, wrt: x) // ok
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}

@derivative(of: add, wrt: (x, y)) // ok
func vjpAdd(x: Float, y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
-> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}

extension AdditiveArithmetic where Self : Differentiable {
extension AdditiveArithmetic where Self: Differentiable {
@derivative(of: +) // ok
static func vjpPlus(x: Self, y: Self) -> (value: Self,
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
static func vjpAdd(x: Self, y: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (x + y, { v in (v, v) })
}
}

@derivative(of: linear) // ok
@derivative(of: foo) // ok
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: linear, linear) // ok
/// Bad

// expected-error @+3 {{expected an original function name}}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: 3)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, linear, wrt: x) // ok
@derivative(of: wrt, foo)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

/// Bad

// expected-error @+3 {{expected an original function name}}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: 3)
// expected-error @+1 {{expected a colon ':' after 'wrt'}}
@derivative(of: foo, wrt)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: linear, foo)
@derivative(of: foo, blah, wrt: x)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x, linear)
@derivative(of: foo, wrt: x, blah)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
Expand All @@ -81,13 +75,13 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
}

// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: linear, foo,)
@derivative(of: foo, foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: linear,)
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
Expand Down