Skip to content

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Sep 25, 2025

This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a selected_region argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with the op's interpreted semantics corresponding to exactly this.

This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs on transforms, going further by making it tunable which (sequences of) transforms are to be applied.

@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a selected_region argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with its interpreted semantics corresponding to exactly this.

This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs on transforms -- going further by making it tunable which (sequences of) transforms are to be applied.


Patch is 30.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160724.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td (+54)
  • (modified) mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp (+185)
  • (modified) mlir/python/mlir/dialects/transform/tune.py (+63-3)
  • (modified) mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir (+85)
  • (modified) mlir/test/Dialect/Transform/test-tune-extension.mlir (+99)
  • (modified) mlir/test/python/dialects/transform_tune_ext.py (+73-14)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
index 74e1d28ffac82..ba11259790676 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
 #define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
 
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d68d451afac40..d095659fc4838 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -11,10 +11,15 @@
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/CommonAttrConstraints.td"
 
+//===----------------------------------------------------------------------===//
+// KnobOp
+//===----------------------------------------------------------------------===//
+
 def KnobOp : Op<Transform_Dialect, "tune.knob", [
   DeclareOpInterfaceMethods<TransformOpInterface>,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
       "`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
 }
 
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
+  DeclareOpInterfaceMethods<RegionBranchOpInterface,
+        ["getEntrySuccessorOperands", "getSuccessorRegions",
+         "getRegionInvocationBounds"]>,
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+  NoRegionArguments
+]> {
+  let summary = "Represents a choice among its regions, i.e. sub-schedules";
+
+  let description = [{
+    This op represents a choice over which of its regions is to be used.
+
+    When `selected_region` is provided, the semantics are that this op is to be
+    substituted for by the selected region, meaning the region's results become
+    the results of this op. Without a provided `selected_region`, the semantics
+    are that this non-deterministic choice is yet to be resolved -- which in
+    terms of the op's interpreted semantics is a failure.
+
+    The `selected_region` argument is either an `IntegerAttr` or a param holding
+    an `IntegerAttr`, which should provide a valid zero-based index with respect
+    to the number of alternatives, i.e. regions.
+  }];
+  let cppNamespace = [{ mlir::transform::tune }];
+
+  let arguments = (ins Builtin_StringAttr:$name,
+                       OptionalAttr<APIntAttr>:$selected_region_attr,
+                       Optional<TransformParamTypeInterface>:$selected_region_param);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
+
+  let assemblyFormat = [{
+    `<` $name `>`
+    (`selected_region` `=` custom<AlternativesOpSelectedRegion>(
+        $selected_region_attr, $selected_region_param)^)?
+    attr-dict-with-keyword
+    (`:` type($selected_region_param)^)?
+    (`->` type($results)^)?
+    regions
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 842e880ca9150..dad63586bc8d4 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -6,13 +6,25 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
 #include "llvm/Support/Debug.h"
+#include <cstddef>
 
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
 
 using namespace mlir;
 
+static ParseResult parseAlternativesOpSelectedRegion(
+    OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+    std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+                                              Operation *op,
+                                              IntegerAttr selectedRegionAttr,
+                                              Value selectedRegionParam);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
 
@@ -57,3 +69,176 @@ LogicalResult transform::tune::KnobOp::verify() {
 
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAlternativesOpSelectedRegion(
+    OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+    std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
+  size_t selectedRegionIdx;
+  OptionalParseResult attrParseRes =
+      parser.parseOptionalInteger(selectedRegionIdx);
+  if (attrParseRes.has_value()) {
+    if (failed(*attrParseRes))
+      return failure();
+
+    selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand param;
+  auto paramParseRes = parser.parseOptionalOperand(param);
+  if (paramParseRes.has_value()) {
+    if (failed(*paramParseRes))
+      return failure();
+
+    selectedRegionParam = param;
+    return success();
+  }
+
+  return parser.emitError(parser.getCurrentLocation())
+         << "expected either an integer attribute or a transform.param operand";
+}
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+                                              Operation *op,
+                                              IntegerAttr selectedRegionAttr,
+                                              Value selectedRegionParam) {
+  if (selectedRegionAttr)
+    printer << selectedRegionAttr.getValue();
+  if (selectedRegionParam)
+    printer << selectedRegionParam;
+}
+
+OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
+    RegionBranchPoint point) {
+  // No operands will be forwarded to the region(s).
+  return getOperands().slice(0, 0);
+}
+
+void transform::tune::AlternativesOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent())
+    if (auto selectedRegionIdx = getSelectedRegionAttr())
+      regions.emplace_back(
+          &getAlternatives()[selectedRegionIdx->getSExtValue()],
+          Block::BlockArgListType());
+    else
+      for (Region &alternative : getAlternatives())
+        regions.emplace_back(&alternative, Block::BlockArgListType());
+  else
+    regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::tune::AlternativesOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  (void)operands;
+  bounds.reserve(getNumRegions());
+
+  if (auto selectedRegionIdx = getSelectedRegionAttr()) {
+    bounds.resize(getNumRegions(), InvocationBounds(0, 0));
+    bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
+  } else {
+    bounds.resize(getNumRegions(), InvocationBounds(0, 1));
+  }
+}
+
+void transform::tune::AlternativesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getSelectedRegionParamMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  // TODO: should effects from regions be forwarded?
+}
+
+DiagnosedSilenceableFailure
+transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  std::optional<size_t> selectedRegionIdx;
+
+  if (auto selectedRegionAttr = getSelectedRegionAttr())
+    selectedRegionIdx = selectedRegionAttr->getSExtValue();
+
+  if (Value selectedRegionParam = getSelectedRegionParam()) {
+    ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
+    IntegerAttr selectedRegionAttr;
+    if (associatedAttrs.size() != 1 ||
+        !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
+      return emitDefiniteFailure()
+             << "param should hold exactly one integer attribute, got: "
+             << associatedAttrs[0];
+    selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
+  }
+
+  if (!selectedRegionIdx)
+    return emitDefiniteFailure() << "non-deterministic choice " << getName()
+                                 << " is only resolved through providing a "
+                                    "`selected_region` attr/param";
+
+  if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
+    return emitDefiniteFailure()
+           << "'selected_region' attribute/param specifies region at index "
+           << *selectedRegionIdx << " while op has only " << getNumRegions()
+           << " regions";
+
+  Region &selectedRegion = getRegion(*selectedRegionIdx);
+  auto scope = state.make_region_scope(selectedRegion);
+  Block &block = selectedRegion.front();
+  // Apply the region's ops one by one.
+  for (Operation &transform : block.without_terminator()) {
+    DiagnosedSilenceableFailure result =
+        state.applyTransform(cast<transform::TransformOpInterface>(transform));
+    if (result.isDefiniteFailure())
+      return result;
+
+    if (result.isSilenceableFailure()) {
+      for (const auto &res : getResults())
+        results.set(res, {});
+      return result;
+    }
+  }
+  // Forward the operation mapping for values yielded from the region to the
+  // values produced by the alternatives op.
+  transform::detail::forwardTerminatorOperands(&block, state, results);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::tune::AlternativesOp::verify() {
+  for (auto *region : getRegions()) {
+    auto yieldTerminator =
+        llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
+    if (!yieldTerminator)
+      return emitOpError() << "expected '"
+                           << transform::YieldOp::getOperationName()
+                           << "' as terminator";
+
+    if (yieldTerminator->getNumOperands() != getNumResults())
+      return yieldTerminator.emitOpError()
+             << "expected terminator to have as many operands as the parent op "
+                "has results";
+
+    for (auto [i, operandType, resultType] : llvm::zip_equal(
+             llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+             yieldTerminator->getOperands().getType(), getResultTypes())) {
+      if (operandType == resultType)
+        continue;
+      return yieldTerminator.emitOpError()
+             << "the type of the terminator operand #" << i
+             << " must match the type of the corresponding parent op result ("
+             << operandType << " vs " << resultType << ")";
+    }
+  }
+
+  if (auto selectedRegionAttr = getSelectedRegionAttr()) {
+    size_t regionIdx = selectedRegionAttr->getSExtValue();
+    if (regionIdx < 0 || regionIdx >= getNumRegions())
+      return emitOpError()
+             << "'selected_region' attribute specifies region at index "
+             << regionIdx << " while op has only " << getNumRegions()
+             << " regions";
+  }
+
+  return success();
+}
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
index f63f88a382422..b3bfa8015c4d8 100644
--- a/mlir/python/mlir/dialects/transform/tune.py
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -6,6 +6,9 @@
 
 from ...ir import (
     Type,
+    Value,
+    Operation,
+    OpView,
     Attribute,
     ArrayAttr,
     StringAttr,
@@ -19,7 +22,10 @@
 from .._transform_tune_extension_ops_gen import _Dialect
 
 try:
-    from .._ods_common import _cext as _ods_cext
+    from .._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
@@ -36,7 +42,7 @@ def __init__(
             ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
         ],
         *,
-        selected: Optional[Attribute] = None,
+        selected: Optional[Union[Attribute, bool, int, float, str]] = None,
         loc=None,
         ip=None,
     ):
@@ -75,8 +81,62 @@ def knob(
         ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
     ],
     *,
-    selected: Optional[Attribute] = None,
+    selected: Optional[Union[Attribute, bool, int, float, str]] = None,
     loc=None,
     ip=None,
 ):
     return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class AlternativesOp(AlternativesOp):
