Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenACC] Implement 'reduction' sema for compute constructs #92808

Merged
merged 2 commits into from
May 21, 2024

Conversation

erichkeane
Copy link
Collaborator

'reduction' has a few restrictions over normal 'var-list' clauses:

1- On parallel, a num_gangs can only have 1 argument when combined with reduction. These two aren't able to be combined on any other of the compute constructs however.

2- The vars all must be 'numerical data types' types of some sort, or a 'composite of numerical data types'. A list of types is given in the standard as a minimum, so we choose 'isScalar', which covers all of these types and keeps types that are actually numeric. Other compilers don't seem to implement the 'composite of numerical data types', though we do.

3- Because of the above restrictions, member-of-composite is not allowed, so any access via a memberexpr is disallowed. Array-element and sub-arrays (aka array sections) are both permitted, so long as they meet the requirements of #2.

This patch implements all of these for compute constructs.

'reduction' has a few restrictions over normal 'var-list' clauses:

1- On parallel, a num_gangs can only have 1 argument when combined with
reduction. These two aren't able to be combined on any other of the
compute constructs however.

2- The vars all must be 'numerical data types' types of some sort, or a
'composite of numerical data types'.  A list of types is given in the standard
as a minimum, so we choose 'isScalar', which covers all of these types
and keeps types that are actually numeric. Other compilers don't seem to
implement the 'composite of numerical data types', though we do.

3- Because of the above restrictions, member-of-composite is not
allowed, so any access via a memberexpr is disallowed. Array-element and
sub-arrays (aka array sections) are both permitted, so long as they meet
the requirements of #2.

This patch implements all of these for compute constructs.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules labels May 20, 2024
@llvmbot
Copy link
Member

llvmbot commented May 20, 2024

@llvm/pr-subscribers-clang

Author: Erich Keane (erichkeane)

Changes

'reduction' has a few restrictions over normal 'var-list' clauses:

1- On parallel, a num_gangs can only have 1 argument when combined with reduction. These two aren't able to be combined on any other of the compute constructs however.

2- The vars all must be 'numerical data types' types of some sort, or a 'composite of numerical data types'. A list of types is given in the standard as a minimum, so we choose 'isScalar', which covers all of these types and keeps types that are actually numeric. Other compilers don't seem to implement the 'composite of numerical data types', though we do.

3- Because of the above restrictions, member-of-composite is not allowed, so any access via a memberexpr is disallowed. Array-element and sub-arrays (aka array sections) are both permitted, so long as they meet the requirements of #2.

This patch implements all of these for compute constructs.


Patch is 107.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92808.diff

39 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+29)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+16-2)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Basic/OpenACCKinds.h (+36)
  • (modified) clang/include/clang/Parse/Parser.h (+2-2)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+27-2)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+19-1)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+4)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+16-14)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+147-10)
  • (modified) clang/lib/Sema/TreeTransform.h (+20-1)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+7-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+7-1)
  • (modified) clang/test/AST/ast-print-openacc-compute-construct.cpp (+28)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (+6-20)
  • (modified) clang/test/SemaOpenACC/compute-construct-attach-clause.c (+1-1)
  • (modified) clang/test/SemaOpenACC/compute-construct-clause-ast.cpp (+248)
  • (modified) clang/test/SemaOpenACC/compute-construct-copy-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-copy-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyin-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyin-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyout-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyout-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-create-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-create-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-1)
  • (modified) clang/test/SemaOpenACC/compute-construct-deviceptr-clause.c (+1-1)
  • (modified) clang/test/SemaOpenACC/compute-construct-firstprivate-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-firstprivate-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-no_create-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-no_create-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-present-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-present-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-private-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-private-clause.cpp (+8-8)
  • (added) clang/test/SemaOpenACC/compute-construct-reduction-clause.c (+107)
  • (added) clang/test/SemaOpenACC/compute-construct-reduction-clause.cpp (+175)
  • (modified) clang/tools/libclang/CIndex.cpp (+4)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 607a2b9d65367..28ff8c44bd256 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -677,6 +677,35 @@ class OpenACCCreateClause final
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
 };
 
