diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h index 74e1d28ffac82..ba11259790676 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H #define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td index d68d451afac40..d095659fc4838 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td @@ -11,10 +11,15 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonAttrConstraints.td" +//===----------------------------------------------------------------------===// +// KnobOp +//===----------------------------------------------------------------------===// + def KnobOp : Op, DeclareOpInterfaceMethods, @@ -52,4 +57,53 @@ def KnobOp : Op` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)"; } +//===----------------------------------------------------------------------===// +// AlternativesOp +//===----------------------------------------------------------------------===// + +def AlternativesOp : Op, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">, + NoRegionArguments +]> { + let summary = "Represents a choice among its regions, i.e. sub-schedules"; + + let description = [{ + This op represents a choice over which of its regions is to be used. + + When `selected_region` is provided, the semantics are that this op is to be + substituted for by the selected region, meaning the region's results become + the results of this op. Without a provided `selected_region`, the semantics + are that this non-deterministic choice is yet to be resolved -- which in + terms of the op's interpreted semantics is a failure. + + The `selected_region` argument is either an `IntegerAttr` or a param holding + an `IntegerAttr`, which should provide a valid zero-based index with respect + to the number of alternatives, i.e. regions. + }]; + let cppNamespace = [{ mlir::transform::tune }]; + + let arguments = (ins Builtin_StringAttr:$name, + OptionalAttr:$selected_region_attr, + Optional:$selected_region_param); + let results = (outs Variadic:$results); + let regions = (region VariadicRegion>:$alternatives); + + let assemblyFormat = [{ + `<` $name `>` + (`selected_region` `=` custom( + $selected_region_attr, $selected_region_param)^)? + attr-dict-with-keyword + (`:` type($selected_region_param)^)? + (`->` type($results)^)? + regions + }]; + + let hasVerifier = 1; +} + #endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index 842e880ca9150..c627158e999ed 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -6,13 +6,24 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" using namespace mlir; +static ParseResult parseAlternativesOpSelectedRegion( + OpAsmParser &parser, IntegerAttr &selectedRegionAttr, + std::optional &selectedRegionParam); + +static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, + Operation *op, + IntegerAttr selectedRegionAttr, + Value selectedRegionParam); + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" @@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() { return success(); } + +//===----------------------------------------------------------------------===// +// AlternativesOp +//===----------------------------------------------------------------------===// + +static ParseResult parseAlternativesOpSelectedRegion( + OpAsmParser &parser, IntegerAttr &selectedRegionAttr, + std::optional &selectedRegionParam) { + size_t selectedRegionIdx; + OptionalParseResult attrParseRes = + parser.parseOptionalInteger(selectedRegionIdx); + if (attrParseRes.has_value()) { + if (failed(*attrParseRes)) + return failure(); + + selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx); + return success(); + } + + OpAsmParser::UnresolvedOperand param; + auto paramParseRes = parser.parseOptionalOperand(param); + if (paramParseRes.has_value()) { + if (failed(*paramParseRes)) + return failure(); + + selectedRegionParam = param; + return success(); + } + + return parser.emitError(parser.getCurrentLocation()) + << "expected either an integer attribute or a transform.param operand"; +} + +static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, + Operation *op, + IntegerAttr selectedRegionAttr, + Value selectedRegionParam) { + if (selectedRegionAttr) + printer << selectedRegionAttr.getValue(); + if (selectedRegionParam) + printer << selectedRegionParam; +} + +OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( + RegionBranchPoint point) { + // No operands will be forwarded to the region(s). + return getOperands().slice(0, 0); +} + +void transform::tune::AlternativesOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (point.isParent()) + if (auto selectedRegionIdx = getSelectedRegionAttr()) + regions.emplace_back( + &getAlternatives()[selectedRegionIdx->getSExtValue()], + Block::BlockArgListType()); + else + for (Region &alternative : getAlternatives()) + regions.emplace_back(&alternative, Block::BlockArgListType()); + else + regions.emplace_back(getOperation()->getResults()); +} + +void transform::tune::AlternativesOp::getRegionInvocationBounds( + ArrayRef operands, SmallVectorImpl &bounds) { + (void)operands; + bounds.reserve(getNumRegions()); + + if (auto selectedRegionIdx = getSelectedRegionAttr()) { + bounds.resize(getNumRegions(), InvocationBounds(0, 0)); + bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1); + } else { + bounds.resize(getNumRegions(), InvocationBounds(0, 1)); + } +} + +void transform::tune::AlternativesOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getSelectedRegionParamMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + // TODO: should effects from regions be forwarded? +} + +DiagnosedSilenceableFailure +transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + std::optional selectedRegionIdx; + + if (auto selectedRegionAttr = getSelectedRegionAttr()) + selectedRegionIdx = selectedRegionAttr->getSExtValue(); + + if (Value selectedRegionParam = getSelectedRegionParam()) { + ArrayRef associatedAttrs = state.getParams(selectedRegionParam); + IntegerAttr selectedRegionAttr; + if (associatedAttrs.size() != 1 || + !(selectedRegionAttr = dyn_cast(associatedAttrs[0]))) + return emitDefiniteFailure() + << "param should hold exactly one integer attribute, got: " + << associatedAttrs[0]; + selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue(); + } + + if (!selectedRegionIdx) + return emitDefiniteFailure() << "non-deterministic choice " << getName() + << " is only resolved through providing a " + "`selected_region` attr/param"; + + if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions()) + return emitDefiniteFailure() + << "'selected_region' attribute/param specifies region at index " + << *selectedRegionIdx << " while op has only " << getNumRegions() + << " regions"; + + Region &selectedRegion = getRegion(*selectedRegionIdx); + auto scope = state.make_region_scope(selectedRegion); + Block &block = selectedRegion.front(); + // Apply the region's ops one by one. + for (Operation &transform : block.without_terminator()) { + DiagnosedSilenceableFailure result = + state.applyTransform(cast(transform)); + if (result.isDefiniteFailure()) + return result; + + if (result.isSilenceableFailure()) { + for (const auto &res : getResults()) + results.set(res, {}); + return result; + } + } + // Forward the operation mapping for values yielded from the region to the + // values produced by the alternatives op. + transform::detail::forwardTerminatorOperands(&block, state, results); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::tune::AlternativesOp::verify() { + for (auto *region : getRegions()) { + auto yieldTerminator = + llvm::dyn_cast_if_present(region->front().back()); + if (!yieldTerminator) + return emitOpError() << "expected '" + << transform::YieldOp::getOperationName() + << "' as terminator"; + + if (yieldTerminator->getNumOperands() != getNumResults()) + return yieldTerminator.emitOpError() + << "expected terminator to have as many operands as the parent op " + "has results"; + + for (auto [i, operandType, resultType] : llvm::zip_equal( + llvm::seq(0, yieldTerminator->getNumOperands()), + yieldTerminator->getOperands().getType(), getResultTypes())) { + if (operandType == resultType) + continue; + return yieldTerminator.emitOpError() + << "the type of the terminator operand #" << i + << " must match the type of the corresponding parent op result (" + << operandType << " vs " << resultType << ")"; + } + } + + if (auto selectedRegionAttr = getSelectedRegionAttr()) { + size_t regionIdx = selectedRegionAttr->getSExtValue(); + if (regionIdx < 0 || regionIdx >= getNumRegions()) + return emitOpError() + << "'selected_region' attribute specifies region at index " + << regionIdx << " while op has only " << getNumRegions() + << " regions"; + } + + return success(); +} diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py index f63f88a382422..b3bfa8015c4d8 100644 --- a/mlir/python/mlir/dialects/transform/tune.py +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -6,6 +6,9 @@ from ...ir import ( Type, + Value, + Operation, + OpView, Attribute, ArrayAttr, StringAttr, @@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import _Dialect try: - from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -36,7 +42,7 @@ def __init__( ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): @@ -75,8 +81,62 @@ def knob( ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AlternativesOp(AlternativesOp): + def __init__( + self, + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[ + Union[int, IntegerAttr, Value, Operation, OpView] + ] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + selected_region_attr = selected_region_param = None + if isinstance(selected_region, IntegerAttr): + selected_region_attr = selected_region + elif isinstance(selected_region, int): + selected_region_attr = IntegerAttr.get( + IntegerType.get_signless(32), selected_region + ) + elif isinstance(selected_region, (Value, Operation, OpView)): + selected_region_param = _get_op_result_or_value(selected_region) + + super().__init__( + results, + name, + num_alternatives, + selected_region_attr=selected_region_attr, + selected_region_param=selected_region_param, + loc=loc, + ip=ip, + ) + for region in self.regions: + region.blocks.append() + + +def alternatives( + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None, + loc=None, + ip=None, +): + return AlternativesOp( + results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip + ) diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir index 2e5f433abeb71..efc3890288456 100644 --- a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir +++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir @@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // expected-error@below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}} + transform.tune.alternatives<"bifurcation"> selected_region = 2 { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %singleton_of_c0 = transform.param.constant [0] -> !transform.any_param + // expected-error@below {{param should hold exactly one integer attribute, got: [0]}} + transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %c0 = transform.param.constant 0 -> !transform.any_param + %c1 = transform.param.constant 1 -> !transform.any_param + %c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param + // expected-error@below {{param should hold exactly one integer attribute}} + transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %c2 = transform.param.constant 2 -> !transform.any_param + // expected-error@below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}} + transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // expected-error@below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}} + transform.tune.alternatives<"bifurcation"> { + transform.yield + }, { + transform.yield + } + transform.yield + } +} diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir index 0a253c6d5f837..5da48a2218ec6 100644 --- a/mlir/test/Dialect/Transform/test-tune-extension.mlir +++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir @@ -59,3 +59,129 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +// ----- + +// CHECK-LABEL: schedule_with_two_independent_choices_already_made +func.func @schedule_with_two_independent_choices_already_made( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK-NOT: scf.forall +// CHECK: scf.for +// CHECK-NOT: scf.for +// CHECK: scf.forall +// CHECK-NOT: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: linalg.matmul +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice +// CHECK: tensor.insert_slice +// CHECK: scf.yield + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op + { // First alternative/region, with index = 0 + %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + }, { // Second alternative/region, with index = 1 + %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + } + + transform.tune.alternatives<"inner_par_or_seq_tiling"> selected_region = 1 -> !transform.any_op { + %contained_matmul, %loop = transform.structured.tile_using_for %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + }, { + %contained_matmul, %loop = transform.structured.tile_using_forall %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + } + + transform.yield + } +} + +// ----- + +// CHECK-LABEL: subschedule_with_choice_resolved_in_main_schedule +func.func @subschedule_with_choice_resolved_in_main_schedule( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK-NOT: scf.for +// CHECK: scf.forall +// CHECK-NOT: scf.forall +// CHECK: scf.for +// CHECK-NOT: scf.forall +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: linalg.matmul +// CHECK: tensor.insert_slice +// CHECK: scf.yield +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @subschedule_with_embedded_choice(%matmul: !transform.any_op {transform.readonly}, + %par_or_seq: !transform.param {transform.readonly}, + %tile_size: !transform.param {transform.readonly}) -> !transform.any_op { + %tiled_matmul = transform.tune.alternatives<"par_or_seq_tiling"> selected_region = %par_or_seq : !transform.param -> !transform.any_op { + %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + }, { + %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + } + transform.yield %tiled_matmul : !transform.any_op + } + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %outer_par = transform.param.constant 1 -> !transform.param + %outer_tile_size = transform.param.constant 32 -> !transform.param + %inner_seq = transform.tune.knob<"inner_par_or_seq"> = 0 from options = [0, 1] -> !transform.param + %inner_tile_size = transform.param.constant 8 -> !transform.param + %tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%matmul, %outer_par, %outer_tile_size) : (!transform.any_op, !transform.param, !transform.param) -> !transform.any_op + %tiled_tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%tiled_matmul, %inner_seq, %inner_tile_size) : (!transform.any_op, !transform.param, !transform.param) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: eeny_meeny_miny_moe +func.func private @eeny_meeny_miny_moe() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %tiled_matmul = transform.tune.alternatives<"4way"> selected_region = 3 -> !transform.any_param + { // First alternative/region, with index = 0 + %out = transform.param.constant "eeny" -> !transform.any_param + transform.yield %out : !transform.any_param + }, { // Second alternative/region, with index = 1 + %out = transform.param.constant "meeny" -> !transform.any_param + transform.yield %out : !transform.any_param + }, { // Third alternative/region, with index = 2 + %out = transform.param.constant "miny" -> !transform.any_param + transform.yield %out : !transform.any_param + }, { // Fourth alternative/region, with index = 3 + %out = transform.param.constant "moe" -> !transform.any_param + transform.yield %out : !transform.any_param + } + transform.yield + } +} \ No newline at end of file diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py index dfb93594bca52..eb2a083211ef7 100644 --- a/mlir/test/python/dialects/transform_tune_ext.py +++ b/mlir/test/python/dialects/transform_tune_ext.py @@ -1,21 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.ir import * +from mlir import ir from mlir.dialects import transform from mlir.dialects.transform import tune, debug def run(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): + print("\n// TEST:", f.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) - with InsertionPoint(sequence.body): + with ir.InsertionPoint(sequence.body): f(sequence.bodyTarget) transform.YieldOp() print(module) @@ -29,10 +29,10 @@ def testKnobOp(target): # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param heads_or_tails = tune.KnobOp( - result=any_param, name=StringAttr.get("coin"), options=[True, False] + result=any_param, name=ir.StringAttr.get("coin"), options=[True, False] ) # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param - tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()]) + tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()]) # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32]) # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param @@ -45,7 +45,10 @@ def testKnobOp(target): heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True) # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param tune.KnobOp( - any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog" + any_param, + name="animal", + options=["cat", "dog", ir.UnitAttr.get()], + selected="dog", ) # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8) @@ -57,16 +60,90 @@ def testKnobOp(target): # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified. - i64 = IntegerType.get_signless(64) + i64 = ir.IntegerType.get_signless(64) tune.knob( any_param, "range_as_a_dict", - DictAttr.get( + ir.DictAttr.get( { - "start": IntegerAttr.get(i64, 2), - "stop": IntegerAttr.get(i64, 16), - "step": IntegerAttr.get(i64, 2), + "start": ir.IntegerAttr.get(i64, 2), + "stop": ir.IntegerAttr.get(i64, 16), + "step": ir.IntegerAttr.get(i64, 2), } ), selected=4, ) + + +# CHECK-LABEL: TEST: testAlternativesOp +@run +def testAlternativesOp(target): + any_param = transform.AnyParamType.get() + + # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param { + left_or_right = tune.AlternativesOp( + [transform.AnyParamType.get()], "left_or_right", 2 + ) + idx_for_left, idx_for_right = 0, 1 + with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]): + # CHECK: %[[C0:.*]] = transform.param.constant 0 + i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) + # CHECK: transform.yield %[[C0]] + transform.yield_(c0) + # CHECK-NEXT: }, { + with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]): + # CHECK: %[[C1:.*]] = transform.param.constant 1 + i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) + # CHECK: transform.yield %[[C1]] + transform.yield_(c1) + # CHECK-NEXT: } + outcome_of_left_or_right_decision = left_or_right.results[0] + + # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param { + fork_in_the_road = tune.AlternativesOp( + [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0 + ) + with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]): + # CHECK: %[[C0:.*]] = transform.param.constant 0 + i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) + # CHECK: transform.yield %[[C0]] + transform.yield_(c0) + # CHECK-NEXT: }, { + with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]): + # CHECK: %[[C1:.*]] = transform.param.constant 1 + i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) + # CHECK: transform.yield %[[C1]] + transform.yield_(c1) + # CHECK-NEXT: } + + # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param { + left_or_right_as_before = tune.AlternativesOp( + [], + "left_or_right_as_before", + 2, + selected_region=outcome_of_left_or_right_decision, + ) + with ir.InsertionPoint( + left_or_right_as_before.alternatives[idx_for_left].blocks[0] + ): + # CHECK: transform.param.constant 1337 + i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337) + c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337) + # CHECK: transform.debug.emit_param_as_remark + debug.emit_param_as_remark(c1337) + transform.yield_([]) + # CHECK-NEXT: }, { + with ir.InsertionPoint( + left_or_right_as_before.alternatives[idx_for_right].blocks[0] + ): + # CHECK: transform.param.constant 42 + i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) + c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42) + # CHECK: transform.debug.emit_param_as_remark + debug.emit_param_as_remark(c42) + transform.yield_([]) + # CHECK-NEXT: }