Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][openacc] Update verifier to catch missing device type attribute #111586

Merged
merged 1 commit into from
Oct 9, 2024

Conversation

clementval
Copy link
Contributor

Operands with device_type support need the corresponding attribute but this was not catches in the verifier if it was missing. The custom parser usually constructs it but creating the op from python could lead to a segfault in the printer. This patch updates the verifier so we catch this early on.

Operands with device_type support need the corresponding attribute but this
was not catches in the verifier if it was missing. The custom parser usually
constructs it but creating the op from python could lead to a segfault in the
printer. This patch updates the verifier so we catch this early on.
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Operands with device_type support need the corresponding attribute but this was not catches in the verifier if it was missing. The custom parser usually constructs it but creating the op from python could lead to a segfault in the printer. This patch updates the verifier so we catch this early on.


Full diff: https://github.com/llvm/llvm-project/pull/111586.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+14-11)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+7)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 877bd226a03528..919a0853fb6049 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -759,20 +759,23 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
     Op op, OperandRange operands, DenseI32ArrayAttr segments,
     ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
   std::size_t numOperandsInSegments = 0;
-
-  if (!segments)
-    return success();
-
-  for (auto segCount : segments.asArrayRef()) {
-    if (maxInSegment != 0 && segCount > maxInSegment)
-      return op.emitOpError() << keyword << " expects a maximum of "
-                              << maxInSegment << " values per segment";
-    numOperandsInSegments += segCount;
+  std::size_t nbOfSegments = 0;
+
+  if (segments) {
+    for (auto segCount : segments.asArrayRef()) {
+      if (maxInSegment != 0 && segCount > maxInSegment)
+        return op.emitOpError() << keyword << " expects a maximum of "
+                                << maxInSegment << " values per segment";
+      numOperandsInSegments += segCount;
+      ++nbOfSegments;
+    }
   }
-  if (numOperandsInSegments != operands.size())
+
+  if ((numOperandsInSegments != operands.size()) ||
+      (!deviceTypes && !operands.empty()))
     return op.emitOpError()
            << keyword << " operand count does not match count in segments";
-  if (deviceTypes.getValue().size() != (size_t)segments.size())
+  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
     return op.emitOpError()
            << keyword << " segment count does not match device_type count";
   return success();
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index ec5430420524ce..96edb585ae21a2 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -507,6 +507,13 @@ acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64va
 
 // -----
 
+%0 = "arith.constant"() <{value = 1 : i64}> : () -> i64
+// expected-error@+1 {{num_gangs operand count does not match count in segments}}
+"acc.parallel"(%0) <{numGangsSegments = array<i32: 1>, operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>}> ({
+}) : (i64) -> ()
+
+// -----
+
 %i64value = arith.constant 1 : i64
 acc.parallel {
 // expected-error@+1 {{'acc.set' op cannot be nested in a compute operation}}

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@clementval clementval merged commit 65bd5ed into llvm:main Oct 9, 2024
13 checks passed
@clementval clementval deleted the acc_num_gang_verifier branch October 9, 2024 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants