Skip to content

Commit baed359

Browse files
authored
[AutoDiff upstream] Clean up parsing and printing. (#28377)
- Use general `Parser::isIdentifier(Token, StringRef)` function. - Remove specialized `isWRTIdentifier`, `isJVPIdentifier`, `isVJPIdentifier` functions from `Parser`. - Clarify doc comments and parameter nullability for attribute printing code: `getDifferentiationParametersClauseString`. - Minor formatting and naming updates.
1 parent 2e65a14 commit baed359

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

include/swift/Parse/Parser.h

-9
Original file line numberDiff line numberDiff line change
@@ -717,15 +717,6 @@ class Parser {
717717
return Tok.is(tok::identifier) && Tok.getText() == value;
718718
}
719719

720-
/// Returns true if token is the identifier "wrt".
721-
bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); }
722-
723-
/// Returns true if token is the identifier "jvp".
724-
bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); }
725-
726-
/// Returns true if token is the identifier "vjp".
727-
bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); }
728-
729720
/// Consume the starting '<' of the current token, which may either
730721
/// be a complete '<' token or some kind of operator token starting with '<',
731722
/// e.g., '<>'.

lib/AST/Attr.cpp

+20-11
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,19 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
353353
Printer.printNewline();
354354
}
355355

356+
// Returns the differentiation parameters clause string for the given function,
357+
// parameter indices, and parsed parameters.
356358
static std::string getDifferentiationParametersClauseString(
357-
const AbstractFunctionDecl *function, IndexSubset *indices,
359+
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
358360
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
359-
bool isInstanceMethod = function && function->isInstanceMember();
361+
assert(function);
362+
bool isInstanceMethod = function->isInstanceMember();
360363
std::string result;
361364
llvm::raw_string_ostream printer(result);
362365

363-
// Use parameter indices from `IndexSubset`, if specified.
364-
if (indices) {
365-
auto parameters = indices->getBitVector();
366+
// Use the parameter indices, if specified.
367+
if (paramIndices) {
368+
auto parameters = paramIndices->getBitVector();
366369
auto parameterCount = parameters.count();
367370
printer << "wrt: ";
368371
if (parameterCount > 1)
@@ -410,19 +413,25 @@ static std::string getDifferentiationParametersClauseString(
410413
}
411414

412415
// Print the arguments of the given `@differentiable` attribute.
416+
// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
417+
// parameters clause.
418+
// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
419+
// functions.
413420
static void printDifferentiableAttrArguments(
414421
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
415422
const Decl *D, bool omitWrtClause = false,
416-
bool omitAssociatedFunctions = false) {
423+
bool omitDerivativeFunctions = false) {
424+
assert(D);
417425
// Create a temporary string for the attribute argument text.
418426
std::string attrArgText;
419427
llvm::raw_string_ostream stream(attrArgText);
420428

421429
// Get original function.
422-
auto *original = dyn_cast_or_null<AbstractFunctionDecl>(D);
430+
auto *original = dyn_cast<AbstractFunctionDecl>(D);
423431
// Handle stored/computed properties and subscript methods.
424-
if (auto *asd = dyn_cast_or_null<AbstractStorageDecl>(D))
432+
if (auto *asd = dyn_cast<AbstractStorageDecl>(D))
425433
original = asd->getAccessor(AccessorKind::Get);
434+
assert(original && "Must resolve original declaration");
426435

427436
// Print comma if not leading clause.
428437
bool isLeadingClause = true;
@@ -440,7 +449,7 @@ static void printDifferentiableAttrArguments(
440449
stream << "linear";
441450
}
442451

443-
// Print differentiation parameters, unless they are to be omitted.
452+
// Print differentiation parameters clause, unless it is to be omitted.
444453
if (!omitWrtClause) {
445454
auto diffParamsString = getDifferentiationParametersClauseString(
446455
original, attr->getParameterIndices(), attr->getParsedParameters());
@@ -453,8 +462,8 @@ static void printDifferentiableAttrArguments(
453462
stream << diffParamsString;
454463
}
455464
}
456-
// Print associated function names, unless they are to be omitted.
457-
if (!omitAssociatedFunctions) {
465+
// Print derivative function names, unless they are to be omitted.
466+
if (!omitDerivativeFunctions) {
458467
// Print jvp function name, if specified.
459468
if (auto jvp = attr->getJVP()) {
460469
printCommaIfNecessary();

lib/Parse/ParseDecl.cpp

+12-11
Original file line numberDiff line numberDiff line change
@@ -941,13 +941,13 @@ bool Parser::parseDifferentiableAttributeArguments(
941941
diagnose(Tok, diag::unexpected_separator, ",");
942942
return true;
943943
}
944-
// Check that token after comma is 'wrt:' or a function specifier label.
945-
if (!(isWRTIdentifier(Tok) || isJVPIdentifier(Tok) ||
946-
isVJPIdentifier(Tok))) {
947-
diagnose(Tok, diag::attr_differentiable_expected_label);
948-
return true;
944+
// Check that token after comma is 'wrt' or a function specifier label.
945+
if (isIdentifier(Tok, "wrt") || isIdentifier(Tok, "jvp") ||
946+
isIdentifier(Tok, "vjp")) {
947+
return false;
949948
}
950-
return false;
949+
diagnose(Tok, diag::attr_differentiable_expected_label);
950+
return true;
951951
};
952952

953953
// Store starting parser position.
@@ -958,7 +958,7 @@ bool Parser::parseDifferentiableAttributeArguments(
958958
// Parse optional differentiation parameters.
959959
// Parse 'linear' label (optional).
960960
linear = false;
961-
if (Tok.is(tok::identifier) && Tok.getText() == "linear") {
961+
if (isIdentifier(Tok, "linear")) {
962962
linear = true;
963963
consumeToken(tok::identifier);
964964
// If no trailing comma or 'where' clause, terminate parsing arguments.
@@ -969,14 +969,15 @@ bool Parser::parseDifferentiableAttributeArguments(
969969
}
970970

971971
// If 'withRespectTo' is used, make the user change it to 'wrt'.
972-
if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") {
972+
if (isIdentifier(Tok, "withRespectTo")) {
973973
SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc());
974974
diagnose(Tok, diag::attr_differentiable_use_wrt_not_withrespectto)
975975
.highlight(withRespectToRange)
976976
.fixItReplace(withRespectToRange, "wrt:");
977977
return errorAndSkipToEnd();
978978
}
979-
if (isWRTIdentifier(Tok)) {
979+
// Parse differentiation parameters' clause.
980+
if (isIdentifier(Tok, "wrt")) {
980981
if (parseDifferentiationParametersClause(params, AttrName))
981982
return true;
982983
// If no trailing comma or 'where' clause, terminate parsing arguments.
@@ -1014,7 +1015,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10141015
bool terminateParsingArgs = false;
10151016

10161017
// Parse 'jvp: <func_name>' (optional).
1017-
if (isJVPIdentifier(Tok)) {
1018+
if (isIdentifier(Tok, "jvp")) {
10181019
SyntaxParsingContext JvpContext(
10191020
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
10201021
jvpSpec = DeclNameWithLoc();
@@ -1027,7 +1028,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10271028
}
10281029

10291030
// Parse 'vjp: <func_name>' (optional).
1030-
if (isVJPIdentifier(Tok)) {
1031+
if (isIdentifier(Tok, "vjp")) {
10311032
SyntaxParsingContext VjpContext(
10321033
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
10331034
vjpSpec = DeclNameWithLoc();

0 commit comments

Comments
 (0)