+class OpenACCReductionClause final
+    : public OpenACCClauseWithVarList,
+      public llvm::TrailingObjects<OpenACCReductionClause, Expr *> {
+  OpenACCReductionOperator Op;
+
+  OpenACCReductionClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                         OpenACCReductionOperator Operator,
+                         ArrayRef<Expr *> VarList, SourceLocation EndLoc)
+      : OpenACCClauseWithVarList(OpenACCClauseKind::Reduction, BeginLoc,
+                                 LParenLoc, EndLoc),
+        Op(Operator) {
+    std::uninitialized_copy(VarList.begin(), VarList.end(),
+                            getTrailingObjects<Expr *>());
+    setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
+  }
+
+public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Reduction;
+  }
+
+  static OpenACCReductionClause *
+  Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+         OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
+         SourceLocation EndLoc);
+
+  OpenACCReductionOperator getReductionOp() const { return Op; }
+};
+
 template <class Impl> class OpenACCClauseVisitor {
   Impl &getDerived() { return static_cast<Impl &>(*this); }
 
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e3c65cba4886a..c7dea1d54d063 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12343,7 +12343,8 @@ def err_acc_num_gangs_num_args
             "provided}0">;
 def err_acc_not_a_var_ref
     : Error<"OpenACC variable is not a valid variable name, sub-array, array "
-            "element, or composite variable member">;
+            "element,%select{| member of a composite variable,}0 or composite "
+            "variable member">;
 def err_acc_typecheck_subarray_value
     : Error<"OpenACC sub-array subscripted value is not an array or pointer">;
 def err_acc_subarray_function_type
@@ -12374,5 +12375,18 @@ def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
 def err_acc_clause_after_device_type
     : Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
             "compute construct">;
-
+def err_acc_reduction_num_gangs_conflict
+    : Error<
+          "OpenACC 'reduction' clause may not appear on a 'parallel' construct "
+          "with a 'num_gangs' clause with more than 1 argument, have %0">;
+def err_acc_reduction_type
+    : Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
+            "composite of scalar types;%select{| sub-array base}1 type is %0">;
+def err_acc_reduction_composite_type
+    : Error<"OpenACC 'reduction' variable must be a composite of scalar types; "
+            "%1 %select{is not a class or struct|is incomplete|is not an "
+            "aggregate}0">;
+def err_acc_reduction_composite_member_type :Error<
+    "OpenACC 'reduction' composite variable must not have non-scalar field">;
+def note_acc_reduction_composite_member_loc : Note<"invalid field is here">;
 } // end of sema component.
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 7ecc51799468c..3e464abaafd92 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -46,6 +46,7 @@ VISIT_CLAUSE(NumGangs)
 VISIT_CLAUSE(NumWorkers)
 VISIT_CLAUSE(Present)
 VISIT_CLAUSE(Private)
+VISIT_CLAUSE(Reduction)
 VISIT_CLAUSE(Self)
 VISIT_CLAUSE(VectorLength)
 VISIT_CLAUSE(Wait)
diff --git a/clang/include/clang/Basic/OpenACCKinds.h b/clang/include/clang/Basic/OpenACCKinds.h
index 0e38a04e7164b..7b9d619a8aec6 100644
--- a/clang/include/clang/Basic/OpenACCKinds.h
+++ b/clang/include/clang/Basic/OpenACCKinds.h
@@ -514,6 +514,42 @@ enum class OpenACCReductionOperator {
   /// Invalid Reduction Clause Kind.
   Invalid,
 };
