diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 8ccbd9cfb1595..6ac0912f6f706 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -791,7 +791,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, RecursiveMemoryEffects, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">]> { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 94e7dd4a0bf44..20a7b283c938d 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2936,6 +2936,30 @@ void ParallelOp::print(OpAsmPrinter &p) { SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } +std::optional ParallelOp::getSingleInductionVar() { + if (getNumLoops() != 1) + return std::nullopt; + return getBody()->getArgument(0); +} + +std::optional ParallelOp::getSingleLowerBound() { + if (getNumLoops() != 1) + return std::nullopt; + return getLowerBound()[0]; +} + +std::optional ParallelOp::getSingleUpperBound() { + if (getNumLoops() != 1) + return std::nullopt; + return getUpperBound()[0]; +} + +std::optional ParallelOp::getSingleStep() { + if (getNumLoops() != 1) + return std::nullopt; + return getStep()[0]; +} + ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { auto ivArg = llvm::dyn_cast(val); if (!ivArg) diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 2d2835c64b984..fbb73e8f499a3 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -9,6 +9,7 @@ target_link_libraries(MLIRDialectTests add_subdirectory(Index) add_subdirectory(LLVMIR) add_subdirectory(MemRef) +add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Transform) diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt new file mode 100644 index 0000000000000..4d23392af1f88 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRSCFTests + LoopLikeSCFOpsTest.cpp +) +target_link_libraries(MLIRSCFTests + PRIVATE + MLIRIR + MLIRSCFDialect +) diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp new file mode 100644 index 0000000000000..f75b84f12b6f1 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -0,0 +1,89 @@ +//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::scf; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class SCFLoopLikeTest : public ::testing::Test { +protected: + SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect(); + } + + void checkUnidimensional(LoopLikeOpInterface loopLikeOp) { + std::optional maybeLb = loopLikeOp.getSingleLowerBound(); + EXPECT_TRUE(maybeLb.has_value()); + std::optional maybeUb = loopLikeOp.getSingleUpperBound(); + EXPECT_TRUE(maybeUb.has_value()); + std::optional maybeStep = loopLikeOp.getSingleStep(); + EXPECT_TRUE(maybeStep.has_value()); + std::optional maybeIndVar = + loopLikeOp.getSingleInductionVar(); + EXPECT_TRUE(maybeIndVar.has_value()); + } + + void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { + std::optional maybeLb = loopLikeOp.getSingleLowerBound(); + EXPECT_FALSE(maybeLb.has_value()); + std::optional maybeUb = loopLikeOp.getSingleUpperBound(); + EXPECT_FALSE(maybeUb.has_value()); + std::optional maybeStep = loopLikeOp.getSingleStep(); + EXPECT_FALSE(maybeStep.has_value()); + std::optional maybeIndVar = + loopLikeOp.getSingleInductionVar(); + EXPECT_FALSE(maybeIndVar.has_value()); + } + + MLIRContext context; + OpBuilder b; + Location loc; +}; + +TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) { + Value lb = b.create(loc, 0); + Value ub = b.create(loc, 10); + Value step = b.create(loc, 2); + + auto forOp = b.create(loc, lb, ub, step); + checkUnidimensional(forOp); + + auto forallOp = b.create( + loc, ArrayRef(lb), ArrayRef(ub), + ArrayRef(step), ValueRange(), std::nullopt); + checkUnidimensional(forallOp); + + auto parallelOp = b.create( + loc, ValueRange(lb), ValueRange(ub), ValueRange(step), ValueRange()); + checkUnidimensional(parallelOp); +} + +TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { + Value lb = b.create(loc, 0); + Value ub = b.create(loc, 10); + Value step = b.create(loc, 2); + + auto forallOp = b.create( + loc, ArrayRef({lb, lb}), ArrayRef({ub, ub}), + ArrayRef({step, step}), ValueRange(), std::nullopt); + checkMultidimensional(forallOp); + + auto parallelOp = + b.create(loc, ValueRange({lb, lb}), ValueRange({ub, ub}), + ValueRange({step, step}), ValueRange()); + checkMultidimensional(parallelOp); +}