Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff upstream] Clean up parsing and printing. #28377

Merged
merged 1 commit into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,15 +717,6 @@ class Parser {
return Tok.is(tok::identifier) && Tok.getText() == value;
}

/// Returns true if token is the identifier "wrt".
bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); }

/// Returns true if token is the identifier "jvp".
bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); }

/// Returns true if token is the identifier "vjp".
bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); }

/// Consume the starting '<' of the current token, which may either
/// be a complete '<' token or some kind of operator token starting with '<',
/// e.g., '<>'.
Expand Down
31 changes: 20 additions & 11 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,16 +353,19 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
Printer.printNewline();
}

// Returns the differentiation parameters clause string for the given function,
// parameter indices, and parsed parameters.
static std::string getDifferentiationParametersClauseString(
const AbstractFunctionDecl *function, IndexSubset *indices,
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
bool isInstanceMethod = function && function->isInstanceMember();
assert(function);
bool isInstanceMethod = function->isInstanceMember();
std::string result;
llvm::raw_string_ostream printer(result);

// Use parameter indices from `IndexSubset`, if specified.
if (indices) {
auto parameters = indices->getBitVector();
// Use the parameter indices, if specified.
if (paramIndices) {
auto parameters = paramIndices->getBitVector();
auto parameterCount = parameters.count();
printer << "wrt: ";
if (parameterCount > 1)
Expand Down Expand Up @@ -410,19 +413,25 @@ static std::string getDifferentiationParametersClauseString(
}

// Print the arguments of the given `@differentiable` attribute.
// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
// parameters clause.
// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
// functions.
static void printDifferentiableAttrArguments(
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
const Decl *D, bool omitWrtClause = false,
bool omitAssociatedFunctions = false) {
bool omitDerivativeFunctions = false) {
assert(D);
// Create a temporary string for the attribute argument text.
std::string attrArgText;
llvm::raw_string_ostream stream(attrArgText);

// Get original function.
auto *original = dyn_cast_or_null<AbstractFunctionDecl>(D);
auto *original = dyn_cast<AbstractFunctionDecl>(D);
// Handle stored/computed properties and subscript methods.
if (auto *asd = dyn_cast_or_null<AbstractStorageDecl>(D))
if (auto *asd = dyn_cast<AbstractStorageDecl>(D))
original = asd->getAccessor(AccessorKind::Get);
assert(original && "Must resolve original declaration");

// Print comma if not leading clause.
bool isLeadingClause = true;
Expand All @@ -440,7 +449,7 @@ static void printDifferentiableAttrArguments(
stream << "linear";
}

// Print differentiation parameters, unless they are to be omitted.
// Print differentiation parameters clause, unless it is to be omitted.
if (!omitWrtClause) {
auto diffParamsString = getDifferentiationParametersClauseString(
original, attr->getParameterIndices(), attr->getParsedParameters());
Expand All @@ -453,8 +462,8 @@ static void printDifferentiableAttrArguments(
stream << diffParamsString;
}
}
// Print associated function names, unless they are to be omitted.
if (!omitAssociatedFunctions) {
// Print derivative function names, unless they are to be omitted.
if (!omitDerivativeFunctions) {
// Print jvp function name, if specified.
if (auto jvp = attr->getJVP()) {
printCommaIfNecessary();
Expand Down
23 changes: 12 additions & 11 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,13 +941,13 @@ bool Parser::parseDifferentiableAttributeArguments(
diagnose(Tok, diag::unexpected_separator, ",");
return true;
}
// Check that token after comma is 'wrt:' or a function specifier label.
if (!(isWRTIdentifier(Tok) || isJVPIdentifier(Tok) ||
isVJPIdentifier(Tok))) {
diagnose(Tok, diag::attr_differentiable_expected_label);
return true;
// Check that token after comma is 'wrt' or a function specifier label.
if (isIdentifier(Tok, "wrt") || isIdentifier(Tok, "jvp") ||
isIdentifier(Tok, "vjp")) {
return false;
}
return false;
diagnose(Tok, diag::attr_differentiable_expected_label);
return true;
};

// Store starting parser position.
Expand All @@ -958,7 +958,7 @@ bool Parser::parseDifferentiableAttributeArguments(
// Parse optional differentiation parameters.
// Parse 'linear' label (optional).
linear = false;
if (Tok.is(tok::identifier) && Tok.getText() == "linear") {
if (isIdentifier(Tok, "linear")) {
linear = true;
consumeToken(tok::identifier);
// If no trailing comma or 'where' clause, terminate parsing arguments.
Expand All @@ -969,14 +969,15 @@ bool Parser::parseDifferentiableAttributeArguments(
}

// If 'withRespectTo' is used, make the user change it to 'wrt'.
if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") {
if (isIdentifier(Tok, "withRespectTo")) {
SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc());
diagnose(Tok, diag::attr_differentiable_use_wrt_not_withrespectto)
.highlight(withRespectToRange)
.fixItReplace(withRespectToRange, "wrt:");
return errorAndSkipToEnd();
}
if (isWRTIdentifier(Tok)) {
// Parse differentiation parameters' clause.
if (isIdentifier(Tok, "wrt")) {
if (parseDifferentiationParametersClause(params, AttrName))
return true;
// If no trailing comma or 'where' clause, terminate parsing arguments.
Expand Down Expand Up @@ -1014,7 +1015,7 @@ bool Parser::parseDifferentiableAttributeArguments(
bool terminateParsingArgs = false;

// Parse 'jvp: <func_name>' (optional).
if (isJVPIdentifier(Tok)) {
if (isIdentifier(Tok, "jvp")) {
SyntaxParsingContext JvpContext(
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
jvpSpec = DeclNameWithLoc();
Expand All @@ -1027,7 +1028,7 @@ bool Parser::parseDifferentiableAttributeArguments(
}

// Parse 'vjp: <func_name>' (optional).
if (isVJPIdentifier(Tok)) {
if (isIdentifier(Tok, "vjp")) {
SyntaxParsingContext VjpContext(
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
vjpSpec = DeclNameWithLoc();
Expand Down