+
+template <typename StreamTy>
+inline StreamTy &printOpenACCReductionOperator(StreamTy &Out,
+                                               OpenACCReductionOperator Op) {
+  switch (Op) {
+  case OpenACCReductionOperator::Addition:
+    return Out << "+";
+  case OpenACCReductionOperator::Multiplication:
+    return Out << "*";
+  case OpenACCReductionOperator::Max:
+    return Out << "max";
+  case OpenACCReductionOperator::Min:
+    return Out << "min";
+  case OpenACCReductionOperator::BitwiseAnd:
+    return Out << "&";
+  case OpenACCReductionOperator::BitwiseOr:
+    return Out << "|";
+  case OpenACCReductionOperator::BitwiseXOr:
+    return Out << "^";
+  case OpenACCReductionOperator::And:
+    return Out << "&&";
+  case OpenACCReductionOperator::Or:
+    return Out << "||";
+  case OpenACCReductionOperator::Invalid:
+    return Out << "<invalid>";
+  }
+  llvm_unreachable("Unknown reduction operator kind");
+}
+inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
+                                             OpenACCReductionOperator Op) {
+  return printOpenACCReductionOperator(Out, Op);
+}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
+                                     OpenACCReductionOperator Op) {
+  return printOpenACCReductionOperator(Out, Op);
+}
 } // namespace clang
 
 #endif // LLVM_CLANG_BASIC_OPENACCKINDS_H
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 5f04664141d29..3c4ab649e3b4c 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3686,9 +3686,9 @@ class Parser : public CodeCompletionHandler {
 
   using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
   /// Parses a single variable in a variable list for OpenACC.
-  OpenACCVarParseResult ParseOpenACCVar();
+  OpenACCVarParseResult ParseOpenACCVar(OpenACCClauseKind CK);
   /// Parses the variable list for the variety of places that take a var-list.
-  llvm::SmallVector<Expr *> ParseOpenACCVarList();
+  llvm::SmallVector<Expr *> ParseOpenACCVarList(OpenACCClauseKind CK);
   /// Parses any parameters for an OpenACC Clause, including required/optional
   /// parens.
   OpenACCClauseParseResult
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index f838fa97d33a2..6f69fa08939b8 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -66,9 +66,14 @@ class SemaOpenACC : public SemaBase {
     struct DeviceTypeDetails {
       SmallVector<DeviceTypeArgument> Archs;
     };
+    struct ReductionDetails {
+      OpenACCReductionOperator Op;
+      SmallVector<Expr *> VarList;
+    };
 
     std::variant<std::monostate, DefaultDetails, ConditionDetails,
-                 IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
+                 IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
+                 ReductionDetails>
         Details = std::monostate{};
 
   public:
@@ -170,6 +175,10 @@ class SemaOpenACC : public SemaBase {
       return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
     }
 
+    OpenACCReductionOperator getReductionOp() const {
+      return std::get<ReductionDetails>(Details).Op;
+    }
+
     ArrayRef<Expr *> getVarList() {
       assert((ClauseKind == OpenACCClauseKind::Private ||
               ClauseKind == OpenACCClauseKind::NoCreate ||
@@ -188,8 +197,13 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::PresentOrCreate ||
               ClauseKind == OpenACCClauseKind::Attach ||
               ClauseKind == OpenACCClauseKind::DevicePtr ||
+              ClauseKind == OpenACCClauseKind::Reduction ||
               ClauseKind == OpenACCClauseKind::FirstPrivate) &&
              "Parsed clause kind does not have a var-list");
+
+      if (ClauseKind == OpenACCClauseKind::Reduction)
+        return std::get<ReductionDetails>(Details).VarList;
+
       return std::get<VarListDetails>(Details).VarList;
     }
 
@@ -334,6 +348,13 @@ class SemaOpenACC : public SemaBase {
       Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
     }
 
+    void setReductionDetails(OpenACCReductionOperator Op,
+                             llvm::SmallVector<Expr *> &&VarList) {
+      assert(ClauseKind == OpenACCClauseKind::Reduction &&
+             "reduction details only valid on reduction");
+      Details = ReductionDetails{Op, std::move(VarList)};
+    }
+
     void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
                         llvm::SmallVector<Expr *> &&IntExprs) {
       assert(ClauseKind == OpenACCClauseKind::Wait &&
@@ -394,7 +415,11 @@ class SemaOpenACC : public SemaBase {
 
   /// Called when encountering a 'var' for OpenACC, ensures it is actually a
   /// declaration reference to a variable of the correct type.
-  ExprResult ActOnVar(Expr *VarExpr);
+  ExprResult ActOnVar(OpenACCClauseKind CK, Expr *VarExpr);
+
+  /// Called while semantically analyzing the reduction clause, ensuring the var
+  /// is the correct kind of reference.
+  ExprResult CheckReductionVar(Expr *VarExpr);
 
   /// Called to check the 'var' type is a variable of pointer type, necessary
   /// for 'deviceptr' and 'attach' clauses. Returns true on success.
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 8ff6dabcbc48e..cb2c7f98be75c 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -35,7 +35,7 @@ bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
          OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) ||
          OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
          OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
-         OpenACCCreateClause::classof(C);
+         OpenACCReductionClause::classof(C) || OpenACCCreateClause::classof(C);
 }
 bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
   return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
@@ -310,6 +310,16 @@ OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
       OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
 }
 
+OpenACCReductionClause *OpenACCReductionClause::Create(
+    const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+    OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
+    SourceLocation EndLoc) {
+  void *Mem = C.Allocate(
+      OpenACCReductionClause::totalSizeToAlloc<Expr *>(VarList.size()));
+  return new (Mem)
+      OpenACCReductionClause(BeginLoc, LParenLoc, Operator, VarList, EndLoc);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenACC clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -445,6 +455,14 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
   OS << ")";
 }
 
+void OpenACCClausePrinter::VisitReductionClause(
+    const OpenACCReductionClause &C) {
+  OS << "reduction(" << C.getReductionOp() << ": ";
+  llvm::interleaveComma(C.getVarList(), OS,
+                        [&](const Expr *E) { printExpr(E); });
+  OS << ")";
+}
+
 void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
   OS << "wait";
   if (!C.getLParenLoc().isInvalid()) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index caab4ab0ef160..00b8c43af035c 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2588,6 +2588,12 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
 /// Nothing to do here, there are no sub-statements.
 void OpenACCClauseProfiler::VisitDeviceTypeClause(
     const OpenACCDeviceTypeClause &Clause) {}
+
+void OpenACCClauseProfiler::VisitReductionClause(
+    const OpenACCReductionClause &Clause) {
+  for (auto *E : Clause.getVarList())
+    Profiler.VisitStmt(E);
+}
 } // namespace
 
 void StmtProfiler::VisitOpenACCComputeConstruct(
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index efcd74717a4e2..4a1e94ffe283b 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -457,6 +457,10 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
           });
       OS << ")";
       break;
+    case OpenACCClauseKind::Reduction:
+      OS << " clause Operator: "
+         << cast<OpenACCReductionClause>(C)->getReductionOp();
+      break;
     default:
       // Nothing to do here.
       break;
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 5db3036b00030..e9c60f76165b6 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -920,7 +920,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::PresentOrCopyIn: {
       bool IsReadOnly = tryParseAndConsumeSpecialTokenKind(
           *this, OpenACCSpecialTokenKind::ReadOnly, ClauseKind);
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(), IsReadOnly,
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
+                                     IsReadOnly,
                                      /*IsZero=*/false);
       break;
     }