+    def __init__(
+        self,
+        results: Sequence[Type],
+        name: Union[StringAttr, str],
+        num_alternatives: int,
+        *,
+        selected_region: Optional[
+            Union[int, IntegerAttr, Value, Operation, OpView]
+        ] = None,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(name, str):
+            name = StringAttr.get(name)
+
+        selected_region_attr = selected_region_param = None
+        if isinstance(selected_region, IntegerAttr):
+            selected_region_attr = selected_region
+        elif isinstance(selected_region, int):
+            selected_region_attr = IntegerAttr.get(
+                IntegerType.get_signless(32), selected_region
+            )
+        elif isinstance(selected_region, (Value, Operation, OpView)):
+            selected_region_param = _get_op_result_or_value(selected_region)
+
+        super().__init__(
+            results,
+            name,
+            num_alternatives,
+            selected_region_attr=selected_region_attr,
+            selected_region_param=selected_region_param,
+            loc=loc,
+            ip=ip,
+        )
+        for region in self.regions:
+            region.blocks.append()
+
+
+def alternatives(
+    results: Sequence[Type],
+    name: Union[StringAttr, str],
+    num_alternatives: int,
+    *,
+    selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
+    loc=None,
+    ip=None,
+):
+    return AlternativesOp(
+        results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
+    )
diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
index 2e5f433abeb71..efc3890288456 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
@@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // expected-error@below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}}
+    transform.tune.alternatives<"bifurcation"> selected_region = 2 {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %singleton_of_c0 = transform.param.constant [0] -> !transform.any_param
+    // expected-error@below {{param should hold exactly one integer attribute, got: [0]}}
+    transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %c0 = transform.param.constant 0 -> !transform.any_param
+    %c1 = transform.param.constant 1 -> !transform.any_param
+    %c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param
+    // expected-error@below {{param should hold exactly one integer attribute}}
+    transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %c2 = transform.param.constant 2 -> !transform.any_param
+    // expected-error@below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}}
+    transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // expected-error@below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}}
+    transform.tune.alternatives<"bifurcation"> {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
index 0a253c6d5f837..80b7525136b33 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -59,3 +59,102 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+
+// -----
+
+// CHECK-LABEL: schedule_with_two_independent_choices_already_made
+func.func @schedule_with_two_independent_choices_already_made(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32> {
+//      CHECK-NOT: scf.forall
+//      CHECK:     scf.for
+//      CHECK-NOT:   scf.for
+//      CHECK:       scf.forall
+//      CHECK-NOT:   scf.for
+//      CHECK:         tensor.extract_slice
+//      CHECK:         tensor.extract_slice
+//      CHECK:         tensor.extract_slice
+//      CHECK:         linalg.matmul
+//      CHECK:         scf.forall.in_parallel
+//      CHECK:           tensor.parallel_insert_slice
+//      CHECK:       tensor.insert_slice
+//      CHECK:       scf.yield
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    %tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op
+    { // First alternative/region, with index = 0
+...
[truncated]

@github-actions
Copy link

github-actions bot commented Sep 25, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link
Contributor

@fschlimb fschlimb Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to parse an optional integer? Shouldn't the assemblyFormat already deal with optionality?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assembly format makes it optional whether there is a selected_region = ?? parse. This function is to resolve the ?? part.

I believe I need to ask the parser to do an "optional" parse of the integer as it should be allowed to fail in case the user instead provides an SSA value (i.e. on non-optional parse, seeing the "%" would cause non-catchable failure.).

If there's a more elegant way of supporting both the int as the ?? and a param SSA-value as the ??, do let me know!

Copy link
Contributor

@fschlimb fschlimb Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something here.
What does selected_region = ?? mean? There is either a valid value or the whole expression does not need to be there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(`selected_region` `=` custom<AlternativesOpSelectedRegion>($selected_region_attr, $selected_region_param)^)? makes it so that the entire sequence selected_region = ?? is parsed optionally, where ?? is my notation for a wildcard. In this case, ?? only should accept one of two "tokens": either an integer literal or a SSA-value. parseAlternativesOpSelectedRegion is the function that implements this parsing, accepting either the literal or the SSA-value.

Because either an integer or SSA-value is valid, when parseAlternativesOpSelectedRegion first tries to parse the integer it needs to do so "optionally" so that it's not an outright failure when an integer is not encountered. This has little to do with the optionality of the overall clause, i.e. whether it was prefixed with "selected_region =". That is, it is equally valid to use custom<AlternativesOpSelectedRegion>($selected_region_attr, $selected_region_param) not wrapped in an optional clause, i.e. not wrapped in parentheses suffixed with a ?.

Hope that helps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the bit of relevant documentation: https://mlir.llvm.org/docs/DefiningDialects/Operations/#optional-groups

The bit about anchors might also be helpful (i.e. an anchor determines whether the group has been committed to, i.e. whether the rest of the parse of the group must succeed for the overall parse to succeed).

Copy link
Contributor

@fschlimb fschlimb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!
You might want to add a test with more than 2 alternatives.

@rolfmorel
Copy link
Contributor Author

Thanks, @fschlimb -- added a test with more alternatives!

If there aren't any further comments, I will merge this soon.

This op enables expressing uncertainty regarding what should be at
particular places in the transform-dialect schedules. In particular, it
enables representing a choice amond alternative region. A choice
resolved through providing a `selected_region` argument. When this
argument is provided, the semantics are such that it is valid to rewrite
the op through substituting in the selected region -- with its
interpreted semantics corresponding to exactly this.

This op represents another piece of the puzzle w.r.t. a toolkit for
expressing autotuning problems with the transform dialect. Note that
this goes beyond tuning knobs _on_ transforms, going further by making
it tunable which (sequences of) transforms are to be applied.
@rolfmorel rolfmorel force-pushed the transform.tune.alternatives branch from d44d8db to 085bdf3 Compare October 1, 2025 13:18
@rolfmorel rolfmorel enabled auto-merge (squash) October 1, 2025 13:19
@rolfmorel rolfmorel merged commit f4d18c0 into llvm:main Oct 1, 2025
8 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…lvm#160724)

This op enables expressing uncertainty regarding what should be
happening at particular places in transform-dialect schedules. In
particular, it enables representing a choice among alternative regions.
This choice is resolved through providing a `selected_region` argument.
When this argument is provided, the semantics are such that it is valid
to rewrite the op through substituting in the selected region -- with
the op's interpreted semantics corresponding to exactly this.

This op represents another piece of the puzzle w.r.t. a toolkit for
expressing autotuning problems with the transform dialect. Note that
this goes beyond tuning knobs _on_ transforms, going further by making
it tunable which (sequences of) transforms are to be applied.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants