diff --git a/lib/Dialect/OM/Transforms/FreezePaths.cpp b/lib/Dialect/OM/Transforms/FreezePaths.cpp index 7f627d75a03d..2ac161d28afa 100644 --- a/lib/Dialect/OM/Transforms/FreezePaths.cpp +++ b/lib/Dialect/OM/Transforms/FreezePaths.cpp @@ -41,7 +41,7 @@ struct PathVisitor { LogicalResult process(BasePathCreateOp pathOp); LogicalResult process(PathCreateOp pathOp); LogicalResult process(EmptyPathOp pathOp); - LogicalResult process(ListCreateOp listCreateOp); + LogicalResult processListCreator(Operation *listCreateOp); LogicalResult process(ObjectFieldOp objectFieldOp); LogicalResult run(ModuleOp module); hw::InstanceGraph &instanceGraph; @@ -257,8 +257,8 @@ LogicalResult PathVisitor::process(EmptyPathOp path) { } /// Replace a ListCreateOp of path types with frozen path types. -LogicalResult PathVisitor::process(ListCreateOp listCreateOp) { - ListType listType = listCreateOp.getResult().getType(); +LogicalResult PathVisitor::processListCreator(Operation *listCreateOp) { + ListType listType = cast(listCreateOp->getResult(0).getType()); // Check if there are any path types in the list(s). if (!hasPathType(listType)) @@ -269,9 +269,10 @@ LogicalResult PathVisitor::process(ListCreateOp listCreateOp) { // Create a new op with the result type updated to replace path types. OpBuilder builder(listCreateOp); - auto newListCreateOp = builder.create( - listCreateOp.getLoc(), newListType, listCreateOp.getOperands()); - listCreateOp.replaceAllUsesWith(newListCreateOp.getResult()); + auto *newListCreateOp = builder.create( + listCreateOp->getLoc(), listCreateOp->getName().getIdentifier(), + listCreateOp->getOperands(), {newListType}); + listCreateOp->getResult(0).replaceAllUsesWith(newListCreateOp->getResult(0)); listCreateOp->erase(); return success(); } @@ -316,8 +317,8 @@ LogicalResult PathVisitor::run(ModuleOp module) { } else if (auto path = dyn_cast(op)) { if (failed(process(path))) return WalkResult::interrupt(); - } else if (auto listCreate = dyn_cast(op)) { - if (failed(process(listCreate))) + } else if (isa(op)) { + if (failed(processListCreator(op))) return WalkResult::interrupt(); } else if (auto objectField = dyn_cast(op)) { if (failed(process(objectField))) diff --git a/test/Dialect/OM/freeze-paths.mlir b/test/Dialect/OM/freeze-paths.mlir index a27a676b49b8..4370c9ffd989 100644 --- a/test/Dialect/OM/freeze-paths.mlir +++ b/test/Dialect/OM/freeze-paths.mlir @@ -73,8 +73,8 @@ om.class @PathTest(%basepath : !om.basepath, %path : !om.path) { } // CHECK-LABEL: om.class @ListCreateTest -// CHECK-SAME: -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>) -om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.path) -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>) { +// CHECK-SAME: -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>, concatpath: !om.list>) +om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.path) -> (notpath: !om.list, basepath: !om.list, path: !om.list, nestedpath: !om.list>, concatpath: !om.list>) { // CHECK: [[NOT_PATH_LIST:%.+]] = om.list_create %notpath : i1 %0 = om.list_create %notpath : i1 @@ -87,8 +87,11 @@ om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.pat // CHECK: [[NESTED_PATH_LIST:%.+]] = om.list_create [[PATH_LIST]] : !om.list %3 = om.list_create %2 : !om.list - // CHECK: om.class.fields [[NOT_PATH_LIST]], [[BASE_PATH_LIST]], [[PATH_LIST]], [[NESTED_PATH_LIST]] : !om.list, !om.list, !om.list, !om.list> - om.class.fields %0, %1, %2, %3 : !om.list, !om.list, !om.list, !om.list> + // CHECK: [[CONCAT_PATH_LIST:%.+]] = om.list_concat [[NESTED_PATH_LIST]] : > + %4 = om.list_concat %3 : !om.list> + + // CHECK: om.class.fields [[NOT_PATH_LIST]], [[BASE_PATH_LIST]], [[PATH_LIST]], [[NESTED_PATH_LIST]], [[CONCAT_PATH_LIST]] : !om.list, !om.list, !om.list, !om.list> + om.class.fields %0, %1, %2, %3, %4 : !om.list, !om.list, !om.list, !om.list>, !om.list> } // CHECK-LABEL om.class @PathListClass(%pathList: !om.list) -> (pathList: !om.list