@@ -932,16 +933,17 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::PresentOrCopyOut: {
       bool IsZero = tryParseAndConsumeSpecialTokenKind(
           *this, OpenACCSpecialTokenKind::Zero, ClauseKind);
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, IsZero);
       break;
     }
-    case OpenACCClauseKind::Reduction:
+    case OpenACCClauseKind::Reduction: {
       // If we're missing a clause-kind (or it is invalid), see if we can parse
       // the var-list anyway.
-      ParseReductionOperator(*this);
-      ParseOpenACCVarList();
+      OpenACCReductionOperator Op = ParseReductionOperator(*this);
+      ParsedClause.setReductionDetails(Op, ParseOpenACCVarList(ClauseKind));
       break;
+    }
     case OpenACCClauseKind::Self:
       // The 'self' clause is a var-list instead of a 'condition' in the case of
       // the 'update' clause, so we have to handle it here.  U se an assert to
@@ -955,11 +957,11 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::Host:
     case OpenACCClauseKind::Link:
     case OpenACCClauseKind::UseDevice:
-      ParseOpenACCVarList();
+      ParseOpenACCVarList(ClauseKind);
       break;
     case OpenACCClauseKind::Attach:
     case OpenACCClauseKind::DevicePtr:
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Copy:
@@ -969,7 +971,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::NoCreate:
     case OpenACCClauseKind::Present:
     case OpenACCClauseKind::Private:
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Collapse: {
@@ -1278,7 +1280,7 @@ ExprResult Parser::ParseOpenACCBindClauseArgument() {
 /// - an array element
 /// - a member of a composite variable
 /// - a common block name between slashes (fortran only)
-Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
+Parser::OpenACCVarParseResult Parser::ParseOpenACCVar(OpenACCClauseKind CK) {
   OpenACCArraySectionRAII ArraySections(*this);
 
   ExprResult Res = ParseAssignmentExpression();
@@ -1289,15 +1291,15 @@ Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
   if (!Res.isUsable())
     return {Res, OpenACCParseCanContinue::Can};
 
-  Res = getActions().OpenACC().ActOnVar(Res.get());
+  Res = getActions().OpenACC().ActOnVar(CK, Res.get());
 
   return {Res, OpenACCParseCanContinue::Can};
 }
 
-llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
+llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList(OpenACCClauseKind CK) {
   llvm::SmallVector<Expr *> Vars;
 
-  auto [Res, CanContinue] = ParseOpenACCVar();
+  auto [Res, CanContinue] = ParseOpenACCVar(CK);
   if (Res.isUsable()) {
     Vars.push_back(Res.get());
   } else if (CanContinue == OpenACCParseCanContinue::Cannot) {
@@ -1308,7 +1310,7 @@ llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
   while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
     ExpectAndConsume(tok::comma);
 
-    auto [Res, CanContinue] = ParseOpenACCVar();
+    auto [Res, CanContinue] = ParseOpenACCVar(CK);
 
     if (Res.isUsable()) {
       Vars.push_back(Res.get());
@@ -1342,7 +1344,7 @@ void Parser::ParseOpenACCCacheVarList() {
 
   // ParseOpenACCVarList should leave us before a r-paren, so no need to skip
   // anything here.
-  ParseOpenACCVarList();
+  ParseOpenACCVarList(OpenACCClauseKind::Invalid);
 }
 
 Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index f174b2fa63c6a..49847bd997161 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -233,6 +233,19 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
       return false;
     }
 
+  case OpenACCClauseKind::Reduction:
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Parallel:
+    case OpenACCDirectiveKind::Serial:
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+      return true;
+    default:
+      return false;
+    }
+
   default:
     // Do nothing so we can go to the 'unimplemented' diagnostic instead.
     return true;
@@ -281,7 +294,6 @@ bool checkValidAfterDeviceType(
     return true;
   }
 }
-
 } // namespace
 
 SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -426,6 +438,24 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
           << /*NoArgs=*/1 << Clause.getDirectiveKind() << MaxArgs
           << Clause.getIntExprs().size();
 
+    // OpenACC 3.3 Section 2.5.4:
+    // A reduction clause may not appear on a parallel construct with a
+    // num_gangs clause that has more than one argument.
+    if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel &&
+        Clause.getIntExprs().size() > 1) {
+      auto *Parallel =
+          llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
+            return C->getClauseKind() == OpenACCClauseKind::Reduction;
+          });
+
+      if (Parallel != ExistingClauses.end()) {
+        Diag(Clause.getBeginLoc(), diag::err_acc_reduction_num_gangs_conflict)
+            << Clause.getIntExprs().size();
+        Diag((*Parallel)->getBeginLoc(), diag::note_acc_previous_clause_here);
+        return nullptr;
+      }
+    }
+
     // Create the AST node for the clause even if the number of expressions is
     // incorrect.
     return OpenACCNumGangsClause::Create(
@@ -706,6 +736,48 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
         Clause.getLParenLoc(), Clause.getDeviceTypeArchitectures(),
         Clause.getEndLoc());
   }
