Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff upstream] Add @derivative(of:) attribute. #28321

Merged
merged 1 commit into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ DECL_ATTR(_originallyDefinedIn, OriginallyDefinedIn,
ABIBreakingToAdd | ABIBreakingToRemove | APIStableToAdd | APIStableToRemove,
96)

DECL_ATTR(derivative, Derivative,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
97)

#undef TYPE_ATTR
#undef DECL_ATTR_ALIAS
#undef CONTEXTUAL_DECL_ATTR_ALIAS
Expand Down
78 changes: 78 additions & 0 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,84 @@ class DifferentiableAttr final
}
};

/// Attribute that registers a function as a derivative of another function.
///
/// Examples:
/// @derivative(of: sin(_:))
/// @derivative(of: +, wrt: (lhs, rhs))
class DerivativeAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The original function name.
DeclNameWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The derivative function kind (JVP or VJP), resolved by the type checker.
Optional<AutoDiffDerivativeFunctionKind> Kind = None;

explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, IndexSubset *indices);

public:
static DerivativeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static DerivativeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, IndexSubset *indices);

DeclNameWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
}
AbstractFunctionDecl *getOriginalFunction() const {
return OriginalFunction;
}
void setOriginalFunction(AbstractFunctionDecl *decl) {
OriginalFunction = decl;
}

AutoDiffDerivativeFunctionKind getDerivativeKind() const {
assert(Kind && "Derivative function kind has not yet been resolved");
return *Kind;
}
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }

/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Derivative;
}
};

/// Attributes that may be applied to declarations.
class DeclAttributes {
/// Linked list of declaration attributes.
Expand Down
47 changes: 41 additions & 6 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,47 @@

namespace swift {

/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};

/// The kind of an linear map.
struct AutoDiffLinearMapKind {
enum innerty : uint8_t {
// The differential function.
Differential = 0,
// The pullback function.
Pullback = 1
} rawValue;

AutoDiffLinearMapKind() = default;
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
operator innerty() const { return rawValue; }
};

/// The kind of a derivative function.
struct AutoDiffDerivativeFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
} rawValue;

AutoDiffDerivativeFunctionKind() = default;
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
explicit AutoDiffDerivativeFunctionKind(StringRef string);
operator innerty() const { return rawValue; }
AutoDiffLinearMapKind getLinearMapKind() {
return (AutoDiffLinearMapKind::innerty)rawValue;
}
};

class ParsedAutoDiffParameter {
public:
enum class Kind { Named, Ordered, Self };
Expand Down Expand Up @@ -89,12 +130,6 @@ class ParsedAutoDiffParameter {
}
};

enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};

} // end namespace swift

#endif // SWIFT_AST_AUTODIFF_H
9 changes: 9 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,11 @@ ERROR(attr_expected_comma,none,
ERROR(attr_expected_string_literal,none,
"expected string literal in '%0' attribute", (StringRef))

ERROR(attr_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
ERROR(attr_expected_label,none,
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))

ERROR(alignment_must_be_positive_integer,none,
"alignment value must be a positive integer literal", ())

Expand Down Expand Up @@ -1550,6 +1555,10 @@ ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
"expected a parameter, which can be a function parameter name, "
"parameter index, or 'self'", ())

// derivative
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
"expected an original function name", ())

//------------------------------------------------------------------------------
// MARK: Generics parsing diagnostics
//------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,10 @@ class Parser {
bool parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);

/// Parse the @derivative attribute.
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse a specific attribute.
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);

Expand Down
44 changes: 42 additions & 2 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,10 +1027,12 @@ StringRef DeclAttribute::getAttrName() const {
return "<<custom>>";
case DAK_ProjectedValueProperty:
return "_projectedValueProperty";
case DAK_Differentiable:
return "differentiable";
case DAK_OriginallyDefinedIn:
return "_originallyDefinedIn";
case DAK_Differentiable:
return "differentiable";
case DAK_Derivative:
return "derivative";
}
llvm_unreachable("bad DeclAttrKind");
}
Expand Down Expand Up @@ -1450,6 +1452,44 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
omitAssociatedFunctions);
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc originalName,
IndexSubset *indices)
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
}

DerivativeAttr *
DerivativeAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(DerivativeAttr));
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
std::move(originalName), params);
}

DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc originalName,
IndexSubset *indices) {
void *mem = context.Allocate(sizeof(DerivativeAttr), alignof(DerivativeAttr));
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
std::move(originalName), indices);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
TypeLoc ProtocolType,
DeclName MemberName,
Expand Down
Loading