@@ -353,16 +353,19 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
353
353
Printer.printNewline ();
354
354
}
355
355
356
+ // Returns the differentiation parameters clause string for the given function,
357
+ // parameter indices, and parsed parameters.
356
358
static std::string getDifferentiationParametersClauseString (
357
- const AbstractFunctionDecl *function, IndexSubset *indices ,
359
+ const AbstractFunctionDecl *function, IndexSubset *paramIndices ,
358
360
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
359
- bool isInstanceMethod = function && function->isInstanceMember ();
361
+ assert (function);
362
+ bool isInstanceMethod = function->isInstanceMember ();
360
363
std::string result;
361
364
llvm::raw_string_ostream printer (result);
362
365
363
- // Use parameter indices from `IndexSubset` , if specified.
364
- if (indices ) {
365
- auto parameters = indices ->getBitVector ();
366
+ // Use the parameter indices, if specified.
367
+ if (paramIndices ) {
368
+ auto parameters = paramIndices ->getBitVector ();
366
369
auto parameterCount = parameters.count ();
367
370
printer << " wrt: " ;
368
371
if (parameterCount > 1 )
@@ -410,19 +413,25 @@ static std::string getDifferentiationParametersClauseString(
410
413
}
411
414
412
415
// Print the arguments of the given `@differentiable` attribute.
416
+ // - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
417
+ // parameters clause.
418
+ // - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
419
+ // functions.
413
420
static void printDifferentiableAttrArguments (
414
421
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
415
422
const Decl *D, bool omitWrtClause = false ,
416
- bool omitAssociatedFunctions = false ) {
423
+ bool omitDerivativeFunctions = false ) {
424
+ assert (D);
417
425
// Create a temporary string for the attribute argument text.
418
426
std::string attrArgText;
419
427
llvm::raw_string_ostream stream (attrArgText);
420
428
421
429
// Get original function.
422
- auto *original = dyn_cast_or_null <AbstractFunctionDecl>(D);
430
+ auto *original = dyn_cast <AbstractFunctionDecl>(D);
423
431
// Handle stored/computed properties and subscript methods.
424
- if (auto *asd = dyn_cast_or_null <AbstractStorageDecl>(D))
432
+ if (auto *asd = dyn_cast <AbstractStorageDecl>(D))
425
433
original = asd->getAccessor (AccessorKind::Get);
434
+ assert (original && " Must resolve original declaration" );
426
435
427
436
// Print comma if not leading clause.
428
437
bool isLeadingClause = true ;
@@ -440,7 +449,7 @@ static void printDifferentiableAttrArguments(
440
449
stream << " linear" ;
441
450
}
442
451
443
- // Print differentiation parameters, unless they are to be omitted.
452
+ // Print differentiation parameters clause , unless it is to be omitted.
444
453
if (!omitWrtClause) {
445
454
auto diffParamsString = getDifferentiationParametersClauseString (
446
455
original, attr->getParameterIndices (), attr->getParsedParameters ());
@@ -453,8 +462,8 @@ static void printDifferentiableAttrArguments(
453
462
stream << diffParamsString;
454
463
}
455
464
}
456
- // Print associated function names, unless they are to be omitted.
457
- if (!omitAssociatedFunctions ) {
465
+ // Print derivative function names, unless they are to be omitted.
466
+ if (!omitDerivativeFunctions ) {
458
467
// Print jvp function name, if specified.
459
468
if (auto jvp = attr->getJVP ()) {
460
469
printCommaIfNecessary ();
0 commit comments