+  case OpenACCClauseKind::Reduction: {
+    // Restrictions only properly implemented on 'compute' constructs, and
+    // 'compute' constructs are the only construct that can do anything with
+    // this yet, so skip/treat as unimplemented in this case.
+    if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
+      break;
+
+    // OpenACC 3.3 Section 2.5.4:
+    // A reduction clause may not appear on a parallel construct with a
+...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 20, 2024

@llvm/pr-subscribers-clang-modules

Author: Erich Keane (erichkeane)

Changes

'reduction' has a few restrictions over normal 'var-list' clauses:

1- On parallel, a num_gangs can only have 1 argument when combined with reduction. These two aren't able to be combined on any other of the compute constructs however.

2- The vars all must be 'numerical data types' types of some sort, or a 'composite of numerical data types'. A list of types is given in the standard as a minimum, so we choose 'isScalar', which covers all of these types and keeps types that are actually numeric. Other compilers don't seem to implement the 'composite of numerical data types', though we do.

3- Because of the above restrictions, member-of-composite is not allowed, so any access via a memberexpr is disallowed. Array-element and sub-arrays (aka array sections) are both permitted, so long as they meet the requirements of #2.

This patch implements all of these for compute constructs.


Patch is 107.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92808.diff

39 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+29)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+16-2)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Basic/OpenACCKinds.h (+36)
  • (modified) clang/include/clang/Parse/Parser.h (+2-2)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+27-2)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+19-1)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+4)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+16-14)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+147-10)
  • (modified) clang/lib/Sema/TreeTransform.h (+20-1)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+7-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+7-1)
  • (modified) clang/test/AST/ast-print-openacc-compute-construct.cpp (+28)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (+6-20)
  • (modified) clang/test/SemaOpenACC/compute-construct-attach-clause.c (+1-1)
  • (modified) clang/test/SemaOpenACC/compute-construct-clause-ast.cpp (+248)
  • (modified) clang/test/SemaOpenACC/compute-construct-copy-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-copy-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyin-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyin-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyout-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-copyout-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-create-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-create-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-1)
  • (modified) clang/test/SemaOpenACC/compute-construct-deviceptr-clause.c (+1-1)
  • (modified) clang/test/SemaOpenACC/compute-construct-firstprivate-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-firstprivate-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-no_create-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-no_create-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-present-clause.c (+4-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-present-clause.cpp (+8-8)
  • (modified) clang/test/SemaOpenACC/compute-construct-private-clause.c (+5-5)
  • (modified) clang/test/SemaOpenACC/compute-construct-private-clause.cpp (+8-8)
  • (added) clang/test/SemaOpenACC/compute-construct-reduction-clause.c (+107)
  • (added) clang/test/SemaOpenACC/compute-construct-reduction-clause.cpp (+175)
  • (modified) clang/tools/libclang/CIndex.cpp (+4)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 607a2b9d65367..28ff8c44bd256 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -677,6 +677,35 @@ class OpenACCCreateClause final
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
 };
 
+class OpenACCReductionClause final
+    : public OpenACCClauseWithVarList,
+      public llvm::TrailingObjects<OpenACCReductionClause, Expr *> {
+  OpenACCReductionOperator Op;
+
+  OpenACCReductionClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                         OpenACCReductionOperator Operator,
+                         ArrayRef<Expr *> VarList, SourceLocation EndLoc)
+      : OpenACCClauseWithVarList(OpenACCClauseKind::Reduction, BeginLoc,
+                                 LParenLoc, EndLoc),
+        Op(Operator) {
+    std::uninitialized_copy(VarList.begin(), VarList.end(),
+                            getTrailingObjects<Expr *>());
+    setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
+  }
+
+public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Reduction;
+  }
+
+  static OpenACCReductionClause *
+  Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+         OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
+         SourceLocation EndLoc);
+
+  OpenACCReductionOperator getReductionOp() const { return Op; }
+};
+
 template <class Impl> class OpenACCClauseVisitor {
   Impl &getDerived() { return static_cast<Impl &>(*this); }
 
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e3c65cba4886a..c7dea1d54d063 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12343,7 +12343,8 @@ def err_acc_num_gangs_num_args
             "provided}0">;
 def err_acc_not_a_var_ref
     : Error<"OpenACC variable is not a valid variable name, sub-array, array "
-            "element, or composite variable member">;
+            "element,%select{| member of a composite variable,}0 or composite "
+            "variable member">;
 def err_acc_typecheck_subarray_value
     : Error<"OpenACC sub-array subscripted value is not an array or pointer">;
 def err_acc_subarray_function_type
@@ -12374,5 +12375,18 @@ def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
 def err_acc_clause_after_device_type
     : Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
             "compute construct">;
-
+def err_acc_reduction_num_gangs_conflict
+    : Error<
+          "OpenACC 'reduction' clause may not appear on a 'parallel' construct "
+          "with a 'num_gangs' clause with more than 1 argument, have %0">;
+def err_acc_reduction_type
+    : Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
+            "composite of scalar types;%select{| sub-array base}1 type is %0">;
+def err_acc_reduction_composite_type
+    : Error<"OpenACC 'reduction' variable must be a composite of scalar types; "
+            "%1 %select{is not a class or struct|is incomplete|is not an "
+            "aggregate}0">;
+def err_acc_reduction_composite_member_type :Error<
+    "OpenACC 'reduction' composite variable must not have non-scalar field">;
+def note_acc_reduction_composite_member_loc : Note<"invalid field is here">;
 } // end of sema component.
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 7ecc51799468c..3e464abaafd92 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -46,6 +46,7 @@ VISIT_CLAUSE(NumGangs)
 VISIT_CLAUSE(NumWorkers)
 VISIT_CLAUSE(Present)
 VISIT_CLAUSE(Private)
