diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt index e70479b2a39f2..eb91ceccd4ef2 100644 --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -4,5 +4,6 @@ add_subdirectory(IR) add_subdirectory(IRDLExtension) add_subdirectory(LoopExtension) add_subdirectory(PDLExtension) +add_subdirectory(SMTExtension) add_subdirectory(Transforms) add_subdirectory(TuneExtension) diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt new file mode 100644 index 0000000000000..da037c1e809de --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td) +mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls) +mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen) + +add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h new file mode 100644 index 0000000000000..7079873cec048 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h @@ -0,0 +1,27 @@ +//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H +#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +class DialectRegistry; + +namespace transform { +/// Registers the SMT extension of the Transform dialect in the given registry. +void registerSMTExtension(DialectRegistry &dialectRegistry); +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h new file mode 100644 index 0000000000000..fc69b039f24ff --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h @@ -0,0 +1,21 @@ +//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H +#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td new file mode 100644 index 0000000000000..b987cb31e54bb --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td @@ -0,0 +1,52 @@ +//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS +#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def ConstrainParamsOp : Op, + DeclareOpInterfaceMethods, + NoTerminator +]> { + let cppNamespace = [{ mlir::transform::smt }]; + + let summary = "Express contraints on params interpreted as symbolic values"; + let description = [{ + Allows expressing constraints on params using the SMT dialect. + + Each Transform dialect param provided as an operand has a corresponding + argument of SMT-type in the region. The SMT-Dialect ops in the region use + these arguments as operands. + + The semantics of this op is that all the ops in the region together express + a constraint on the params-interpreted-as-smt-vars. The op fails in case the + expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the + op succeeds. + + --- + + TODO: currently the operational semantics per the Transform interpreter is + to always fail. The intention is build out support for hooking in your own + operational semantics so you can invoke your favourite solver to determine + satisfiability of the corresponding constraint problem. + }]; + + let arguments = (ins Variadic:$params); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = + "`(` $params `)` attr-dict `:` type(operands) $body"; + + let hasVerifier = 1; +} + +#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index 3123e3bdda496..0d1d9e89f92f6 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors; static void populateDialectSMTSubmodule(nanobind::module_ &m) { - auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) - .def_classmethod( - "get", - [](const nb::object &, MlirContext context) { - return mlirSMTTypeGetBool(context); - }, - "cls"_a, "context"_a = nb::none()); + auto smtBoolType = + mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) + .def_staticmethod( + "get", + [](MlirContext context) { return mlirSMTTypeGetBool(context); }, + "context"_a = nb::none()); auto smtBitVectorType = mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector) - .def_classmethod( + .def_staticmethod( "get", - [](const nb::object &, int32_t width, MlirContext context) { + [](int32_t width, MlirContext context) { return mlirSMTTypeGetBitVector(context, width); }, - "cls"_a, "width"_a, "context"_a = nb::none()); + "width"_a, "context"_a = nb::none()); + auto smtIntType = + mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt) + .def_staticmethod( + "get", + [](MlirContext context) { return mlirSMTTypeGetInt(context); }, + "context"_a = nb::none()); auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, bool indentLetBody) { diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt index 6e628353258d6..123c4b92271fe 100644 --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(IR) add_subdirectory(IRDLExtension) add_subdirectory(LoopExtension) add_subdirectory(PDLExtension) +add_subdirectory(SMTExtension) add_subdirectory(Transforms) add_subdirectory(TuneExtension) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt new file mode 100644 index 0000000000000..ba1cc464e506d --- /dev/null +++ b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRTransformSMTExtension + SMTExtension.cpp + SMTExtensionOps.cpp + + DEPENDS + MLIRTransformDialectSMTExtensionOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransformDialect + MLIRSMT +) diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp new file mode 100644 index 0000000000000..228e8d342a1f6 --- /dev/null +++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp @@ -0,0 +1,35 @@ +//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" +#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h" +#include "mlir/IR/DialectRegistry.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class SMTExtension : public transform::TransformDialectExtension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension) + + SMTExtension() { + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) { + dialectRegistry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp new file mode 100644 index 0000000000000..8e7af05353de7 --- /dev/null +++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp @@ -0,0 +1,55 @@ +//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// ConstrainParamsOp +//===----------------------------------------------------------------------===// + +void transform::smt::ConstrainParamsOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getParamsMutable(), effects); +} + +DiagnosedSilenceableFailure +transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + // TODO: Proper operational semantics are to check the SMT problem in the body + // with a SMT solver with the arguments of the body constrained to the + // values passed into the op. Success or failure is then determined by + // the solver's result. + // One way to support this is to just promise the TransformOpInterface + // and allow for users to attach their own implementation, which would, + // e.g., translate the ops to SMTLIB and hand that over to the user's + // favourite solver. This requires changes to the dialect's verifier. + return emitDefiniteFailure() << "op does not have interpreted semantics yet"; +} + +LogicalResult transform::smt::ConstrainParamsOp::verify() { + if (getOperands().size() != getBody().getNumArguments()) + return emitOpError( + "must have the same number of block arguments as operands"); + + for (auto &op : getBody().getOps()) { + if (!isa(op.getDialect())) + return emitOpError( + "ops contained in region should belong to SMT-dialect"); + } + + return success(); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 69a85dbe141ce..3839172fd0b42 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -53,6 +53,7 @@ #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" @@ -108,6 +109,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerIRDLExtension(registry); transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); + transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 97f0778071ef9..d6686bb89ce4e 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" DIALECT_NAME transform EXTENSION_NAME transform_pdl_extension) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformSMTExtensionOps.td + SOURCES + dialects/transform/smt.py + DIALECT_NAME transform + EXTENSION_NAME transform_smt_extension) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformSMTExtensionOps.td b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td new file mode 100644 index 0000000000000..3e92417a35d13 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the SMT extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS + +include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py index ae7a4c41cbc3a..38970d17abd47 100644 --- a/mlir/python/mlir/dialects/smt.py +++ b/mlir/python/mlir/dialects/smt.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._smt_ops_gen import * +from ._smt_enum_gen import * from .._mlir_libs._mlirDialectsSMT import * from ..extras.meta import region_op diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py new file mode 100644 index 0000000000000..1f0b7f066118c --- /dev/null +++ b/mlir/python/mlir/dialects/transform/smt.py @@ -0,0 +1,38 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Sequence + +from ...ir import Type, Block +from .._transform_smt_extension_ops_gen import * +from .._transform_smt_extension_ops_gen import _Dialect +from ...dialects import transform + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstrainParamsOp(ConstrainParamsOp): + def __init__( + self, + params: Sequence[transform.AnyParamType], + arg_types: Sequence[Type], + loc=None, + ip=None, + ): + if len(params) != len(arg_types): + raise ValueError(f"{params=} not same length as {arg_types=}") + super().__init__( + params, + loc=loc, + ip=ip, + ) + self.regions[0].blocks.append(*arg_types) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir new file mode 100644 index 0000000000000..314b8d493c5d4 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics + +// CHECK-LABEL: @constraint_not_using_smt_ops +module attributes {transform.with_named_sequence} { + transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{ops contained in region should belong to SMT-dialect}} + transform.smt.constrain_params(%param_as_param) : !transform.param { + ^bb0(%param_as_smt_var: !smt.int): + %c4 = arith.constant 4 : i32 + // This is the kind of thing one might think works: + //arith.remsi %param_as_smt_var, %c4 : i32 + } + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @operands_not_one_to_one_with_vars +module attributes {transform.with_named_sequence} { + transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{must have the same number of block arguments as operands}} + transform.smt.constrain_params(%param_as_param) : !transform.param { + ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int): + } + transform.yield + } +} diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir new file mode 100644 index 0000000000000..29d15175ae4ec --- /dev/null +++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt %s --split-input-file | FileCheck %s + +// CHECK-LABEL: @schedule_with_constrained_param +module attributes {transform.with_named_sequence} { + transform.named_sequence @schedule_with_constrained_param(%arg0: !transform.any_op {transform.readonly}) { + // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant + %param_as_param = transform.param.constant 42 -> !transform.param + + // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) + transform.smt.constrain_params(%param_as_param) : !transform.param { + // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int): + ^bb0(%param_as_smt_var: !smt.int): + // CHECK: %[[C0:.*]] = smt.int.constant 0 + %c0 = smt.int.constant 0 + // CHECK: %[[C43:.*]] = smt.int.constant 43 + %c43 = smt.int.constant 43 + // CHECK: %[[LOWER_BOUND:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]] + %lower_bound = smt.int.cmp le %c0, %param_as_smt_var + // CHECK: smt.assert %[[LOWER_BOUND]] + smt.assert %lower_bound + // CHECK: %[[UPPER_BOUND:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]] + %upper_bound = smt.int.cmp le %param_as_smt_var, %c43 + // CHECK: smt.assert %[[UPPER_BOUND]] + smt.assert %upper_bound + } + // NB: from here can rely on that 0 <= %param_as_param <= 43, even if its + // definition changes. + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @schedule_with_constraint_on_multiple_params +module attributes {transform.with_named_sequence} { + transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) { + // CHECK: %[[PARAM_A:.*]] = transform.param.constant + %param_a = transform.param.constant 4 -> !transform.param + // CHECK: %[[PARAM_B:.*]] = transform.param.constant + %param_b = transform.param.constant 16 -> !transform.param + + // CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]]) + transform.smt.constrain_params(%param_a, %param_b) : !transform.param, !transform.param { + // CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int): + ^bb0(%var_a: !smt.int, %var_b: !smt.int): + // CHECK: %[[C0:.*]] = smt.int.constant 0 + %c0 = smt.int.constant 0 + // CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]] + %remainder = smt.int.mod %var_b, %var_a + // CHECK: %[[EQ:.*]] = smt.eq %[[REMAINDER]], %[[C0]] + %eq = smt.eq %remainder, %c0 : !smt.int + // CHECK: smt.assert %[[EQ]] + smt.assert %eq + } + // NB: from here can rely on that %param_a is a divisor of %param_b + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @schedule_with_param_as_a_bool +module attributes {transform.with_named_sequence} { + transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) { + // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant + %param_as_param = transform.param.constant true -> !transform.any_param + + // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) + transform.smt.constrain_params(%param_as_param) : !transform.any_param { + // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool): + ^bb0(%param_as_smt_var: !smt.bool): + // CHECK: %[[C0:.*]] = smt.int.constant 0 + %c0 = smt.int.constant 0 + // CHECK: %[[C1:.*]] = smt.int.constant 1 + %c1 = smt.int.constant 1 + // CHECK: %[[FALSEHOOD:.*]] = smt.eq %[[C0]], %[[C1]] + %falsehood = smt.eq %c0, %c1 : !smt.int + // CHECK: %[[TRUE_IFF_PARAM_IS:.*]] = smt.or %[[PARAM_AS_SMT_VAR]], %[[FALSEHOOD]] + %true_iff_param_is = smt.or %param_as_smt_var, %falsehood + // CHECK: smt.assert %[[TRUE_IFF_PARAM_IS]] + smt.assert %true_iff_param_is + } + // NB: from here can rely on that %param_as_param holds true, even if its + // definition changes. + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py new file mode 100644 index 0000000000000..3692fd92344a6 --- /dev/null +++ b/mlir/test/python/dialects/transform_smt_ext.py @@ -0,0 +1,50 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir import ir +from mlir.dialects import transform, smt +from mlir.dialects.transform import smt as transform_smt + + +def run(f): + print("\nTEST:", f.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with ir.InsertionPoint(sequence.body): + f(sequence.bodyTarget) + transform.YieldOp() + print(module) + return f + + +# CHECK-LABEL: TEST: testConstrainParamsOp +@run +def testConstrainParamsOp(target): + dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) + # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant + symbolic_value = transform.ParamConstantOp( + transform.AnyParamType.get(), dummy_value + ) + # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) + constrain_params = transform_smt.ConstrainParamsOp( + [symbolic_value], [smt.IntType.get()] + ) + # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int): + with ir.InsertionPoint(constrain_params.body): + # CHECK: %[[C0:.*]] = smt.int.constant 0 + c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)) + # CHECK: %[[C43:.*]] = smt.int.constant 43 + c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43)) + # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]] + lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0]) + # CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]] + ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43) + # CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]] + bounded = smt.AndOp([lb, ub]) + # CHECK: smt.assert %[[BOUNDED:.*]] + smt.AssertOp(bounded)