diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index 754fac8d45bd2..abf1af7580048 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -717,15 +717,6 @@ class Parser { return Tok.is(tok::identifier) && Tok.getText() == value; } - /// Returns true if token is the identifier "wrt". - bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); } - - /// Returns true if token is the identifier "jvp". - bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); } - - /// Returns true if token is the identifier "vjp". - bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); } - /// Consume the starting '<' of the current token, which may either /// be a complete '<' token or some kind of operator token starting with '<', /// e.g., '<>'. diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 87c5bca84dce2..3fdb973a67805 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -353,16 +353,19 @@ static void printShortFormAvailable(ArrayRef Attrs, Printer.printNewline(); } +// Returns the differentiation parameters clause string for the given function, +// parameter indices, and parsed parameters. static std::string getDifferentiationParametersClauseString( - const AbstractFunctionDecl *function, IndexSubset *indices, + const AbstractFunctionDecl *function, IndexSubset *paramIndices, ArrayRef parsedParams) { - bool isInstanceMethod = function && function->isInstanceMember(); + assert(function); + bool isInstanceMethod = function->isInstanceMember(); std::string result; llvm::raw_string_ostream printer(result); - // Use parameter indices from `IndexSubset`, if specified. - if (indices) { - auto parameters = indices->getBitVector(); + // Use the parameter indices, if specified. + if (paramIndices) { + auto parameters = paramIndices->getBitVector(); auto parameterCount = parameters.count(); printer << "wrt: "; if (parameterCount > 1) @@ -410,19 +413,25 @@ static std::string getDifferentiationParametersClauseString( } // Print the arguments of the given `@differentiable` attribute. +// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation +// parameters clause. +// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative +// functions. static void printDifferentiableAttrArguments( const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options, const Decl *D, bool omitWrtClause = false, - bool omitAssociatedFunctions = false) { + bool omitDerivativeFunctions = false) { + assert(D); // Create a temporary string for the attribute argument text. std::string attrArgText; llvm::raw_string_ostream stream(attrArgText); // Get original function. - auto *original = dyn_cast_or_null(D); + auto *original = dyn_cast(D); // Handle stored/computed properties and subscript methods. - if (auto *asd = dyn_cast_or_null(D)) + if (auto *asd = dyn_cast(D)) original = asd->getAccessor(AccessorKind::Get); + assert(original && "Must resolve original declaration"); // Print comma if not leading clause. bool isLeadingClause = true; @@ -440,7 +449,7 @@ static void printDifferentiableAttrArguments( stream << "linear"; } - // Print differentiation parameters, unless they are to be omitted. + // Print differentiation parameters clause, unless it is to be omitted. if (!omitWrtClause) { auto diffParamsString = getDifferentiationParametersClauseString( original, attr->getParameterIndices(), attr->getParsedParameters()); @@ -453,8 +462,8 @@ static void printDifferentiableAttrArguments( stream << diffParamsString; } } - // Print associated function names, unless they are to be omitted. - if (!omitAssociatedFunctions) { + // Print derivative function names, unless they are to be omitted. + if (!omitDerivativeFunctions) { // Print jvp function name, if specified. if (auto jvp = attr->getJVP()) { printCommaIfNecessary(); diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 8e8efc9a30a0d..436b9c1bb6b03 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -941,13 +941,13 @@ bool Parser::parseDifferentiableAttributeArguments( diagnose(Tok, diag::unexpected_separator, ","); return true; } - // Check that token after comma is 'wrt:' or a function specifier label. - if (!(isWRTIdentifier(Tok) || isJVPIdentifier(Tok) || - isVJPIdentifier(Tok))) { - diagnose(Tok, diag::attr_differentiable_expected_label); - return true; + // Check that token after comma is 'wrt' or a function specifier label. + if (isIdentifier(Tok, "wrt") || isIdentifier(Tok, "jvp") || + isIdentifier(Tok, "vjp")) { + return false; } - return false; + diagnose(Tok, diag::attr_differentiable_expected_label); + return true; }; // Store starting parser position. @@ -958,7 +958,7 @@ bool Parser::parseDifferentiableAttributeArguments( // Parse optional differentiation parameters. // Parse 'linear' label (optional). linear = false; - if (Tok.is(tok::identifier) && Tok.getText() == "linear") { + if (isIdentifier(Tok, "linear")) { linear = true; consumeToken(tok::identifier); // If no trailing comma or 'where' clause, terminate parsing arguments. @@ -969,14 +969,15 @@ bool Parser::parseDifferentiableAttributeArguments( } // If 'withRespectTo' is used, make the user change it to 'wrt'. - if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") { + if (isIdentifier(Tok, "withRespectTo")) { SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc()); diagnose(Tok, diag::attr_differentiable_use_wrt_not_withrespectto) .highlight(withRespectToRange) .fixItReplace(withRespectToRange, "wrt:"); return errorAndSkipToEnd(); } - if (isWRTIdentifier(Tok)) { + // Parse differentiation parameters' clause. + if (isIdentifier(Tok, "wrt")) { if (parseDifferentiationParametersClause(params, AttrName)) return true; // If no trailing comma or 'where' clause, terminate parsing arguments. @@ -1014,7 +1015,7 @@ bool Parser::parseDifferentiableAttributeArguments( bool terminateParsingArgs = false; // Parse 'jvp: ' (optional). - if (isJVPIdentifier(Tok)) { + if (isIdentifier(Tok, "jvp")) { SyntaxParsingContext JvpContext( SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier); jvpSpec = DeclNameWithLoc(); @@ -1027,7 +1028,7 @@ bool Parser::parseDifferentiableAttributeArguments( } // Parse 'vjp: ' (optional). - if (isVJPIdentifier(Tok)) { + if (isIdentifier(Tok, "vjp")) { SyntaxParsingContext VjpContext( SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier); vjpSpec = DeclNameWithLoc();