+VISIT_CLAUSE(Reduction)
 VISIT_CLAUSE(Self)
 VISIT_CLAUSE(VectorLength)
 VISIT_CLAUSE(Wait)
diff --git a/clang/include/clang/Basic/OpenACCKinds.h b/clang/include/clang/Basic/OpenACCKinds.h
index 0e38a04e7164b..7b9d619a8aec6 100644
--- a/clang/include/clang/Basic/OpenACCKinds.h
+++ b/clang/include/clang/Basic/OpenACCKinds.h
@@ -514,6 +514,42 @@ enum class OpenACCReductionOperator {
   /// Invalid Reduction Clause Kind.
   Invalid,
 };
+
+template <typename StreamTy>
+inline StreamTy &printOpenACCReductionOperator(StreamTy &Out,
+                                               OpenACCReductionOperator Op) {
+  switch (Op) {
+  case OpenACCReductionOperator::Addition:
+    return Out << "+";
+  case OpenACCReductionOperator::Multiplication:
+    return Out << "*";
+  case OpenACCReductionOperator::Max:
+    return Out << "max";
+  case OpenACCReductionOperator::Min:
+    return Out << "min";
+  case OpenACCReductionOperator::BitwiseAnd:
+    return Out << "&";
+  case OpenACCReductionOperator::BitwiseOr:
+    return Out << "|";
+  case OpenACCReductionOperator::BitwiseXOr:
+    return Out << "^";
+  case OpenACCReductionOperator::And:
+    return Out << "&&";
+  case OpenACCReductionOperator::Or:
+    return Out << "||";
+  case OpenACCReductionOperator::Invalid:
+    return Out << "<invalid>";
+  }
+  llvm_unreachable("Unknown reduction operator kind");
+}
+inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
+                                             OpenACCReductionOperator Op) {
+  return printOpenACCReductionOperator(Out, Op);
+}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
+                                     OpenACCReductionOperator Op) {
+  return printOpenACCReductionOperator(Out, Op);
+}
 } // namespace clang
 
 #endif // LLVM_CLANG_BASIC_OPENACCKINDS_H
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 5f04664141d29..3c4ab649e3b4c 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3686,9 +3686,9 @@ class Parser : public CodeCompletionHandler {
 
   using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
   /// Parses a single variable in a variable list for OpenACC.
-  OpenACCVarParseResult ParseOpenACCVar();
+  OpenACCVarParseResult ParseOpenACCVar(OpenACCClauseKind CK);
   /// Parses the variable list for the variety of places that take a var-list.
-  llvm::SmallVector<Expr *> ParseOpenACCVarList();
+  llvm::SmallVector<Expr *> ParseOpenACCVarList(OpenACCClauseKind CK);
   /// Parses any parameters for an OpenACC Clause, including required/optional
   /// parens.
   OpenACCClauseParseResult
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index f838fa97d33a2..6f69fa08939b8 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -66,9 +66,14 @@ class SemaOpenACC : public SemaBase {
     struct DeviceTypeDetails {
       SmallVector<DeviceTypeArgument> Archs;
     };
+    struct ReductionDetails {
+      OpenACCReductionOperator Op;
+      SmallVector<Expr *> VarList;
+    };
 
     std::variant<std::monostate, DefaultDetails, ConditionDetails,
-                 IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
+                 IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
+                 ReductionDetails>
         Details = std::monostate{};
 
   public:
@@ -170,6 +175,10 @@ class SemaOpenACC : public SemaBase {
       return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
     }
 
+    OpenACCReductionOperator getReductionOp() const {
+      return std::get<ReductionDetails>(Details).Op;
+    }
+
     ArrayRef<Expr *> getVarList() {
       assert((ClauseKind == OpenACCClauseKind::Private ||
               ClauseKind == OpenACCClauseKind::NoCreate ||
@@ -188,8 +197,13 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::PresentOrCreate ||
               ClauseKind == OpenACCClauseKind::Attach ||
               ClauseKind == OpenACCClauseKind::DevicePtr ||
+              ClauseKind == OpenACCClauseKind::Reduction ||
               ClauseKind == OpenACCClauseKind::FirstPrivate) &&
              "Parsed clause kind does not have a var-list");
+
+      if (ClauseKind == OpenACCClauseKind::Reduction)
+        return std::get<ReductionDetails>(Details).VarList;
+
       return std::get<VarListDetails>(Details).VarList;
     }
 
@@ -334,6 +348,13 @@ class SemaOpenACC : public SemaBase {
       Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
     }
 
+    void setReductionDetails(OpenACCReductionOperator Op,
+                             llvm::SmallVector<Expr *> &&VarList) {
+      assert(ClauseKind == OpenACCClauseKind::Reduction &&
+             "reduction details only valid on reduction");
+      Details = ReductionDetails{Op, std::move(VarList)};
+    }
+
     void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
                         llvm::SmallVector<Expr *> &&IntExprs) {
       assert(ClauseKind == OpenACCClauseKind::Wait &&
@@ -394,7 +415,11 @@ class SemaOpenACC : public SemaBase {
 
   /// Called when encountering a 'var' for OpenACC, ensures it is actually a
   /// declaration reference to a variable of the correct type.
-  ExprResult ActOnVar(Expr *VarExpr);
+  ExprResult ActOnVar(OpenACCClauseKind CK, Expr *VarExpr);
+
+  /// Called while semantically analyzing the reduction clause, ensuring the var
+  /// is the correct kind of reference.
+  ExprResult CheckReductionVar(Expr *VarExpr);
 
   /// Called to check the 'var' type is a variable of pointer type, necessary
   /// for 'deviceptr' and 'attach' clauses. Returns true on success.
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 8ff6dabcbc48e..cb2c7f98be75c 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -35,7 +35,7 @@ bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
          OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) ||
          OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
          OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
-         OpenACCCreateClause::classof(C);
+         OpenACCReductionClause::classof(C) || OpenACCCreateClause::classof(C);
 }
 bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
   return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
@@ -310,6 +310,16 @@ OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
       OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
 }
 
+OpenACCReductionClause *OpenACCReductionClause::Create(
+    const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+    OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
+    SourceLocation EndLoc) {
+  void *Mem = C.Allocate(
+      OpenACCReductionClause::totalSizeToAlloc<Expr *>(VarList.size()));
+  return new (Mem)
+      OpenACCReductionClause(BeginLoc, LParenLoc, Operator, VarList, EndLoc);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenACC clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -445,6 +455,14 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
   OS << ")";
 }
 
+void OpenACCClausePrinter::VisitReductionClause(
+    const OpenACCReductionClause &C) {
+  OS << "reduction(" << C.getReductionOp() << ": ";
+  llvm::interleaveComma(C.getVarList(), OS,
+                        [&](const Expr *E) { printExpr(E); });
+  OS << ")";
+}
+
 void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
   OS << "wait";
   if (!C.getLParenLoc().isInvalid()) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index caab4ab0ef160..00b8c43af035c 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2588,6 +2588,12 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
 /// Nothing to do here, there are no sub-statements.
 void OpenACCClauseProfiler::VisitDeviceTypeClause(
     const OpenACCDeviceTypeClause &Clause) {}
+
+void OpenACCClauseProfiler::VisitReductionClause(
+    const OpenACCReductionClause &Clause) {
+  for (auto *E : Clause.getVarList())
+    Profiler.VisitStmt(E);
+}
 } // namespace
 
 void StmtProfiler::VisitOpenACCComputeConstruct(
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index efcd74717a4e2..4a1e94ffe283b 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -457,6 +457,10 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
           });
       OS << ")";
       break;
