Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Python] Python binding support for AffineIfOp #108323

Merged
merged 1 commit into from
Nov 13, 2024

Conversation

kaitingwang
Copy link
Contributor

Fix the AffineIfOp's default builder such that it takes in an IntegerSetAttr. AffineIfOp has skipDefaultBuilders=1 which effectively skips the creation of the default AffineIfOp::builder on the C++ side. (AffineIfOp has two custom OpBuilder defined in the extraClassDeclaration.) However, on the python side, _affine_ops_gen.py shows that the default builder is being created, but it does not accept IntegerSet and thus is useless. This fix at line 411 makes the default python AffineIfOp builder take in an IntegerSet input and does not impact the C++ side of things.

@llvmbot
Copy link

llvmbot commented Sep 12, 2024

@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir

Author: Amy Wang (kaitingwang)

Changes

Fix the AffineIfOp's default builder such that it takes in an IntegerSetAttr. AffineIfOp has skipDefaultBuilders=1 which effectively skips the creation of the default AffineIfOp::builder on the C++ side. (AffineIfOp has two custom OpBuilder defined in the extraClassDeclaration.) However, on the python side, _affine_ops_gen.py shows that the default builder is being created, but it does not accept IntegerSet and thus is useless. This fix at line 411 makes the default python AffineIfOp builder take in an IntegerSet input and does not impact the C++ side of things.


Full diff: https://github.com/llvm/llvm-project/pull/108323.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+2-1)
  • (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+9)
  • (modified) mlir/python/mlir/dialects/affine.py (+58)
  • (modified) mlir/test/python/dialects/affine.py (+70)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index dbec741cf1b1f3..a7fbebed9aa586 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
     }
     ```
   }];
-  let arguments = (ins Variadic<AnyType>);
+  let arguments = (ins Variadic<AnyType>,
+                       IntegerSetAttr:$condition);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
 
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..9fe90c622147ae 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -557,6 +557,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
   let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
 }
 
+// Attributes containing integer sets.
+def IntegerSetAttr : Attr<
+CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
+  let storageType = [{::mlir::IntegerSetAttr }];
+  let returnType = [{ ::mlir::IntegerSet }];
+  let valueType = NoneType;
+  let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
+}
+
 // Base class for array attributes.
 class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
   let storageType = [{ ::mlir::ArrayAttr }];
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 913cea61105cee..dcd4e7d2c232c2 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -156,3 +156,61 @@ def for_(
             yield iv, iter_args[0]
         else:
             yield iv
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class AffineIfOp(AffineIfOp):
+    """Specialization for the Affine if op class."""
+
+    def __init__(
+        self,
+        cond: IntegerSet,
+        results_: Optional[Type] = None,
+        *,
+        cond_operands: Optional[_VariadicResultValueT] = None,
+        hasElse: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an Affine `if` operation.
+
+        - `cond` is the integer set used to determine which regions of code
+          will be executed.
+        - `results` are the list of types to be yielded by the operand.
+        - `cond_operands` is the list of arguments to substitute the
+          dimensions, then symbols in the `cond` integer set expression to
+          determine whether they are in the set.
+        - `hasElse` determines whether the affine if operation has the else
+          branch.
+        """
+        if results_ is None:
+            results_ = []
+        if cond_operands is None:
+            cond_operands = []
+
+        if not (actual_n_inputs := len(cond_operands)) == (
+            exp_n_inputs := cond.n_inputs
+        ):
+            raise ValueError(
+                f"expected {exp_n_inputs} condition operands, got {actual_n_inputs}"
+            )
+
+        operands = []
+        operands.extend(cond_operands)
+        results = []
+        results.extend(results_)
+
+        super().__init__(results, cond_operands, cond)
+        self.regions[0].blocks.append(*[])
+        if hasElse:
+            self.regions[1].blocks.append(*[])
+
+    @property
+    def then_block(self):
+        """Returns the then block of the if operation."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def else_block(self):
+        """Returns the else block of the if operation."""
+        return self.regions[1].blocks[0]
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 6f39e1348fcd57..24fe4f398f4e87 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -265,3 +265,73 @@ def range_loop_8(lb, ub, memref_v):
             add = arith.addi(i, i)
             memref.store(add, it, [i])
             affine.yield_([it])
+
+
+# CHECK-LABEL: TEST: testAffineIfWithoutElse
+@constructAndPrintInModule
+def testAffineIfWithoutElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if(
+    # CHECK-SAME:                        %[[VAL_0:.*]]: index) {
+    # CHECK:            affine.if #[[$SET0]](%[[VAL_0]]) {
+    # CHECK:                %[[VAL_1:.*]] = arith.constant 1 : i32
+    # CHECK:                %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
+    # CHECK:            }
+    # CHECK:            return
+    # CHECK:        }
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if(cond_operands):
+        if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
+        with InsertionPoint(if_op.then_block):
+            one = arith.ConstantOp(i32, 1)
+            add = arith.AddIOp(one, one)
+            affine.AffineYieldOp([])
+        return
+
+
+# CHECK-LABEL: TEST: testAffineIfWithElse
+@constructAndPrintInModule
+def testAffineIfWithElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if_else(
+    # CHECK-SAME:                                    %[[VAL_0:.*]]: index) {
+    # CHECK:            %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
+    # CHECK:                %[[VAL_XT:.*]] = arith.constant 0 : i32
+    # CHECK:                %[[VAL_YT:.*]] = arith.constant 1 : i32
+    # CHECK:                affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
+    # CHECK:            } else {
+    # CHECK:                %[[VAL_XF:.*]] = arith.constant 2 : i32
+    # CHECK:                %[[VAL_YF:.*]] = arith.constant 3 : i32
+    # CHECK:                affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
+    # CHECK:            }
+    # CHECK:            %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
+    # CHECK:            return
+    # CHECK:        }
+
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if_else(cond_operands):
+        if_op = affine.AffineIfOp(
+            cond, [i32, i32], cond_operands=[cond_operands], hasElse=True
+        )
+        with InsertionPoint(if_op.then_block):
+            x_true = arith.ConstantOp(i32, 0)
+            y_true = arith.ConstantOp(i32, 1)
+            affine.AffineYieldOp([x_true, y_true])
+        with InsertionPoint(if_op.else_block):
+            x_false = arith.ConstantOp(i32, 2)
+            y_false = arith.ConstantOp(i32, 3)
+            affine.AffineYieldOp([x_false, y_false])
+        add = arith.AddIOp(if_op.results[0], if_op.results[1])
+        return

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with comments addressed

mlir/python/mlir/dialects/affine.py Outdated Show resolved Hide resolved
mlir/python/mlir/dialects/affine.py Outdated Show resolved Hide resolved
mlir/python/mlir/dialects/affine.py Outdated Show resolved Hide resolved
mlir/python/mlir/dialects/affine.py Outdated Show resolved Hide resolved
mlir/python/mlir/dialects/affine.py Outdated Show resolved Hide resolved
mlir/python/mlir/dialects/affine.py Show resolved Hide resolved
@kaitingwang
Copy link
Contributor Author

@ftynse Much appreciated the reviewed! Nice meeting you at the llvm developer conference last month! In summary, I addressed all the points you raised: removed the walrus assignment for readability, renamed the has_else variable, checked for empty region[1].

@kaitingwang kaitingwang merged commit d50fbe4 into llvm:main Nov 13, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:affine mlir:core MLIR Core Infrastructure mlir:ods mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants