Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Rename @transposing to @transpose(of:). #28488

Merged
merged 5 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
OnVar |
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
94)
DECL_ATTR(transposing, Transposing,
DECL_ATTR(transpose, Transpose,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
NotSerialized, 96)
Expand Down
48 changes: 23 additions & 25 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,12 +1743,11 @@ using DifferentiatingAttr = DerivativeAttr;
/// Attribute that registers a function as a transpose of another function.
///
/// Examples:
/// @transposing(foo)
/// @transposing(+, wrt: (lhs, rhs))
class TransposingAttr final
: public DeclAttribute,
private llvm::TrailingObjects<TransposingAttr,
ParsedAutoDiffParameter> {
/// @transpose(of: foo)
/// @transpose(of: +, wrt: (lhs, rhs))
class TransposeAttr final
: public DeclAttribute,
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The base type of the original function.
Expand All @@ -1761,28 +1760,27 @@ class TransposingAttr final
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
/// The transposed parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original, IndexSubset *indices);
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
IndexSubset *indices);

public:
static TransposingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);
static TransposeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static TransposingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
IndexSubset *indices);
static TransposeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
IndexSubset *indices);

TypeRepr *getBaseType() const { return BaseType; }
DeclNameWithLoc getOriginalFunctionName() const {
Expand All @@ -1795,8 +1793,8 @@ class TransposingAttr final
OriginalFunction = decl;
}

/// The parsed transposing parameters, i.e. the list of parameters
/// specified in 'wrt:'.
/// The parsed transposed parameters, i.e. the list of parameters specified in
/// 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
Expand All @@ -1815,7 +1813,7 @@ class TransposingAttr final
}

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Transposing;
return DA->getKind() == DAK_Transpose;
}
};

Expand Down
10 changes: 5 additions & 5 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1572,14 +1572,14 @@ WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
"instead", ())

// transposing
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
// transpose
ERROR(attr_transpose_expected_original_name,PointsToFirstBadToken,
"expected an original function name", ())
ERROR(attr_transposing_expected_label_linear_or_wrt,none,
ERROR(attr_transpose_expected_label_linear_or_wrt,none,
"expected 'wrt:'", ())

// transposing `wrt` parameters clause
ERROR(transposing_params_clause_expected_parameter,PointsToFirstBadToken,
// transpose `wrt` parameters clause
ERROR(transpose_params_clause_expected_parameter,PointsToFirstBadToken,
"expected a parameter, which can be a 'unsigned int' parameter number "
"or 'self'", ())

Expand Down
14 changes: 7 additions & 7 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3024,17 +3024,17 @@ ERROR(derivative_attr_original_stored_property_unsupported,none,
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))

// transposing
// @transpose
ERROR(transpose_params_clause_param_not_differentiable,none,
"can only transpose with respect to parameters that conform to "
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
ERROR(transposing_attr_overload_not_found,none,
ERROR(transpose_attr_overload_not_found,none,
"could not find function %0 with expected type %1", (DeclName, Type))
ERROR(transposing_attr_cannot_use_named_wrt_params,none,
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
(Identifier))
ERROR(transposing_attr_result_value_not_differentiable,none,
"'@transposing' attribute requires original function result %0 to "
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
"%0", (Identifier))
ERROR(transpose_attr_result_value_not_differentiable,none,
"'@transpose(of:)' attribute requires original function result %0 to "
"conform to 'Differentiable'", (Type))

// differentiation `wrt` parameters clause
Expand Down
9 changes: 4 additions & 5 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3142,7 +3142,7 @@ class AnyFunctionType : public TypeBase {
///
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
/// first. This should be used during type-checking, e.g. type-checking
/// `@differentiable`, `@derivative`, and `@transposing` attributes.
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
///
/// \note The original function type (`self`) need not be `@differentiable`.
/// The resulting function will preserve all `ExtInfo` of the original
Expand All @@ -3158,11 +3158,10 @@ class AnyFunctionType : public TypeBase {
/// corresponding original function type.
AnyFunctionType *getAutoDiffOriginalFunctionType();

/// Given the type of a transposing derivative function, returns the
/// corresponding original function type.
/// Given the type of a transpose function, returns the corresponding original
/// function type.
AnyFunctionType *
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices,
bool wrtSelf);
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices, bool wrtSelf);

AnyFunctionType *getWithoutDifferentiability() const;

Expand Down
18 changes: 10 additions & 8 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,13 +996,10 @@ class Parser {
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
TrailingWhereClause *&whereClause);

/// Parse a differentiation parameters clause.
/// Parse a differentiation parameters clause, i.e. the "wrt:" clause in
/// @differentiable and @derivative attributes.
bool parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);

/// Parse a transposing parameters clause.
bool parseTransposingParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);

/// Parse the @derivative attribute.
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
Expand All @@ -1013,9 +1010,14 @@ class Parser {
ParserResult<DerivativeAttr> parseDifferentiatingAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse the @transposing attribute.
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
SourceLoc Loc);
/// Parse a transposed parameters clause, i.e. the "wrt:" clause in @transpose
/// attributes.
bool parseTransposedParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);

/// Parse the @transpose attribute.
ParserResult<TransposeAttr> parseTransposeAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse the @quoted attribute.
ParserResult<QuotedAttr> parseQuotedAttribute(SourceLoc AtLoc, SourceLoc Loc);
Expand Down
67 changes: 33 additions & 34 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,10 +936,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
}

// SWIFT_ENABLE_TENSORFLOW
case DAK_Transposing: {
Printer.printAttrName("@transposing");
case DAK_Transpose: {
Printer.printAttrName("@transpose");
Printer << '(';
auto *attr = cast<TransposingAttr>(this);
auto *attr = cast<TransposeAttr>(this);
auto *transpose = cast<AbstractFunctionDecl>(D);
Printer << attr->getOriginalFunctionName().Name;
auto transParamsString = getTransposedParametersClauseString(
Expand Down Expand Up @@ -1110,12 +1110,13 @@ StringRef DeclAttribute::getAttrName() const {
return "differentiable";
case DAK_Derivative:
return "derivative";
case DAK_Transpose:
return "transpose";
case DAK_Differentiating:
return "differentiating";
case DAK_Transposing:
return "transposing";
case DAK_Quoted:
return "quoted";
// SWIFT_ENABLE_TENSORFLOW END
}
llvm_unreachable("bad DeclAttrKind");
}
Expand Down Expand Up @@ -1608,45 +1609,43 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
std::move(originalName), indices);
}

TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::uninitialized_copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc originalName,
IndexSubset *indices)
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc originalName, IndexSubset *indices)
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
ParameterIndices(indices) {}

TransposingAttr *
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params) {
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType,
DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(TransposingAttr));
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
std::move(original), params);
}

TransposingAttr *
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original,
IndexSubset *indices) {
void *mem =
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
std::move(original), indices);
void *mem = context.Allocate(size, alignof(TransposeAttr));
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
std::move(originalName), params);
}

TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType,
DeclNameWithLoc originalName,
IndexSubset *indices) {
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
std::move(originalName), indices);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
Expand Down
3 changes: 1 addition & 2 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4897,8 +4897,7 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
assert(originalResult);

SmallVector<TupleTypeElt, 4> transposeResultTypes;
// Return type of '@transposing' function can have single type or tuples
// of types.
// Return type of transpose function can be a singular type or a tuple type.
if (auto transposeResultTupleType = transposeResult->getAs<TupleType>()) {
transposeResultTypes.append(transposeResultTupleType->getElements().begin(),
transposeResultTupleType->getElements().end());
Expand Down
Loading