From c69d3faa2a7ddb9663793196baa726a72ff51371 Mon Sep 17 00:00:00 2001 From: yuhao <1171760467@qq.com> Date: Thu, 11 May 2023 15:39:16 +0800 Subject: [PATCH 1/7] init --- oneflow/ir/include/Transform/CMakeLists.txt | 12 +- .../TestTransformDialectExtension.td | 452 ------------ .../Transform/TestTransformStateExtension.h | 51 -- ...xtension.h => TransformDialectExtension.h} | 29 +- .../Transform/TransformDialectExtension.td | 71 ++ .../Transform/TransformStateExtension.h | 43 ++ oneflow/ir/lib/Transform/CMakeLists.txt | 10 +- .../TestTransformDialectExtension.cpp | 663 ------------------ .../Transform/TransformDialectExtension.cpp | 132 ++++ ...er.cpp => TransformDialectInterpreter.cpp} | 34 +- ...ension.cpp => TransformStateExtension.cpp} | 6 +- oneflow/ir/oneflow-opt/CMakeLists.txt | 2 +- oneflow/ir/oneflow-opt/oneflow-opt.cpp | 8 +- oneflow/ir/test/Transform/apply-patterns.mlir | 20 + 14 files changed, 320 insertions(+), 1213 deletions(-) delete mode 100644 oneflow/ir/include/Transform/TestTransformDialectExtension.td delete mode 100644 oneflow/ir/include/Transform/TestTransformStateExtension.h rename oneflow/ir/include/Transform/{TestTransformDialectExtension.h => TransformDialectExtension.h} (66%) create mode 100644 oneflow/ir/include/Transform/TransformDialectExtension.td create mode 100644 oneflow/ir/include/Transform/TransformStateExtension.h delete mode 100644 oneflow/ir/lib/Transform/TestTransformDialectExtension.cpp create mode 100644 oneflow/ir/lib/Transform/TransformDialectExtension.cpp rename oneflow/ir/lib/Transform/{TestTransformDialectInterpreter.cpp => TransformDialectInterpreter.cpp} (89%) rename oneflow/ir/lib/Transform/{TestTransformStateExtension.cpp => TransformStateExtension.cpp} (88%) create mode 100644 oneflow/ir/test/Transform/apply-patterns.mlir diff --git a/oneflow/ir/include/Transform/CMakeLists.txt b/oneflow/ir/include/Transform/CMakeLists.txt index 8b9b41c4e3b..f635a0f61de 100644 --- a/oneflow/ir/include/Transform/CMakeLists.txt +++ b/oneflow/ir/include/Transform/CMakeLists.txt @@ -1,8 +1,8 @@ -set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td) -mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls) -mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs) -mlir_tablegen(TestTransformDialectExtensionTypes.h.inc -gen-typedef-decls +set(LLVM_TARGET_DEFINITIONS TransformDialectExtension.td) +mlir_tablegen(TransformDialectExtension.h.inc -gen-op-decls) +mlir_tablegen(TransformDialectExtension.cpp.inc -gen-op-defs) +mlir_tablegen(TransformDialectExtensionTypes.h.inc -gen-typedef-decls -typedefs-dialect=transform) -mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs +mlir_tablegen(TransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=transform) -add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen) +add_public_tablegen_target(MLIROneFlowTransformDialectExtensionIncGen) diff --git a/oneflow/ir/include/Transform/TestTransformDialectExtension.td b/oneflow/ir/include/Transform/TestTransformDialectExtension.td deleted file mode 100644 index 04056bddf09..00000000000 --- a/oneflow/ir/include/Transform/TestTransformDialectExtension.td +++ /dev/null @@ -1,452 +0,0 @@ -//===- TestTransformDialectExtension.td --------------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the operations that are injected into the Transform -// dialect through the extension mechanism, as a test. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD -#define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD - -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" -include "mlir/Dialect/Transform/IR/MatchInterfaces.td" -include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformInterfaces.td" -include "mlir/Dialect/PDL/IR/PDLTypes.td" - -def TestTransformTestDialectHandleType - : TypeDef]> { - let description = [{Handle pointing to an op from the Test dialect.}]; - let mnemonic = "test_dialect_op"; - let assemblyFormat = ""; -} - -def TestTransformTestDialectParamType - : TypeDef]> { - let description = [{ - Parameter associated with an i32 attribute for testing purposes. - }]; - let mnemonic = "test_dialect_param"; - let assemblyFormat = ""; -} - -def TestProduceSelfHandleOrForwardOperandOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins Optional:$operand); - let results = (outs PDL_Operation:$res); - let assemblyFormat = "($operand^)? attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestProduceValueHandleToSelfOperand - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins TransformHandleTypeInterface:$in); - let results = (outs TransformValueHandleTypeInterface:$out); - let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; - let cppNamespace = "::mlir::transform"; - -} - -def TestProduceValueHandleToResult - : Op]> { - let arguments = (ins TransformHandleTypeInterface:$in, - I64Attr:$number); - let results = (outs TransformValueHandleTypeInterface:$out); - let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestProduceValueHandleToArgumentOfParentBlock - : Op]> { - let arguments = (ins TransformHandleTypeInterface:$in, - I64Attr:$number); - let results = (outs TransformValueHandleTypeInterface:$out); - let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestConsumeOperand : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins - Transform_AnyHandleOrParamType:$operand, - Optional:$second_operand); - let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict `:` type($operand)"; - let cppNamespace = "::mlir::transform"; -} - -def TestConsumeOperandOfOpKindOrFail - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins - PDL_Operation:$operand, - StrAttr:$op_kind); - let assemblyFormat = "$operand `,` $op_kind attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestSucceedIfOperandOfOpKind - : Op]> { - let arguments = (ins - TransformHandleTypeInterface:$operand_handle, - StrAttr:$op_kind); - let assemblyFormat = - "$operand_handle `,` $op_kind attr-dict `:` type($operand_handle)"; - let extraClassDeclaration = SingleOpMatcher.extraDeclaration; - let cppNamespace = "::mlir::transform"; -} - -def TestPrintRemarkAtOperandOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins - TransformHandleTypeInterface:$operand, - StrAttr:$message); - let assemblyFormat = - "$operand `,` $message attr-dict `:` type($operand)"; - let cppNamespace = "::mlir::transform"; -} - -def TestPrintRemarkAtOperandValue - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins TransformValueHandleTypeInterface:$in, - StrAttr:$message); - let assemblyFormat = "$in `,` $message attr-dict `:` type($in)"; - let cppNamespace = "::mlir::transform"; -} - -def TestAddTestExtensionOp - : Op, - NoMemoryEffect]> { - let arguments = (ins StrAttr:$message); - let assemblyFormat = "$message attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestCheckIfTestExtensionPresentOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins PDL_Operation:$operand); - let assemblyFormat = "$operand attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestRemapOperandPayloadToSelfOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins PDL_Operation:$operand); - let results = (outs Optional:$out); - let assemblyFormat = "$operand attr-dict (`:` type($out)^)?"; - let cppNamespace = "::mlir::transform"; -} - -def TestRemoveTestExtensionOp - : Op, - NoMemoryEffect]> { - let assemblyFormat = "attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestReversePayloadOpsOp - : Op]> { - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$result); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestTransformOpWithRegions - : Op, - DeclareOpInterfaceMethods]> { - let regions = (region AnyRegion:$first, AnyRegion:$second); - let assemblyFormat = "attr-dict-with-keyword regions"; - let cppNamespace = "::mlir::transform"; -} - -def TestBranchingTransformOpTerminator - : Op, - DeclareOpInterfaceMethods]> { - let successors = (successor VariadicSuccessor:$succ); - let cppNamespace = "::mlir::transform"; -} - -def TestEmitRemarkAndEraseOperandOp - : Op, - DeclareOpInterfaceMethods, - FunctionalStyleTransformOpTrait]> { - let arguments = (ins PDL_Operation:$target, StrAttr:$remark, - UnitAttr:$fail_after_erase); - let assemblyFormat = "$target `,` $remark attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestWrongNumberOfResultsOp - : Op { - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$a, - PDL_Operation:$b, - PDL_Operation:$c); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation * target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestWrongNumberOfMultiResultsOp - : Op { - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$result); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation * target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestCorrectNumberOfMultiResultsOp - : Op { - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$result1, - PDL_Operation:$result2); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation * target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestMixedNullAndNonNullResultsOp - : Op { - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$null, - PDL_Operation:$non_null); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation * target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestMixedSuccessAndSilenceableOp - : Op { - let arguments = (ins PDL_Operation:$target); - let results = (outs); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation * target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestPrintNumberOfAssociatedPayloadIROps - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins PDL_Operation:$handle); - let assemblyFormat = "$handle attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestCopyPayloadOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins PDL_Operation:$handle); - let results = (outs PDL_Operation:$copy); - let cppNamespace = "::mlir::transform"; - let assemblyFormat = "$handle attr-dict"; -} - -def TestReportNumberOfTrackedHandlesNestedUnder - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins PDL_Operation:$target); - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestPrintParamOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins TransformParamTypeInterface:$param, - Optional:$anchor, - OptionalAttr:$message); - let assemblyFormat = "$param (`,` $message^)? (`at` $anchor^)?" - "attr-dict `:` type($param) (`,` type($anchor)^)?"; - let cppNamespace = "::mlir::transform"; -} - -def TestAddToParamOp - : Op]> { - let arguments = (ins Optional:$param, - I32Attr:$addendum); - let results = (outs TestTransformTestDialectParamType:$result); - let assemblyFormat = "($param^ `,`)? $addendum attr-dict"; - let cppNamespace = "::mlir::transform"; -} - -def TestProduceParamWithNumberOfTestOps - : Op]> { - let arguments = (ins TransformHandleTypeInterface:$handle); - let results = (outs TestTransformTestDialectParamType:$result); - let assemblyFormat = "$handle attr-dict `:` type($handle)"; - let cppNamespace = "::mlir::transform"; -} - -def TestProduceIntegerParamWithTypeOp - : Op]> { - let arguments = (ins TypeAttr:$type); - let results = (outs TransformParamTypeInterface:$result); - let assemblyFormat = "$type attr-dict `:` type($result)"; - let cppNamespace = "::mlir::transform"; - let hasVerifier = 1; -} - -def TestProduceTransformParamOrForwardOperandOp - : Op, - TransformEachOpTrait, TransformOpInterface]> { - let arguments = (ins TransformHandleTypeInterface:$in, - UnitAttr:$first_result_is_param, - UnitAttr:$first_result_is_null, - UnitAttr:$second_result_is_handle); - let results = (outs AnyType:$out, - TransformParamTypeInterface:$param); - let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; - let cppNamespace = "::mlir::transform"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -def TestProduceNullPayloadOp - : Op, - DeclareOpInterfaceMethods]> { - let results = (outs TransformHandleTypeInterface:$out); - let assemblyFormat = "attr-dict `:` type($out)"; - let cppNamespace = "::mlir::transform"; -} - -def TestProduceNullParamOp - : Op, - DeclareOpInterfaceMethods]> { - let results = (outs TransformParamTypeInterface:$out); - let assemblyFormat = "attr-dict `:` type($out)"; - let cppNamespace = "::mlir::transform"; -} - -def TestProduceNullValueOp - : Op, - DeclareOpInterfaceMethods]> { - let results = (outs TransformValueHandleTypeInterface:$out); - let assemblyFormat = "attr-dict `:` type($out)"; - let cppNamespace = "::mlir::transform"; -} - -def TestRequiredMemoryEffectsOp - : Op, - DeclareOpInterfaceMethods]> { - let arguments = (ins TransformHandleTypeInterface:$in, - UnitAttr:$has_operand_effect, - UnitAttr:$has_result_effect, - UnitAttr:$modifies_payload); - let results = (outs TransformHandleTypeInterface:$out); - let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; - let cppNamespace = "::mlir::transform"; -} - -#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/oneflow/ir/include/Transform/TestTransformStateExtension.h b/oneflow/ir/include/Transform/TestTransformStateExtension.h deleted file mode 100644 index 84173e9500c..00000000000 --- a/oneflow/ir/include/Transform/TestTransformStateExtension.h +++ /dev/null @@ -1,51 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -//===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines an TransformState extension for the purpose of testing the -// relevant APIs. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H -#define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H - -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" - -namespace mlir { -namespace transform { -class TestTransformStateExtension : public transform::TransformState::Extension { - public: - TestTransformStateExtension(transform::TransformState& state, StringAttr message) - : Extension(state), message(message) {} - - StringRef getMessage() const { return message.getValue(); } - - LogicalResult updateMapping(Operation* previous, Operation* updated); - - private: - StringAttr message; -}; -} // namespace transform -} // namespace mlir - -#endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H diff --git a/oneflow/ir/include/Transform/TestTransformDialectExtension.h b/oneflow/ir/include/Transform/TransformDialectExtension.h similarity index 66% rename from oneflow/ir/include/Transform/TestTransformDialectExtension.h rename to oneflow/ir/include/Transform/TransformDialectExtension.h index 1d794771cce..47dbd7dab65 100644 --- a/oneflow/ir/include/Transform/TestTransformDialectExtension.h +++ b/oneflow/ir/include/Transform/TransformDialectExtension.h @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -//===- TestTransformDialectExtension.h --------------------------*- C++ -*-===// +//===- TransformDialectExtension.h --------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -26,8 +26,8 @@ limitations under the License. // //===----------------------------------------------------------------------===// -#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H -#define MLIR_TESTTRANSFORMDIALECTEXTENSION_H +#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ +#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" @@ -38,19 +38,26 @@ limitations under the License. namespace mlir { class DialectRegistry; -namespace transform { +namespace oneflow { +namespace transform_dialect { /// Registers the test extension to the Transform dialect. -void registerTestTransformDialectExtension(::mlir::DialectRegistry& registry); -void registerTestTransformDialectEraseSchedulePass(); -void registerTestTransformDialectInterpreterPass(); -} // namespace transform +void registerTransformDialectExtension(::mlir::DialectRegistry& registry); +void registerTransformDialectEraseSchedulePass(); +void registerTransformDialectInterpreterPass(); +struct ApplyPatternsOpPatterns { + bool canonicalization = false; +}; + +} // namespace transform_dialect + +} // namespace oneflow } // namespace mlir #define GET_TYPEDEF_CLASSES -#include "Transform/TestTransformDialectExtensionTypes.h.inc" +#include "Transform/TransformDialectExtensionTypes.h.inc" #define GET_OP_CLASSES -#include "Transform/TestTransformDialectExtension.h.inc" +#include "Transform/TransformDialectExtension.h.inc" -#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_H +#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ diff --git a/oneflow/ir/include/Transform/TransformDialectExtension.td b/oneflow/ir/include/Transform/TransformDialectExtension.td new file mode 100644 index 00000000000..6d386cc316e --- /dev/null +++ b/oneflow/ir/include/Transform/TransformDialectExtension.td @@ -0,0 +1,71 @@ +#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_ +#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_ + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" + +def ApplyPatternsOp : Op, + TransformEachOpTrait, + TransformOpInterface]> { + let description = [{ + Modified from iree project: https://github.com/openxla/iree + Greedily applies patterns as specified by its attributes. + + Must be applied to an op with trait IsolatedFromAbove since the + GreedyPatternRewriter asserts those. Internally, uses the tracking rewriter + to preserve handles to payload operations nested within operations + associated with `target`. Fails if tracking cannot find replacement for a + payload operation. This may become controllable with an attribute in the + future. + + Returns the IsolatedFromAbove op whose content it has modified for better + chaining APIs. + + The following additive attributes can be set, they add patterns in an + unspecified order: + - tiling_canonicalization: adds specific tiling-related canonicalization + patterns. + + + #### Return modes: + + This operation applies a set of patterns specified by attributes. To apply + these patterns, this operation must target an operation that is isolated + from above, otherwise the transform definitely fails. + + If the pattern application fails, or if the underlying listener fails to + capture op handles, the transformation definitely fails. + + Otherwise the transformation is successful. + + This operation does not consume the target handle and does not produce any + handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$canonicalization); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; + let cppNamespace = "mlir::oneflow::transform_dialect"; + + let builders = [ + OpBuilder<(ins "Value":$target, + "const ApplyPatternsOpPatterns &":$patterns)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_ diff --git a/oneflow/ir/include/Transform/TransformStateExtension.h b/oneflow/ir/include/Transform/TransformStateExtension.h new file mode 100644 index 00000000000..2d9929895b3 --- /dev/null +++ b/oneflow/ir/include/Transform/TransformStateExtension.h @@ -0,0 +1,43 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_ +#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_ + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +namespace mlir { +namespace oneflow { + +namespace transform_dialect { +class TestTransformStateExtension : public ::mlir::transform::TransformState::Extension { + public: + TestTransformStateExtension(::mlir::transform::TransformState& state, StringAttr message) + : Extension(state), message(message) {} + + StringRef getMessage() const { return message.getValue(); } + + LogicalResult updateMapping(Operation* previous, Operation* updated); + + private: + StringAttr message; +}; + +} // namespace transform_dialect +} // namespace oneflow +} // namespace mlir + +#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_ diff --git a/oneflow/ir/lib/Transform/CMakeLists.txt b/oneflow/ir/lib/Transform/CMakeLists.txt index b03c0f4bb87..64bd0a4a8bc 100644 --- a/oneflow/ir/lib/Transform/CMakeLists.txt +++ b/oneflow/ir/lib/Transform/CMakeLists.txt @@ -1,11 +1,11 @@ add_mlir_library( - MLIRTestTransformDialect - TestTransformDialectExtension.cpp - TestTransformDialectInterpreter.cpp - TestTransformStateExtension.cpp + MLIROneFlowTransformDialect + TransformDialectExtension.cpp + TransformDialectInterpreter.cpp + TransformStateExtension.cpp EXCLUDE_FROM_LIBMLIR DEPENDS - MLIRTestTransformDialectExtensionIncGen + MLIROneFlowTransformDialectExtensionIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/oneflow/ir/lib/Transform/TestTransformDialectExtension.cpp b/oneflow/ir/lib/Transform/TestTransformDialectExtension.cpp deleted file mode 100644 index 571309e21cd..00000000000 --- a/oneflow/ir/lib/Transform/TestTransformDialectExtension.cpp +++ /dev/null @@ -1,663 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -//===- TestTransformDialectExtension.cpp ----------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines an extension of the MLIR Transform dialect for testing -// purposes. -// -//===----------------------------------------------------------------------===// - -#include "Transform/TestTransformDialectExtension.h" -#include "Transform/TestTransformStateExtension.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/OpImplementation.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Compiler.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -namespace { -/// Simple transform op defined outside of the dialect. Just emits a remark when -/// applied. This op is defined in C++ to test that C++ definitions also work -/// for op injection into the Transform dialect. -class TestTransformOp : public Op { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) - - using Op::Op; - - static ArrayRef getAttributeNames() { return {}; } - - static constexpr llvm::StringLiteral getOperationName() { - return llvm::StringLiteral("transform.test_transform_op"); - } - - DiagnosedSilenceableFailure apply(transform::TransformResults& results, - transform::TransformState& state) { - InFlightDiagnostic remark = emitRemark() << "applying transformation"; - if (Attribute message = getMessage()) remark << " " << message; - - return DiagnosedSilenceableFailure::success(); - } - - Attribute getMessage() { return getOperation()->getAttr("message"); } - - static ParseResult parse(OpAsmParser& parser, OperationState& state) { - StringAttr message; - OptionalParseResult result = parser.parseOptionalAttribute(message); - if (!result.has_value()) return success(); - - if (result.value().succeeded()) state.addAttribute("message", message); - return result.value(); - } - - void print(OpAsmPrinter& printer) { - if (getMessage()) printer << " " << getMessage(); - } - - // No side effects. - void getEffects(SmallVectorImpl& effects) {} -}; - -/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait -/// in cases where it is attached to ops that do not comply with the trait -/// requirements. This op cannot be defined in ODS because ODS generates strict -/// verifiers that overalp with those in the trait and run earlier. -class TestTransformUnrestrictedOpNoInterface - : public Op { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformUnrestrictedOpNoInterface) - - using Op::Op; - - static ArrayRef getAttributeNames() { return {}; } - - static constexpr llvm::StringLiteral getOperationName() { - return llvm::StringLiteral("transform.test_transform_unrestricted_op_no_interface"); - } - - DiagnosedSilenceableFailure apply(transform::TransformResults& results, - transform::TransformState& state) { - return DiagnosedSilenceableFailure::success(); - } - - // No side effects. - void getEffects(SmallVectorImpl& effects) {} -}; -} // namespace - -DiagnosedSilenceableFailure mlir::transform::TestProduceSelfHandleOrForwardOperandOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - if (getOperation()->getNumOperands() != 0) { - results.set(getResult().cast(), getOperation()->getOperand(0).getDefiningOp()); - } else { - results.set(getResult().cast(), getOperation()); - } - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceSelfHandleOrForwardOperandOp::getEffects( - SmallVectorImpl& effects) { - if (getOperand()) transform::onlyReadsHandle(getOperand(), effects); - transform::producesHandle(getRes(), effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceValueHandleToSelfOperand::apply( - transform::TransformResults& results, transform::TransformState& state) { - results.setValues(getOut().cast(), getIn()); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceValueHandleToSelfOperand::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getIn(), effects); - transform::producesHandle(getOut(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceValueHandleToResult::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - if (target->getNumResults() <= getNumber()) - return emitSilenceableError() << "payload has no result #" << getNumber(); - results.push_back(target->getResult(getNumber())); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceValueHandleToResult::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getIn(), effects); - transform::producesHandle(getOut(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure -mlir::transform::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - if (!target->getBlock()) return emitSilenceableError() << "payload has no parent block"; - if (target->getBlock()->getNumArguments() <= getNumber()) - return emitSilenceableError() << "parent of the payload has no argument #" << getNumber(); - results.push_back(target->getBlock()->getArgument(getNumber())); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceValueHandleToArgumentOfParentBlock::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getIn(), effects); - transform::producesHandle(getOut(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestConsumeOperand::apply( - transform::TransformResults& results, transform::TransformState& state) { - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestConsumeOperand::getEffects( - SmallVectorImpl& effects) { - transform::consumesHandle(getOperand(), effects); - if (getSecondOperand()) transform::consumesHandle(getSecondOperand(), effects); - transform::modifiesPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestConsumeOperandOfOpKindOrFail::apply( - transform::TransformResults& results, transform::TransformState& state) { - ArrayRef payload = state.getPayloadOps(getOperand()); - assert(payload.size() == 1 && "expected a single target op"); - if (payload[0]->getName().getStringRef() != getOpKind()) { - return emitSilenceableError() - << "op expected the operand to be associated a payload op of kind " << getOpKind() - << " got " << payload[0]->getName().getStringRef(); - } - - emitRemark() << "succeeded"; - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestConsumeOperandOfOpKindOrFail::getEffects( - SmallVectorImpl& effects) { - transform::consumesHandle(getOperand(), effects); - transform::modifiesPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestSucceedIfOperandOfOpKind::matchOperation( - Operation* op, transform::TransformResults& results, transform::TransformState& state) { - if (op->getName().getStringRef() != getOpKind()) { - return emitSilenceableError() - << "op expected the operand to be associated with a payload op of " - "kind " - << getOpKind() << " got " << op->getName().getStringRef(); - } - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestSucceedIfOperandOfOpKind::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getOperand(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestPrintRemarkAtOperandOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - ArrayRef payload = state.getPayloadOps(getOperand()); - for (Operation* op : payload) op->emitRemark() << getMessage(); - - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestPrintRemarkAtOperandOp::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getOperand(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestPrintRemarkAtOperandValue::apply( - transform::TransformResults& results, transform::TransformState& state) { - ArrayRef values = state.getPayloadValues(getIn()); - for (Value value : values) { - std::string note; - llvm::raw_string_ostream os(note); - if (auto arg = value.dyn_cast()) { - os << "a block argument #" << arg.getArgNumber() << " in block #" - << std::distance(arg.getOwner()->getParent()->begin(), arg.getOwner()->getIterator()) - << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); - } else { - os << "an op result #" << value.cast().getResultNumber(); - } - InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); - diag.attachNote() << "value handle points to " << os.str(); - } - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestPrintRemarkAtOperandValue::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getIn(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestAddTestExtensionOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - state.addExtension(getMessageAttr()); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestCheckIfTestExtensionPresentOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - auto* extension = state.getExtension(); - if (!extension) { - emitRemark() << "extension absent"; - return DiagnosedSilenceableFailure::success(); - } - - InFlightDiagnostic diag = emitRemark() << "extension present, " << extension->getMessage(); - for (Operation* payload : state.getPayloadOps(getOperand())) { - diag.attachNote(payload->getLoc()) << "associated payload op"; -#ifndef NDEBUG - SmallVector handles; - assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); - assert(llvm::is_contained(handles, getOperand()) - && "inconsistent mapping between transform IR handles and payload IR " - "operations"); -#endif // NDEBUG - } - - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestCheckIfTestExtensionPresentOp::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getOperand(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestRemapOperandPayloadToSelfOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - auto* extension = state.getExtension(); - if (!extension) return emitDefiniteFailure("TestTransformStateExtension missing"); - - if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), getOperation()))) - return DiagnosedSilenceableFailure::definiteFailure(); - if (getNumResults() > 0) results.set(getResult(0).cast(), getOperation()); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestRemapOperandPayloadToSelfOp::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getOperand(), effects); - transform::producesHandle(getOut(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestRemoveTestExtensionOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - state.removeExtension(); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestReversePayloadOpsOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); - results.set(getResult().cast(), reversedOps); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestTransformOpWithRegions::apply( - transform::TransformResults& results, transform::TransformState& state) { - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestTransformOpWithRegions::getEffects( - SmallVectorImpl& effects) {} - -DiagnosedSilenceableFailure mlir::transform::TestBranchingTransformOpTerminator::apply( - transform::TransformResults& results, transform::TransformState& state) { - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestBranchingTransformOpTerminator::getEffects( - SmallVectorImpl& effects) {} - -DiagnosedSilenceableFailure mlir::transform::TestEmitRemarkAndEraseOperandOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - emitRemark() << getRemark(); - for (Operation* op : state.getPayloadOps(getTarget())) op->erase(); - - if (getFailAfterErase()) return emitSilenceableError() << "silenceable error"; - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestEmitRemarkAndEraseOperandOp::getEffects( - SmallVectorImpl& effects) { - transform::consumesHandle(getTarget(), effects); - transform::modifiesPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestWrongNumberOfResultsOp::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - OperationState opState(target->getLoc(), "foo"); - results.push_back(OpBuilder(target).create(opState)); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - static int count = 0; - if (count++ == 0) { - OperationState opState(target->getLoc(), "foo"); - results.push_back(OpBuilder(target).create(opState)); - } - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - OperationState opState(target->getLoc(), "foo"); - results.push_back(OpBuilder(target).create(opState)); - results.push_back(OpBuilder(target).create(opState)); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestMixedNullAndNonNullResultsOp::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - OperationState opState(target->getLoc(), "foo"); - results.push_back(nullptr); - results.push_back(OpBuilder(target).create(opState)); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestMixedSuccessAndSilenceableOp::applyToOne( - Operation* target, transform::ApplyToEachResultList& results, - transform::TransformState& state) { - if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); - return emitDefaultSilenceableFailure(target); -} - -DiagnosedSilenceableFailure mlir::transform::TestPrintNumberOfAssociatedPayloadIROps::apply( - transform::TransformResults& results, transform::TransformState& state) { - if (!getHandle()) emitRemark() << 0; - emitRemark() << state.getPayloadOps(getHandle()).size(); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestPrintNumberOfAssociatedPayloadIROps::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getHandle(), effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestCopyPayloadOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - results.set(getCopy().cast(), state.getPayloadOps(getHandle())); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestCopyPayloadOp::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getHandle(), effects); - transform::producesHandle(getCopy(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( - Location loc, ArrayRef payload) const { - if (payload.empty()) return DiagnosedSilenceableFailure::success(); - - for (Operation* op : payload) { - if (op->getName().getDialectNamespace() != "test") { - return emitSilenceableError(loc) << "expected the payload operation to " - "belong to the 'test' dialect"; - } - } - - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( - Location loc, ArrayRef payload) const { - for (Attribute attr : payload) { - auto integerAttr = attr.dyn_cast(); - if (integerAttr && integerAttr.getType().isSignlessInteger(32)) continue; - return emitSilenceableError(loc) << "expected the parameter to be a i32 integer attribute"; - } - - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getTarget(), effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestReportNumberOfTrackedHandlesNestedUnder::apply( - transform::TransformResults& results, transform::TransformState& state) { - int64_t count = 0; - for (Operation* op : state.getPayloadOps(getTarget())) { - op->walk([&](Operation* nested) { - SmallVector handles; - (void)state.getHandlesForPayloadOp(nested, handles); - count += handles.size(); - }); - } - emitRemark() << count << " handles nested under"; - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestPrintParamOp::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getParam(), effects); - if (getAnchor()) transform::onlyReadsHandle(getAnchor(), effects); - transform::onlyReadsPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestPrintParamOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - std::string str; - llvm::raw_string_ostream os(str); - if (getMessage()) os << *getMessage() << " "; - llvm::interleaveComma(state.getParams(getParam()), os); - if (!getAnchor()) { - emitRemark() << os.str(); - return DiagnosedSilenceableFailure::success(); - } - for (Operation* payload : state.getPayloadOps(getAnchor())) - ::mlir::emitRemark(payload->getLoc()) << os.str(); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestAddToParamOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - SmallVector values(/*Size=*/1, /*Value=*/0); - if (Value param = getParam()) { - values = - llvm::to_vector(llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { - return attr.cast().getValue().getLimitedValue(UINT32_MAX); - })); - } - - Builder builder(getContext()); - SmallVector result = - llvm::to_vector(llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { - return builder.getI32IntegerAttr(value + getAddendum()); - })); - results.setParams(getResult().cast(), result); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceParamWithNumberOfTestOps::apply( - transform::TransformResults& results, transform::TransformState& state) { - Builder builder(getContext()); - SmallVector result = llvm::to_vector(llvm::map_range( - state.getPayloadOps(getHandle()), [&builder](Operation* payload) -> Attribute { - int32_t count = 0; - payload->walk([&count](Operation* op) { - if (op->getName().getDialectNamespace() == "test") ++count; - }); - return builder.getI32IntegerAttr(count); - })); - results.setParams(getResult().cast(), result); - return DiagnosedSilenceableFailure::success(); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceIntegerParamWithTypeOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - Attribute zero = IntegerAttr::get(getType(), 0); - results.setParams(getResult().cast(), zero); - return DiagnosedSilenceableFailure::success(); -} - -LogicalResult mlir::transform::TestProduceIntegerParamWithTypeOp::verify() { - if (!getType().isa()) { return emitOpError() << "expects an integer type"; } - return success(); -} - -void mlir::transform::TestProduceTransformParamOrForwardOperandOp::getEffects( - SmallVectorImpl& effects) { - transform::onlyReadsHandle(getIn(), effects); - transform::producesHandle(getOut(), effects); - transform::producesHandle(getParam(), effects); -} - -DiagnosedSilenceableFailure -mlir::transform::TestProduceTransformParamOrForwardOperandOp::applyToOne( - Operation* target, ::transform::ApplyToEachResultList& results, - ::transform::TransformState& state) { - Builder builder(getContext()); - if (getFirstResultIsParam()) { - results.push_back(builder.getI64IntegerAttr(0)); - } else if (getFirstResultIsNull()) { - results.push_back(nullptr); - } else { - results.push_back(state.getPayloadOps(getIn()).front()); - } - - if (getSecondResultIsHandle()) { - results.push_back(state.getPayloadOps(getIn()).front()); - } else { - results.push_back(builder.getI64IntegerAttr(42)); - } - - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceNullPayloadOp::getEffects( - SmallVectorImpl& effects) { - transform::producesHandle(getOut(), effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceNullPayloadOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - SmallVector null({nullptr}); - results.set(getOut().cast(), null); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceNullParamOp::getEffects( - SmallVectorImpl& effects) { - transform::producesHandle(getOut(), effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceNullParamOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - results.setParams(getOut().cast(), Attribute()); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestProduceNullValueOp::getEffects( - SmallVectorImpl& effects) { - transform::producesHandle(getOut(), effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestProduceNullValueOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - results.setValues(getOut().cast(), Value()); - return DiagnosedSilenceableFailure::success(); -} - -void mlir::transform::TestRequiredMemoryEffectsOp::getEffects( - SmallVectorImpl& effects) { - if (getHasOperandEffect()) transform::consumesHandle(getIn(), effects); - - if (getHasResultEffect()) - transform::producesHandle(getOut(), effects); - else - transform::onlyReadsHandle(getOut(), effects); - - if (getModifiesPayload()) transform::modifiesPayload(effects); -} - -DiagnosedSilenceableFailure mlir::transform::TestRequiredMemoryEffectsOp::apply( - transform::TransformResults& results, transform::TransformState& state) { - results.set(getOut().cast(), state.getPayloadOps(getIn())); - return DiagnosedSilenceableFailure::success(); -} - -namespace { -/// Test extension of the Transform dialect. Registers additional ops and -/// declares PDL as dependent dialect since the additional ops are using PDL -/// types for operands and results. -class TestTransformDialectExtension - : public transform::TransformDialectExtension { - public: - using Base::Base; - - void init() { - declareDependentDialect(); - registerTransformOps(); - registerTypes< -#define GET_TYPEDEF_LIST -#include "Transform/TestTransformDialectExtensionTypes.cpp.inc" - >(); - } -}; -} // namespace - -// These are automatically generated by ODS but are not used as the Transform -// dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult generatedTypeParser(AsmParser& parser, - StringRef* mnemonic, - Type& value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult generatedTypePrinter(Type def, AsmPrinter& printer); - -#define GET_TYPEDEF_CLASSES -#include "Transform/TestTransformDialectExtensionTypes.cpp.inc" - -#define GET_OP_CLASSES -#include "Transform/TestTransformDialectExtension.cpp.inc" - -void ::transform::registerTestTransformDialectExtension(DialectRegistry& registry) { - registry.addExtensions(); -} diff --git a/oneflow/ir/lib/Transform/TransformDialectExtension.cpp b/oneflow/ir/lib/Transform/TransformDialectExtension.cpp new file mode 100644 index 00000000000..7c83014e6a5 --- /dev/null +++ b/oneflow/ir/lib/Transform/TransformDialectExtension.cpp @@ -0,0 +1,132 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +//===- TransformDialectExtension.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an extension of the MLIR Transform dialect for testing +// purposes. +// +//===----------------------------------------------------------------------===// + +#include "Transform/TransformDialectExtension.h" +#include "Transform/TransformStateExtension.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::oneflow; + +//===---------------------------------------------------------------------===// +// ApplyPatternsOp +//===---------------------------------------------------------------------===// +void transform_dialect::ApplyPatternsOp::build( + OpBuilder& builder, OperationState& result, Value target, + const transform_dialect::ApplyPatternsOpPatterns& patterns) { + result.addOperands(target); + + auto unitAttr = builder.getUnitAttr(); + +#define ADD_PATTERN(NAME, ATTR) \ + if (patterns.NAME) result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr); + + ADD_PATTERN(canonicalization, getCanonicalizationAttrName) +#undef ADD_PATTERN +} + +static void addAllRegisteredCanonicalizationPatterns(RewritePatternSet& patterns) { + MLIRContext* ctx = patterns.getContext(); + for (Dialect* dialect : ctx->getLoadedDialects()) dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); +} + +DiagnosedSilenceableFailure transform_dialect::ApplyPatternsOp::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure( + target, "applies only to isolated-from-above targets because it needs to apply " + "patterns greedily"); + } + MLIRContext* ctx = target->getContext(); + RewritePatternSet patterns(ctx); + if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns); + SmallVector ops; + GreedyRewriteConfig config; + target->walk([&](Operation* nestedOp) { + if (target != nestedOp) ops.push_back(nestedOp); + }); + LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config); + if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); } + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::ApplyPatternsOp::getEffects( + SmallVectorImpl& effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +namespace { +class OneFlowTransformDialectExtension + : public transform::TransformDialectExtension { + public: + using Base::Base; + + void init() { + declareDependentDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "Transform/TransformDialectExtension.cpp.inc" + >(); + registerTypes< +#define GET_TYPEDEF_LIST +#include "Transform/TransformDialectExtensionTypes.cpp.inc" + >(); + } +}; +} // namespace + +// These are automatically generated by ODS but are not used as the Transform +// dialect uses a different dispatch mechanism to support dialect extensions. +LLVM_ATTRIBUTE_UNUSED static OptionalParseResult generatedTypeParser(AsmParser& parser, + StringRef* mnemonic, + Type& value); +LLVM_ATTRIBUTE_UNUSED static LogicalResult generatedTypePrinter(Type def, AsmPrinter& printer); + +#define GET_TYPEDEF_CLASSES +#include "Transform/TransformDialectExtensionTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "Transform/TransformDialectExtension.cpp.inc" + +void mlir::oneflow::transform_dialect::registerTransformDialectExtension( + DialectRegistry& registry) { + registry.addExtensions(); +} diff --git a/oneflow/ir/lib/Transform/TestTransformDialectInterpreter.cpp b/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp similarity index 89% rename from oneflow/ir/lib/Transform/TestTransformDialectInterpreter.cpp rename to oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp index 4bd02daf94a..4a112d3a417 100644 --- a/oneflow/ir/lib/Transform/TestTransformDialectInterpreter.cpp +++ b/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -//===- TestTransformDialectInterpreter.cpp --------------------------------===// +//===- TransformDialectInterpreter.cpp --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -41,17 +41,17 @@ namespace { template class OpPassWrapper : public PassWrapper> {}; -class TestTransformDialectInterpreterPass - : public transform::TransformInterpreterPassBase { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectInterpreterPass) + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectInterpreterPass) - TestTransformDialectInterpreterPass() = default; - TestTransformDialectInterpreterPass(const TestTransformDialectInterpreterPass& pass) + TransformDialectInterpreterPass() = default; + TransformDialectInterpreterPass(const TransformDialectInterpreterPass& pass) : TransformInterpreterPassBase(pass) {} - StringRef getArgument() const override { return "test-transform-dialect-interpreter"; } + StringRef getArgument() const override { return "oneflow-transform-dialect-interpreter"; } StringRef getDescription() const override { return "apply transform dialect operations one by one"; @@ -195,11 +195,11 @@ class TestTransformDialectInterpreterPass "definitions to be injected into the transform module.")}; }; -struct TestTransformDialectEraseSchedulePass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectEraseSchedulePass) +struct TransformDialectEraseSchedulePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectEraseSchedulePass) - StringRef getArgument() const final { return "test-transform-dialect-erase-schedule"; } + StringRef getArgument() const final { return "oneflow-transform-dialect-erase-schedule"; } StringRef getDescription() const final { return "erase transform dialect schedule from the IR"; } @@ -216,14 +216,14 @@ struct TestTransformDialectEraseSchedulePass } // namespace namespace mlir { -namespace transform { +namespace transform_dialect { /// Registers the test pass for erasing transform dialect ops. -void registerTestTransformDialectEraseSchedulePass() { - PassRegistration reg; +void registerTransformDialectEraseSchedulePass() { + PassRegistration reg; } /// Registers the test pass for applying transform dialect ops. -void registerTestTransformDialectInterpreterPass() { - PassRegistration reg; +void registerTransformDialectInterpreterPass() { + PassRegistration reg; } -} // namespace transform +} // namespace transform_dialect } // namespace mlir diff --git a/oneflow/ir/lib/Transform/TestTransformStateExtension.cpp b/oneflow/ir/lib/Transform/TransformStateExtension.cpp similarity index 88% rename from oneflow/ir/lib/Transform/TestTransformStateExtension.cpp rename to oneflow/ir/lib/Transform/TransformStateExtension.cpp index 5b3fc77059f..8a130a8d66e 100644 --- a/oneflow/ir/lib/Transform/TestTransformStateExtension.cpp +++ b/oneflow/ir/lib/Transform/TransformStateExtension.cpp @@ -21,12 +21,12 @@ limitations under the License. // //===----------------------------------------------------------------------===// -#include "Transform/TestTransformStateExtension.h" +#include "Transform/TransformStateExtension.h" using namespace mlir; -LogicalResult transform::TestTransformStateExtension::updateMapping(Operation* previous, - Operation* updated) { +LogicalResult mlir::oneflow::transform_dialect::TestTransformStateExtension::updateMapping( + Operation* previous, Operation* updated) { // Update value handles. The new ops should have at least as many results as // the replacement op. Fewer results are acceptable, if those results are not // mapped to any handle. diff --git a/oneflow/ir/oneflow-opt/CMakeLists.txt b/oneflow/ir/oneflow-opt/CMakeLists.txt index 5746cb01bb8..e45998047b6 100644 --- a/oneflow/ir/oneflow-opt/CMakeLists.txt +++ b/oneflow/ir/oneflow-opt/CMakeLists.txt @@ -18,6 +18,6 @@ target_link_libraries( MLIROptLib $ MLIROneFlowExtension - MLIRTestTransformDialect) + MLIROneFlowTransformDialect) mlir_check_all_link_libraries(oneflow-opt) diff --git a/oneflow/ir/oneflow-opt/oneflow-opt.cpp b/oneflow/ir/oneflow-opt/oneflow-opt.cpp index c69ca3b6d07..187ad3cc14a 100644 --- a/oneflow/ir/oneflow-opt/oneflow-opt.cpp +++ b/oneflow/ir/oneflow-opt/oneflow-opt.cpp @@ -23,7 +23,7 @@ limitations under the License. #include "OneFlow/OKM/OKMDialect.h" #include "OneFlow/OKL/passes.h" #include "OneFlow/OKM/passes.h" -#include "Transform/TestTransformDialectExtension.h" +#include "Transform/TransformDialectExtension.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -38,13 +38,13 @@ int32_t main(int32_t argc, char** argv) { mlir::oneflow::registerPasses(); mlir::okm::registerPasses(); mlir::okl::registerPasses(); - mlir::transform::registerTestTransformDialectEraseSchedulePass(); - mlir::transform::registerTestTransformDialectInterpreterPass(); + mlir::oneflow::transform_dialect::registerTransformDialectEraseSchedulePass(); + mlir::oneflow::transform_dialect::registerTransformDialectInterpreterPass(); mlir::DialectRegistry registry; // Note: register all mlir dialect and their extension. mlir::registerAllDialects(registry); - mlir::transform::registerTestTransformDialectExtension(registry); + mlir::oneflow::transform_dialect::registerTransformDialectExtension(registry); registry.insert(); registry.insert(); registry.insert(); diff --git a/oneflow/ir/test/Transform/apply-patterns.mlir b/oneflow/ir/test/Transform/apply-patterns.mlir new file mode 100644 index 00000000000..4e8c0b783ae --- /dev/null +++ b/oneflow/ir/test/Transform/apply-patterns.mlir @@ -0,0 +1,20 @@ +// RUN: oneflow-opt %s -oneflow-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: @select_cmp_eq_select +// CHECK: return %arg1 +func.func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = arith.cmpi eq, %arg0, %arg1 : i64 + %1 = arith.select %0, %arg0, %arg1 : i64 + return %1 : i64 +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 : !pdl.operation failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.oneflow.apply_patterns %0 { canonicalization } : (!pdl.operation) -> () + } +} + +// ----- \ No newline at end of file From 6d1f0efb44f354f9c882f55d170f844a94dabfa8 Mon Sep 17 00:00:00 2001 From: yuhao <1171760467@qq.com> Date: Thu, 11 May 2023 15:47:41 +0800 Subject: [PATCH 2/7] fix --- oneflow/ir/include/Transform/TransformStateExtension.h | 4 ++-- oneflow/ir/lib/Transform/TransformStateExtension.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/oneflow/ir/include/Transform/TransformStateExtension.h b/oneflow/ir/include/Transform/TransformStateExtension.h index 2d9929895b3..c3d0589dbda 100644 --- a/oneflow/ir/include/Transform/TransformStateExtension.h +++ b/oneflow/ir/include/Transform/TransformStateExtension.h @@ -23,9 +23,9 @@ namespace mlir { namespace oneflow { namespace transform_dialect { -class TestTransformStateExtension : public ::mlir::transform::TransformState::Extension { +class TransformStateExtension : public ::mlir::transform::TransformState::Extension { public: - TestTransformStateExtension(::mlir::transform::TransformState& state, StringAttr message) + TransformStateExtension(::mlir::transform::TransformState& state, StringAttr message) : Extension(state), message(message) {} StringRef getMessage() const { return message.getValue(); } diff --git a/oneflow/ir/lib/Transform/TransformStateExtension.cpp b/oneflow/ir/lib/Transform/TransformStateExtension.cpp index 8a130a8d66e..b5f0fab36c8 100644 --- a/oneflow/ir/lib/Transform/TransformStateExtension.cpp +++ b/oneflow/ir/lib/Transform/TransformStateExtension.cpp @@ -25,7 +25,7 @@ limitations under the License. using namespace mlir; -LogicalResult mlir::oneflow::transform_dialect::TestTransformStateExtension::updateMapping( +LogicalResult mlir::oneflow::transform_dialect::TransformStateExtension::updateMapping( Operation* previous, Operation* updated) { // Update value handles. The new ops should have at least as many results as // the replacement op. Fewer results are acceptable, if those results are not From 56b698a8f41841e977f8001f3c885bd0c8ea91c1 Mon Sep 17 00:00:00 2001 From: yuhao <1171760467@qq.com> Date: Thu, 11 May 2023 15:52:41 +0800 Subject: [PATCH 3/7] fix --- oneflow/ir/include/Transform/TransformDialectExtension.td | 4 ++-- oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/oneflow/ir/include/Transform/TransformDialectExtension.td b/oneflow/ir/include/Transform/TransformDialectExtension.td index 6d386cc316e..3afd123a59c 100644 --- a/oneflow/ir/include/Transform/TransformDialectExtension.td +++ b/oneflow/ir/include/Transform/TransformDialectExtension.td @@ -29,8 +29,8 @@ def ApplyPatternsOp : Op reg; } } // namespace transform_dialect +} // namespace oneflow } // namespace mlir From 4911ab0e12070f6bdab2d665dde68607908e2d62 Mon Sep 17 00:00:00 2001 From: yuhao <1171760467@qq.com> Date: Thu, 11 May 2023 15:57:04 +0800 Subject: [PATCH 4/7] fix --- oneflow/ir/test/Transform/apply-patterns.mlir | 4 +--- oneflow/ir/test/Transform/softmax.mlir | 2 +- oneflow/ir/test/Transform/test-dialect.mlir | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/oneflow/ir/test/Transform/apply-patterns.mlir b/oneflow/ir/test/Transform/apply-patterns.mlir index 4e8c0b783ae..0d1a6180475 100644 --- a/oneflow/ir/test/Transform/apply-patterns.mlir +++ b/oneflow/ir/test/Transform/apply-patterns.mlir @@ -15,6 +15,4 @@ transform.with_pdl_patterns { %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.oneflow.apply_patterns %0 { canonicalization } : (!pdl.operation) -> () } -} - -// ----- \ No newline at end of file +} \ No newline at end of file diff --git a/oneflow/ir/test/Transform/softmax.mlir b/oneflow/ir/test/Transform/softmax.mlir index a54cf4783f7..985ce298afe 100644 --- a/oneflow/ir/test/Transform/softmax.mlir +++ b/oneflow/ir/test/Transform/softmax.mlir @@ -1,4 +1,4 @@ -// RUN: oneflow-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/softmax-trait.mlir})" +// RUN: oneflow-opt %s --pass-pipeline="builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax-trait.mlir})" !tmp_tensor_t = tensor<16x128xf32> !in_tensor_t = tensor<16x128x128xf32> !out_tensor_t = tensor<16x128x128xf32> diff --git a/oneflow/ir/test/Transform/test-dialect.mlir b/oneflow/ir/test/Transform/test-dialect.mlir index a7b96b2bed6..9de807a278b 100644 --- a/oneflow/ir/test/Transform/test-dialect.mlir +++ b/oneflow/ir/test/Transform/test-dialect.mlir @@ -1,4 +1,4 @@ -// RUN: oneflow-opt --test-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: oneflow-opt --oneflow-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s // Test One-Shot Bufferize. From 787401f14e4c9eec1549de42c56f180dec408c5a Mon Sep 17 00:00:00 2001 From: yuhao <1171760467@qq.com> Date: Thu, 11 May 2023 16:00:13 +0800 Subject: [PATCH 5/7] fix --- .../include/Transform/TransformDialectExtension.h | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/oneflow/ir/include/Transform/TransformDialectExtension.h b/oneflow/ir/include/Transform/TransformDialectExtension.h index 47dbd7dab65..107758e8fc1 100644 --- a/oneflow/ir/include/Transform/TransformDialectExtension.h +++ b/oneflow/ir/include/Transform/TransformDialectExtension.h @@ -13,19 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -//===- TransformDialectExtension.h --------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines an extension of the MLIR Transform dialect for testing -// purposes. -// -//===----------------------------------------------------------------------===// - #ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ #define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ From 4fe35fd7791b243fe95cb689b6e19f8e392a4947 Mon Sep 17 00:00:00 2001 From: yuhao <1171760467@qq.com> Date: Thu, 11 May 2023 16:01:28 +0800 Subject: [PATCH 6/7] fix --- .../Transform/TransformDialectExtension.cpp | 28 ------------------- .../Transform/TransformDialectInterpreter.cpp | 28 ------------------- .../lib/Transform/TransformStateExtension.cpp | 23 --------------- 3 files changed, 79 deletions(-) diff --git a/oneflow/ir/lib/Transform/TransformDialectExtension.cpp b/oneflow/ir/lib/Transform/TransformDialectExtension.cpp index 7c83014e6a5..e5ab238b1a4 100644 --- a/oneflow/ir/lib/Transform/TransformDialectExtension.cpp +++ b/oneflow/ir/lib/Transform/TransformDialectExtension.cpp @@ -1,31 +1,3 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -//===- TransformDialectExtension.cpp ----------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines an extension of the MLIR Transform dialect for testing -// purposes. -// -//===----------------------------------------------------------------------===// - #include "Transform/TransformDialectExtension.h" #include "Transform/TransformStateExtension.h" #include "mlir/Dialect/PDL/IR/PDL.h" diff --git a/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp b/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp index e1c8ee0dd49..67058732944 100644 --- a/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp +++ b/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp @@ -1,31 +1,3 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -//===- TransformDialectInterpreter.cpp --------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines a test pass that interprets Transform dialect operations in -// the module. -// -//===----------------------------------------------------------------------===// - #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" diff --git a/oneflow/ir/lib/Transform/TransformStateExtension.cpp b/oneflow/ir/lib/Transform/TransformStateExtension.cpp index b5f0fab36c8..4a29e26ebff 100644 --- a/oneflow/ir/lib/Transform/TransformStateExtension.cpp +++ b/oneflow/ir/lib/Transform/TransformStateExtension.cpp @@ -1,26 +1,3 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -//===- TestTransformStateExtension.cpp ------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - #include "Transform/TransformStateExtension.h" using namespace mlir; From 5a1b9ac80a9a5a578fb92325affd8c98a0e7c714 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 11 May 2023 08:03:09 +0000 Subject: [PATCH 7/7] auto format by CI --- oneflow/ir/include/Transform/CMakeLists.txt | 6 ++---- .../lib/Transform/TransformDialectExtension.cpp | 15 +++++++++++++++ .../lib/Transform/TransformDialectInterpreter.cpp | 15 +++++++++++++++ .../ir/lib/Transform/TransformStateExtension.cpp | 15 +++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/oneflow/ir/include/Transform/CMakeLists.txt b/oneflow/ir/include/Transform/CMakeLists.txt index f635a0f61de..2477f35aeb5 100644 --- a/oneflow/ir/include/Transform/CMakeLists.txt +++ b/oneflow/ir/include/Transform/CMakeLists.txt @@ -1,8 +1,6 @@ set(LLVM_TARGET_DEFINITIONS TransformDialectExtension.td) mlir_tablegen(TransformDialectExtension.h.inc -gen-op-decls) mlir_tablegen(TransformDialectExtension.cpp.inc -gen-op-defs) -mlir_tablegen(TransformDialectExtensionTypes.h.inc -gen-typedef-decls - -typedefs-dialect=transform) -mlir_tablegen(TransformDialectExtensionTypes.cpp.inc -gen-typedef-defs - -typedefs-dialect=transform) +mlir_tablegen(TransformDialectExtensionTypes.h.inc -gen-typedef-decls -typedefs-dialect=transform) +mlir_tablegen(TransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=transform) add_public_tablegen_target(MLIROneFlowTransformDialectExtensionIncGen) diff --git a/oneflow/ir/lib/Transform/TransformDialectExtension.cpp b/oneflow/ir/lib/Transform/TransformDialectExtension.cpp index e5ab238b1a4..bff5ab524d6 100644 --- a/oneflow/ir/lib/Transform/TransformDialectExtension.cpp +++ b/oneflow/ir/lib/Transform/TransformDialectExtension.cpp @@ -1,3 +1,18 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ #include "Transform/TransformDialectExtension.h" #include "Transform/TransformStateExtension.h" #include "mlir/Dialect/PDL/IR/PDL.h" diff --git a/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp b/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp index 67058732944..c5051ea9eb4 100644 --- a/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp +++ b/oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp @@ -1,3 +1,18 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" diff --git a/oneflow/ir/lib/Transform/TransformStateExtension.cpp b/oneflow/ir/lib/Transform/TransformStateExtension.cpp index 4a29e26ebff..0f5e960e654 100644 --- a/oneflow/ir/lib/Transform/TransformStateExtension.cpp +++ b/oneflow/ir/lib/Transform/TransformStateExtension.cpp @@ -1,3 +1,18 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ #include "Transform/TransformStateExtension.h" using namespace mlir;