Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"

//===----------------------------------------------------------------------===//
// KnobOp
//===----------------------------------------------------------------------===//

def KnobOp : Op<Transform_Dialect, "tune.knob", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down Expand Up @@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
}

//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//

def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
NoRegionArguments
]> {
let summary = "Represents a choice among its regions, i.e. sub-schedules";

let description = [{
This op represents a choice over which of its regions is to be used.

When `selected_region` is provided, the semantics are that this op is to be
substituted for by the selected region, meaning the region's results become
the results of this op. Without a provided `selected_region`, the semantics
are that this non-deterministic choice is yet to be resolved -- which in
terms of the op's interpreted semantics is a failure.

The `selected_region` argument is either an `IntegerAttr` or a param holding
an `IntegerAttr`, which should provide a valid zero-based index with respect
to the number of alternatives, i.e. regions.
}];
let cppNamespace = [{ mlir::transform::tune }];

let arguments = (ins Builtin_StringAttr:$name,
OptionalAttr<APIntAttr>:$selected_region_attr,
Optional<TransformParamTypeInterface>:$selected_region_param);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);

let assemblyFormat = [{
`<` $name `>`
(`selected_region` `=` custom<AlternativesOpSelectedRegion>(
$selected_region_attr, $selected_region_param)^)?
attr-dict-with-keyword
(`:` type($selected_region_param)^)?
(`->` type($results)^)?
regions
}];

let hasVerifier = 1;
}

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
184 changes: 184 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"

using namespace mlir;

static ParseResult parseAlternativesOpSelectedRegion(
OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);

static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
Operation *op,
IntegerAttr selectedRegionAttr,
Value selectedRegionParam);

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"

Expand Down Expand Up @@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() {

return success();
}

//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//

static ParseResult parseAlternativesOpSelectedRegion(
OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
size_t selectedRegionIdx;
OptionalParseResult attrParseRes =
parser.parseOptionalInteger(selectedRegionIdx);
if (attrParseRes.has_value()) {
if (failed(*attrParseRes))
return failure();

selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
return success();
}

OpAsmParser::UnresolvedOperand param;
auto paramParseRes = parser.parseOptionalOperand(param);
if (paramParseRes.has_value()) {
if (failed(*paramParseRes))
return failure();

selectedRegionParam = param;
return success();
}

return parser.emitError(parser.getCurrentLocation())
<< "expected either an integer attribute or a transform.param operand";
}

static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
Operation *op,
IntegerAttr selectedRegionAttr,
Value selectedRegionParam) {
if (selectedRegionAttr)
printer << selectedRegionAttr.getValue();
if (selectedRegionParam)
printer << selectedRegionParam;
}

OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
RegionBranchPoint point) {
// No operands will be forwarded to the region(s).
return getOperands().slice(0, 0);
}

void transform::tune::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent())
if (auto selectedRegionIdx = getSelectedRegionAttr())
regions.emplace_back(
&getAlternatives()[selectedRegionIdx->getSExtValue()],
Block::BlockArgListType());
else
for (Region &alternative : getAlternatives())
regions.emplace_back(&alternative, Block::BlockArgListType());
else
regions.emplace_back(getOperation()->getResults());
}

void transform::tune::AlternativesOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
(void)operands;
bounds.reserve(getNumRegions());

if (auto selectedRegionIdx = getSelectedRegionAttr()) {
bounds.resize(getNumRegions(), InvocationBounds(0, 0));
bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
} else {
bounds.resize(getNumRegions(), InvocationBounds(0, 1));
}
}

void transform::tune::AlternativesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getSelectedRegionParamMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
// TODO: should effects from regions be forwarded?
}

DiagnosedSilenceableFailure
transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
std::optional<size_t> selectedRegionIdx;

if (auto selectedRegionAttr = getSelectedRegionAttr())
selectedRegionIdx = selectedRegionAttr->getSExtValue();