+    case OpenACCClauseKind::Reduction:
+      OS << " clause Operator: "
+         << cast<OpenACCReductionClause>(C)->getReductionOp();
+      break;
     default:
       // Nothing to do here.
       break;
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 5db3036b00030..e9c60f76165b6 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -920,7 +920,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::PresentOrCopyIn: {
       bool IsReadOnly = tryParseAndConsumeSpecialTokenKind(
           *this, OpenACCSpecialTokenKind::ReadOnly, ClauseKind);
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(), IsReadOnly,
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
+                                     IsReadOnly,
                                      /*IsZero=*/false);
       break;
     }
@@ -932,16 +933,17 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::PresentOrCopyOut: {
       bool IsZero = tryParseAndConsumeSpecialTokenKind(
           *this, OpenACCSpecialTokenKind::Zero, ClauseKind);
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, IsZero);
       break;
     }
-    case OpenACCClauseKind::Reduction:
+    case OpenACCClauseKind::Reduction: {
       // If we're missing a clause-kind (or it is invalid), see if we can parse
       // the var-list anyway.
-      ParseReductionOperator(*this);
-      ParseOpenACCVarList();
+      OpenACCReductionOperator Op = ParseReductionOperator(*this);
+      ParsedClause.setReductionDetails(Op, ParseOpenACCVarList(ClauseKind));
       break;
+    }
     case OpenACCClauseKind::Self:
       // The 'self' clause is a var-list instead of a 'condition' in the case of
       // the 'update' clause, so we have to handle it here.  U se an assert to
@@ -955,11 +957,11 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::Host:
     case OpenACCClauseKind::Link:
     case OpenACCClauseKind::UseDevice:
