From 684f12775016ed93426fa9a0f34fee1081e3bcba Mon Sep 17 00:00:00 2001
From: Dan Zheng <danielzheng@google.com>
Date: Tue, 19 Nov 2019 20:58:36 -0800
Subject: [PATCH] [AutoDiff upstream] Clean up parsing and printing.

- Use general `Parser::isIdentifier(Token, StringRef)` function.
  - Remove specialized `isWRTIdentifier`, `isJVPIdentifier`, `isVJPIdentifier`
    functions from `Parser`.
- Clarify doc comments and parameter nullability for attribute printing code:
  `getDifferentiationParametersClauseString`.
- Minor formatting and naming updates.
---
 include/swift/Parse/Parser.h |  9 ---------
 lib/AST/Attr.cpp             | 31 ++++++++++++++++++++-----------
 lib/Parse/ParseDecl.cpp      | 23 ++++++++++++-----------
 3 files changed, 32 insertions(+), 31 deletions(-)

diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h
index 754fac8d45bd2..abf1af7580048 100644
--- a/include/swift/Parse/Parser.h
+++ b/include/swift/Parse/Parser.h
@@ -717,15 +717,6 @@ class Parser {
     return Tok.is(tok::identifier) && Tok.getText() == value;
   }
 
-  /// Returns true if token is the identifier "wrt".
-  bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); }
-
-  /// Returns true if token is the identifier "jvp".
-  bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); }
-
-  /// Returns true if token is the identifier "vjp".
-  bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); }
-
   /// Consume the starting '<' of the current token, which may either
   /// be a complete '<' token or some kind of operator token starting with '<',
   /// e.g., '<>'.
diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp
index 87c5bca84dce2..3fdb973a67805 100644
--- a/lib/AST/Attr.cpp
+++ b/lib/AST/Attr.cpp
@@ -353,16 +353,19 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
   Printer.printNewline();
 }
 