if (Value selectedRegionParam = getSelectedRegionParam()) {
ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
IntegerAttr selectedRegionAttr;
if (associatedAttrs.size() != 1 ||
!(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
return emitDefiniteFailure()
<< "param should hold exactly one integer attribute, got: "
<< associatedAttrs[0];
selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
}

if (!selectedRegionIdx)
return emitDefiniteFailure() << "non-deterministic choice " << getName()
<< " is only resolved through providing a "
"`selected_region` attr/param";

if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
return emitDefiniteFailure()
<< "'selected_region' attribute/param specifies region at index "
<< *selectedRegionIdx << " while op has only " << getNumRegions()
<< " regions";

Region &selectedRegion = getRegion(*selectedRegionIdx);
auto scope = state.make_region_scope(selectedRegion);
Block &block = selectedRegion.front();
// Apply the region's ops one by one.
for (Operation &transform : block.without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<transform::TransformOpInterface>(transform));
if (result.isDefiniteFailure())
return result;

if (result.isSilenceableFailure()) {
for (const auto &res : getResults())
results.set(res, {});
return result;
}
}
// Forward the operation mapping for values yielded from the region to the
// values produced by the alternatives op.
transform::detail::forwardTerminatorOperands(&block, state, results);
return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::tune::AlternativesOp::verify() {
for (auto *region : getRegions()) {
auto yieldTerminator =
llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
if (!yieldTerminator)
return emitOpError() << "expected '"
<< transform::YieldOp::getOperationName()
<< "' as terminator";

if (yieldTerminator->getNumOperands() != getNumResults())
return yieldTerminator.emitOpError()
<< "expected terminator to have as many operands as the parent op "
"has results";

for (auto [i, operandType, resultType] : llvm::zip_equal(
llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
yieldTerminator->getOperands().getType(), getResultTypes())) {
if (operandType == resultType)
continue;
return yieldTerminator.emitOpError()
<< "the type of the terminator operand #" << i
<< " must match the type of the corresponding parent op result ("
<< operandType << " vs " << resultType << ")";
}
}

if (auto selectedRegionAttr = getSelectedRegionAttr()) {
size_t regionIdx = selectedRegionAttr->getSExtValue();
if (regionIdx < 0 || regionIdx >= getNumRegions())
return emitOpError()
<< "'selected_region' attribute specifies region at index "
<< regionIdx << " while op has only " << getNumRegions()
<< " regions";
}

return success();
}
66 changes: 63 additions & 3 deletions mlir/python/mlir/dialects/transform/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from ...ir import (
Type,
Value,
Operation,
OpView,
Attribute,
ArrayAttr,
StringAttr,
Expand All @@ -19,7 +22,10 @@
from .._transform_tune_extension_ops_gen import _Dialect

try:
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

Expand All @@ -36,7 +42,7 @@ def __init__(
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
selected: Optional[Attribute] = None,
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
Expand Down Expand Up @@ -75,8 +81,62 @@ def knob(
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
selected: Optional[Attribute] = None,
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class AlternativesOp(AlternativesOp):
def __init__(
self,
results: Sequence[Type],
name: Union[StringAttr, str],
num_alternatives: int,
*,
selected_region: Optional[
Union[int, IntegerAttr, Value, Operation, OpView]
] = None,
loc=None,
ip=None,
):
if isinstance(name, str):
name = StringAttr.get(name)

selected_region_attr = selected_region_param = None
if isinstance(selected_region, IntegerAttr):
selected_region_attr = selected_region
elif isinstance(selected_region, int):
selected_region_attr = IntegerAttr.get(
IntegerType.get_signless(32), selected_region
)
elif isinstance(selected_region, (Value, Operation, OpView)):
selected_region_param = _get_op_result_or_value(selected_region)

super().__init__(
results,
name,
num_alternatives,
selected_region_attr=selected_region_attr,
selected_region_param=selected_region_param,
loc=loc,
ip=ip,
)
for region in self.regions:
region.blocks.append()


def alternatives(
results: Sequence[Type],
name: Union[StringAttr, str],
num_alternatives: int,
*,
selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
loc=None,
ip=None,
):
return AlternativesOp(
results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
)
Loading