-      ParseOpenACCVarList();
+      ParseOpenACCVarList(ClauseKind);
       break;
     case OpenACCClauseKind::Attach:
     case OpenACCClauseKind::DevicePtr:
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Copy:
@@ -969,7 +971,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::NoCreate:
     case OpenACCClauseKind::Present:
     case OpenACCClauseKind::Private:
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Collapse: {
@@ -1278,7 +1280,7 @@ ExprResult Parser::ParseOpenACCBindClauseArgument() {
 /// - an array element
 /// - a member of a composite variable
 /// - a common block name between slashes (fortran only)
-Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
+Parser::OpenACCVarParseResult Parser::ParseOpenACCVar(OpenACCClauseKind CK) {
   OpenACCArraySectionRAII ArraySections(*this);
 
   ExprResult Res = ParseAssignmentExpression();
@@ -1289,15 +1291,15 @@ Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
   if (!Res.isUsable())
     return {Res, OpenACCParseCanContinue::Can};
 
-  Res = getActions().OpenACC().ActOnVar(Res.get());
+  Res = getActions().OpenACC().ActOnVar(CK, Res.get());
 
   return {Res, OpenACCParseCanContinue::Can};
 }
 
-llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
+llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList(OpenACCClauseKind CK) {
   llvm::SmallVector<Expr *> Vars;
 
-  auto [Res, CanContinue] = ParseOpenACCVar();
+  auto [Res, CanContinue] = ParseOpenACCVar(CK);
   if (Res.isUsable()) {
     Vars.push_back(Res.get());
   } else if (CanContinue == OpenACCParseCanContinue::Cannot) {
@@ -1308,7 +1310,7 @@ llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
   while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
     ExpectAndConsume(tok::comma);
 
-    auto [Res, CanContinue] = ParseOpenACCVar();
+    auto [Res, CanContinue] = ParseOpenACCVar(CK);
 
     if (Res.isUsable()) {
       Vars.push_back(Res.get());
@@ -1342,7 +1344,7 @@ void Parser::ParseOpenACCCacheVarList() {
 
   // ParseOpenACCVarList should leave us before a r-paren, so no need to skip
   // anything here.
-  ParseOpenACCVarList();
+  ParseOpenACCVarList(OpenACCClauseKind::Invalid);
 }
 
 Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index f174b2fa63c6a..49847bd997161 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -233,6 +233,19 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
       return false;
     }
 
+  case OpenACCClauseKind::Reduction:
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Parallel:
+    case OpenACCDirectiveKind::Serial:
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+      return true;
+    default:
+      return false;
+    }
+
   default:
     // Do nothing so we can go to the 'unimplemented' diagnostic instead.
     return true;
@@ -281,7 +294,6 @@ bool checkValidAfterDeviceType(
     return true;
   }
 }
-
 } // namespace
 
 SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -426,6 +438,24 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
           << /*NoArgs=*/1 << Clause.getDirectiveKind() << MaxArgs
           << Clause.getIntExprs().size();
 
+    // OpenACC 3.3 Section 2.5.4:
+    // A reduction clause may not appear on a parallel construct with a
+    // num_gangs clause that has more than one argument.
+    if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel &&
+        Clause.getIntExprs().size() > 1) {
+      auto *Parallel =
+          llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
+            return C->getClauseKind() == OpenACCClauseKind::Reduction;
+          });
+
+      if (Parallel != ExistingClauses.end()) {
+        Diag(Clause.getBeginLoc(), diag::err_acc_reduction_num_gangs_conflict)
+            << Clause.getIntExprs().size();
+        Diag((*Parallel)->getBeginLoc(), diag::note_acc_previous_clause_here);
+        return nullptr;
+      }
+    }
+
     // Create the AST node for the clause even if the number of expressions is
     // incorrect.
     return OpenACCNumGangsClause::Create(
@@ -706,6 +736,48 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
         Clause.getLParenLoc(), Clause.getDeviceTypeArchitectures(),
         Clause.getEndLoc());
   }
+  case OpenACCClauseKind::Reduction: {
+    // Restrictions only properly implemented on 'compute' constructs, and
+    // 'compute' constructs are the only construct that can do anything with
+    // this yet, so skip/treat as unimplemented in this case.
+    if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
+      break;
+
+    // OpenACC 3.3 Section 2.5.4:
+    // A reduction clause may not appear on a parallel construct with a
+...
[truncated]

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ast printing tests?

Comment on lines 447 to 449
llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Reduction;
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::find_if(ExitingClauses, IsaPred<OpenACCReductionClause>)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is awesome, thank you for showing me this!

Comment on lines 751 to 753
llvm::make_filter_range(ExistingClauses, [](const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::NumGangs;
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::make_filter_range(ExistingClauses, IsaPred<OpenACCNumGangsClause>)?

<< /*incomplete*/ 1 << VarExpr->getType();
return ExprError();
}
if (isa<CXXRecordDecl>(RD) && !cast<CXXRecordDecl>(RD)->isAggregate()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD); CXXRD && !CXXRD->?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, better.

@@ -130,5 +130,33 @@ void foo() {
//CHECK: #pragma acc parallel device_type(SomeStructImpl)
#pragma acc parallel device_type (SomeStructImpl)
while(true);

//CHECK: #pragma acc parallel reduction(+: iPtr)
#pragma acc parallel reduction(+: iPtr)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AST print tests are here

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@erichkeane erichkeane merged commit a15b685 into llvm:main May 21, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants