Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] forbid derivative registration using @differentiable #30001

Merged
merged 1 commit into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 4 additions & 40 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1675,12 +1675,11 @@ struct DeclNameRefWithLoc {
DeclNameLoc Loc;
};

/// Attribute that marks a function as differentiable and optionally specifies
/// custom associated derivative functions: 'jvp' and 'vjp'.
/// Attribute that marks a function as differentiable.
///
/// Examples:
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
/// @differentiable(where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y))
class DifferentiableAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiableAttr,
Expand All @@ -1696,16 +1695,6 @@ class DifferentiableAttr final
bool Linear;
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameRefWithLoc> JVP;
/// The VJP function.
Optional<DeclNameRefWithLoc> VJP;
/// The JVP function (optional), resolved by the type checker if JVP name is
/// specified.
FuncDecl *JVPFunction = nullptr;
/// The VJP function (optional), resolved by the type checker if VJP name is
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiability parameter indices, resolved by the type checker.
/// The bit stores whether the parameter indices have been computed.
///
Expand All @@ -1724,32 +1713,24 @@ class DifferentiableAttr final
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenericSignature);

public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(AbstractFunctionDecl *original,
bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig);

Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
Expand All @@ -1758,16 +1739,6 @@ class DifferentiableAttr final
/// Should only be used by parsing and deserialization.
void setOriginalDeclaration(Decl *originalDeclaration);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Optional<DeclNameRefWithLoc> getJVP() const { return JVP; }

/// Get the optional 'vjp:' function name and location.
/// Use this instead of `getVJPFunction` to check whether the attribute has a
/// registered VJP.
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }

private:
/// Returns true if the given `@differentiable` attribute has been
/// type-checked.
Expand Down Expand Up @@ -1800,21 +1771,14 @@ class DifferentiableAttr final
DerivativeGenericSignature = derivativeGenSig;
}

FuncDecl *getJVPFunction() const { return JVPFunction; }
void setJVPFunction(FuncDecl *decl);
FuncDecl *getVJPFunction() const { return VJPFunction; }
void setVJPFunction(FuncDecl *decl);

/// Get the derivative generic environment for the given `@differentiable`
/// attribute and original function.
GenericEnvironment *
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false,
bool omitDerivativeFunctions = false) const;
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
Expand Down
12 changes: 3 additions & 9 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1582,21 +1582,15 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
"expected a member name as second parameter in '_implements' attribute", ())

// differentiable
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
"expected a %0 function name", (StringRef))
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to", ())
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
"use 'wrt:' to specify parameters to differentiate with respect to", ())
ERROR(attr_differentiable_expected_label,none,
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
"or 'vjp:'", ())
ERROR(attr_differentiable_expected_label,none,"expected 'wrt:' or 'where'", ())
ERROR(attr_differentiable_unexpected_argument,none,
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none,
// TODO(TF-893): Remove this error after the 0.8 release.
ERROR(attr_differentiable_jvp_vjp_deprecated_error,none,
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
"deprecated; use '@derivative' attribute for derivative registration "
"instead", ())
Expand Down
8 changes: 2 additions & 6 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2914,6 +2914,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
"attribute for transpose registration instead", ())
ERROR(differentiable_attr_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(differentiable_attr_overload_not_found,none,
"%0 does not have expected type %1", (DeclNameRef, Type))
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
Expand All @@ -2938,12 +2940,6 @@ ERROR(differentiable_attr_result_not_differentiable,none,
ERROR(differentiable_attr_protocol_req_where_clause,none,
"'@differentiable' attribute on protocol requirement cannot specify "
"'where' clause", ())
ERROR(differentiable_attr_protocol_req_assoc_func,none,
"'@differentiable' attribute on protocol requirement cannot specify "
"'jvp:' or 'vjp:'", ())
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
"'@differentiable' attribute on stored property cannot specify "
"'jvp:' or 'vjp:'", ())
ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none,
"'@differentiable' attribute cannot be declared on class members "
"returning 'Self'", ())
Expand Down
2 changes: 0 additions & 2 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1010,8 +1010,6 @@ class Parser {
/// Parse the arguments inside the @differentiable attribute.
bool parseDifferentiableAttributeArguments(
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
Optional<DeclNameRefWithLoc> &jvpSpec,
Optional<DeclNameRefWithLoc> &vjpSpec,
TrailingWhereClause *&whereClause);

/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in
Expand Down
56 changes: 7 additions & 49 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,12 +526,9 @@ static std::string getDifferentiationParametersClauseString(
/// Print the arguments of the given `@differentiable` attribute.
/// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
/// parameters clause.
/// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
/// functions.
static void printDifferentiableAttrArguments(
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
const Decl *D, bool omitWrtClause = false,
bool omitDerivativeFunctions = false) {
const Decl *D, bool omitWrtClause = false) {
assert(D);
// Create a temporary string for the attribute argument text.
std::string attrArgText;
Expand Down Expand Up @@ -574,19 +571,6 @@ static void printDifferentiableAttrArguments(
stream << diffParamsString;
}
}
// Print derivative function names, unless they are to be omitted.
if (!omitDerivativeFunctions) {
// Print jvp function name, if specified.
if (auto jvp = attr->getJVP()) {
printCommaIfNecessary();
stream << "jvp: " << jvp->Name;
}
// Print vjp function name, if specified.
if (auto vjp = attr->getVJP()) {
printCommaIfNecessary();
stream << "vjp: " << vjp->Name;
}
}
// Print 'where' clause, if any.
// First, filter out requirements satisfied by the original function's
// generic signature. They should not be printed.
Expand Down Expand Up @@ -1616,12 +1600,9 @@ SPIAccessControlAttr::create(ASTContext &context,
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
TrailingWhereClause *clause)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
VJP(std::move(vjp)), WhereClause(clause) {
Linear(linear), NumParsedParameters(params.size()), WhereClause(clause) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}
Expand All @@ -1630,12 +1611,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
OriginalDeclaration(original), Linear(linear), JVP(std::move(jvp)),
VJP(std::move(vjp)) {
OriginalDeclaration(original), Linear(linear) {
setParameterIndices(parameterIndices);
setDerivativeGenericSignature(derivativeGenSig);
}
Expand All @@ -1645,29 +1623,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
TrailingWhereClause *clause) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
parameters, std::move(jvp),
std::move(vjp), clause);
parameters, clause);
}

DifferentiableAttr *
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig) {
auto &ctx = original->getASTContext();
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
linear, parameterIndices, std::move(jvp),
std::move(vjp), derivativeGenSig);
linear, parameterIndices, derivativeGenSig);
}

void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
Expand Down Expand Up @@ -1701,18 +1673,6 @@ void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
std::move(paramIndices));
}

void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
JVPFunction = decl;
if (decl && !JVP)
JVP = {decl->createNameRef(), DeclNameLoc(decl->getNameLoc())};
}

void DifferentiableAttr::setVJPFunction(FuncDecl *decl) {
VJPFunction = decl;
if (decl && !VJP)
VJP = {decl->createNameRef(), DeclNameLoc(decl->getNameLoc())};
}

GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
AbstractFunctionDecl *original) const {
GenericEnvironment *derivativeGenEnv = original->getGenericEnvironment();
Expand All @@ -1722,12 +1682,10 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
}

void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause,
bool omitDerivativeFunctions) const {
bool omitWrtClause) const {
StreamPrinter P(OS);
P << "@" << getAttrName();
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause,
omitDerivativeFunctions);
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause);
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
Expand Down
Loading