+// Returns the differentiation parameters clause string for the given function,
+// parameter indices, and parsed parameters.
 static std::string getDifferentiationParametersClauseString(
-    const AbstractFunctionDecl *function, IndexSubset *indices,
+    const AbstractFunctionDecl *function, IndexSubset *paramIndices,
     ArrayRef<ParsedAutoDiffParameter> parsedParams) {
-  bool isInstanceMethod = function && function->isInstanceMember();
+  assert(function);
+  bool isInstanceMethod = function->isInstanceMember();
   std::string result;
   llvm::raw_string_ostream printer(result);
 
-  // Use parameter indices from `IndexSubset`, if specified.
-  if (indices) {
-    auto parameters = indices->getBitVector();
+  // Use the parameter indices, if specified.
+  if (paramIndices) {
+    auto parameters = paramIndices->getBitVector();
     auto parameterCount = parameters.count();
     printer << "wrt: ";
     if (parameterCount > 1)
@@ -410,19 +413,25 @@ static std::string getDifferentiationParametersClauseString(
 }
 
 // Print the arguments of the given `@differentiable` attribute.
+// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
+//   parameters clause.
+// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
+//   functions.
 static void printDifferentiableAttrArguments(
     const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
     const Decl *D, bool omitWrtClause = false,
-    bool omitAssociatedFunctions = false) {
+    bool omitDerivativeFunctions = false) {
+  assert(D);
   // Create a temporary string for the attribute argument text.
   std::string attrArgText;
   llvm::raw_string_ostream stream(attrArgText);
 
   // Get original function.
-  auto *original = dyn_cast_or_null<AbstractFunctionDecl>(D);
+  auto *original = dyn_cast<AbstractFunctionDecl>(D);
   // Handle stored/computed properties and subscript methods.
-  if (auto *asd = dyn_cast_or_null<AbstractStorageDecl>(D))
+  if (auto *asd = dyn_cast<AbstractStorageDecl>(D))
     original = asd->getAccessor(AccessorKind::Get);
+  assert(original && "Must resolve original declaration");
 
   // Print comma if not leading clause.
   bool isLeadingClause = true;
@@ -440,7 +449,7 @@ static void printDifferentiableAttrArguments(
     stream << "linear";
   }
 
-  // Print differentiation parameters, unless they are to be omitted.
+  // Print differentiation parameters clause, unless it is to be omitted.
   if (!omitWrtClause) {
     auto diffParamsString = getDifferentiationParametersClauseString(
         original, attr->getParameterIndices(), attr->getParsedParameters());
@@ -453,8 +462,8 @@ static void printDifferentiableAttrArguments(
       stream << diffParamsString;
     }
   }
-  // Print associated function names, unless they are to be omitted.
-  if (!omitAssociatedFunctions) {
+  // Print derivative function names, unless they are to be omitted.
+  if (!omitDerivativeFunctions) {
     // Print jvp function name, if specified.
     if (auto jvp = attr->getJVP()) {
       printCommaIfNecessary();
diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp
index 8e8efc9a30a0d..436b9c1bb6b03 100644
--- a/lib/Parse/ParseDecl.cpp
+++ b/lib/Parse/ParseDecl.cpp
@@ -941,13 +941,13 @@ bool Parser::parseDifferentiableAttributeArguments(
       diagnose(Tok, diag::unexpected_separator, ",");
       return true;
     }
-    // Check that token after comma is 'wrt:' or a function specifier label.
-    if (!(isWRTIdentifier(Tok) || isJVPIdentifier(Tok) ||
-          isVJPIdentifier(Tok))) {
-      diagnose(Tok, diag::attr_differentiable_expected_label);
-      return true;
+    // Check that token after comma is 'wrt' or a function specifier label.
+    if (isIdentifier(Tok, "wrt") || isIdentifier(Tok, "jvp") ||
+        isIdentifier(Tok, "vjp")) {
+      return false;
     }
-    return false;
+    diagnose(Tok, diag::attr_differentiable_expected_label);
+    return true;
   };
 
   // Store starting parser position.
@@ -958,7 +958,7 @@ bool Parser::parseDifferentiableAttributeArguments(
   // Parse optional differentiation parameters.
   // Parse 'linear' label (optional).
   linear = false;
-  if (Tok.is(tok::identifier) && Tok.getText() == "linear") {
+  if (isIdentifier(Tok, "linear")) {
     linear = true;
     consumeToken(tok::identifier);
     // If no trailing comma or 'where' clause, terminate parsing arguments.
@@ -969,14 +969,15 @@ bool Parser::parseDifferentiableAttributeArguments(
   }
 
   // If 'withRespectTo' is used, make the user change it to 'wrt'.
-  if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") {
+  if (isIdentifier(Tok, "withRespectTo")) {
     SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc());
     diagnose(Tok, diag::attr_differentiable_use_wrt_not_withrespectto)
         .highlight(withRespectToRange)
         .fixItReplace(withRespectToRange, "wrt:");
     return errorAndSkipToEnd();
   }
-  if (isWRTIdentifier(Tok)) {
+  // Parse differentiation parameters' clause.
+  if (isIdentifier(Tok, "wrt")) {
     if (parseDifferentiationParametersClause(params, AttrName))
       return true;
     // If no trailing comma or 'where' clause, terminate parsing arguments.
@@ -1014,7 +1015,7 @@ bool Parser::parseDifferentiableAttributeArguments(
   bool terminateParsingArgs = false;
 
   // Parse 'jvp: <func_name>' (optional).
-  if (isJVPIdentifier(Tok)) {
+  if (isIdentifier(Tok, "jvp")) {
     SyntaxParsingContext JvpContext(
         SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
     jvpSpec = DeclNameWithLoc();
@@ -1027,7 +1028,7 @@ bool Parser::parseDifferentiableAttributeArguments(
   }
 
   // Parse 'vjp: <func_name>' (optional).
-  if (isVJPIdentifier(Tok)) {
+  if (isIdentifier(Tok, "vjp")) {
     SyntaxParsingContext VjpContext(
         SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
     vjpSpec = DeclNameWithLoc();