From 116507a01d61b120f46dcba5644634ebbad103aa Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 10 Dec 2024 15:58:44 -0700 Subject: [PATCH] [OM] Generalize handling for list creation ops in FreezePaths. (#7965) We previously handled the ListCreateOp specifically, because lists of paths need to become lists of frozen paths. But there are other list creation ops that need to be considered and handled similarly, like the recently added ListConcatOp. To handle this, the process method for ListCreateOp was updated to work on any generic Operation * that returns a ListType. The typeswitch that dispatches to the process methods was updated to use this generic processor for both ListCreateOp and ListConcatOp. I thought about writing a generic check instead of listing out the supported Operation classes, but that seems like a fragile tradeoff that might not be worth the cost relative to keeping this list up to date. --- lib/Dialect/OM/Transforms/FreezePaths.cpp | 17 +++++++++-------- test/Dialect/OM/freeze-paths.mlir | 11 +++++++---- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/OM/Transforms/FreezePaths.cpp b/lib/Dialect/OM/Transforms/FreezePaths.cpp index 7f627d75a03d..2ac161d28afa 100644 --- a/lib/Dialect/OM/Transforms/FreezePaths.cpp +++ b/lib/Dialect/OM/Transforms/FreezePaths.cpp @@ -41,7 +41,7 @@ struct PathVisitor { LogicalResult process(BasePathCreateOp pathOp); LogicalResult process(PathCreateOp pathOp); LogicalResult process(EmptyPathOp pathOp); - LogicalResult process(ListCreateOp listCreateOp); + LogicalResult processListCreator(Operation *listCreateOp); LogicalResult process(ObjectFieldOp objectFieldOp); LogicalResult run(ModuleOp module); hw::InstanceGraph &instanceGraph; @@ -257,8 +257,8 @@ LogicalResult PathVisitor::process(EmptyPathOp path) { } /// Replace a ListCreateOp of path types with frozen path types. -LogicalResult PathVisitor::process(ListCreateOp listCreateOp) { - ListType listType = listCreateOp.getResult().getType(); +LogicalResult PathVisitor::processListCreator(Operation *listCreateOp) { + ListType listType = cast(listCreateOp->getResult(0).getType()); // Check if there are any path types in the list(s). if (!hasPathType(listType)) @@ -269,9 +269,10 @@ LogicalResult PathVisitor::process(ListCreateOp listCreateOp) { // Create a new op with the result type updated to replace path types. OpBuilder builder(listCreateOp); - auto newListCreateOp = builder.create( - listCreateOp.getLoc(), newListType, listCreateOp.getOperands()); - listCreateOp.replaceAllUsesWith(newListCreateOp.getResult()); + auto *newListCreateOp = builder.create( + listCreateOp->getLoc(), listCreateOp->getName().getIdentifier(), + listCreateOp->getOperands(), {newListType}); + listCreateOp->getResult(0).replaceAllUsesWith(newListCreateOp->getResult(0)); listCreateOp->erase(); return success(); } @@ -316,8 +317,8 @@ LogicalResult PathVisitor::run(ModuleOp module) { } else if (auto path = dyn_cast(op)) { if (failed(process(path))) return WalkResult::interrupt(); - } else if (auto listCreate = dyn_cast(op)) { - if (failed(process(listCreate))) + } else if (isa(op)) { + if (failed(processListCreator(op))) return WalkResult::interrupt(); } else if (auto objectField = dyn_cast(op)) { if (failed(process(objectField))) diff --git a/test/Dialect/OM/freeze-paths.mlir b/test/Dialect/OM/freeze-paths.mlir index a27a676b49b8..4370c9ffd989 100644 --- a/test/Dialect/OM/freeze-paths.mlir +++ b/test/Dialect/OM/freeze-paths.mlir @@ -73,8 +73,8 @@ om.class @PathTest(%basepath : !om.basepath, %path : !om.path) { } // CHECK-LABEL: om.class @ListCreateTest -// CHECK-SAME: -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>) -om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.path) -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>) { +// CHECK-SAME: -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>, concatpath: !om.list>) +om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.path) -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>, concatpath: !om.list>) { // CHECK: [[NOT_PATH_LIST:%.+]] = om.list_create %notpath : i1 %0 = om.list_create %notpath : i1 @@ -87,8 +87,11 @@ om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.pat // CHECK: [[NESTED_PATH_LIST:%.+]] = om.list_create [[PATH_LIST]] : !om.list %3 = om.list_create %2 : !om.list - // CHECK: om.class.fields [[NOT_PATH_LIST]], [[BASE_PATH_LIST]], [[PATH_LIST]], [[NESTED_PATH_LIST]] : !om.list, !om.list, !om.list, !om.list> - om.class.fields %0, %1, %2, %3 : !om.list, !om.list, !om.list, !om.list> + // CHECK: [[CONCAT_PATH_LIST:%.+]] = om.list_concat [[NESTED_PATH_LIST]] : > + %4 = om.list_concat %3 : !om.list> + + // CHECK: om.class.fields [[NOT_PATH_LIST]], [[BASE_PATH_LIST]], [[PATH_LIST]], [[NESTED_PATH_LIST]], [[CONCAT_PATH_LIST]] : !om.list, !om.list, !om.list, !om.list> + om.class.fields %0, %1, %2, %3, %4 : !om.list, !om.list, !om.list, !om.list>, !om.list> } // CHECK-LABEL om.class @PathListClass(%pathList: !om.list) -> (pathList: !om.list