Skip to content

Commit

Permalink
[OM] Generalize handling for list creation ops in FreezePaths. (#7965)
Browse files Browse the repository at this point in the history
We previously handled the ListCreateOp specifically, because lists of
paths need to become lists of frozen paths. But there are other list
creation ops that need to be considered and handled similarly, like
the recently added ListConcatOp.

To handle this, the process method for ListCreateOp was updated to
work on any generic Operation * that returns a ListType. The
typeswitch that dispatches to the process methods was updated to use
this generic processor for both ListCreateOp and ListConcatOp. I
thought about writing a generic check instead of listing out the
supported Operation classes, but that seems like a fragile tradeoff
that might not be worth the cost relative to keeping this list up to
date.
  • Loading branch information
mikeurbach authored Dec 10, 2024
1 parent 2aaf978 commit 116507a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
17 changes: 9 additions & 8 deletions lib/Dialect/OM/Transforms/FreezePaths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct PathVisitor {
LogicalResult process(BasePathCreateOp pathOp);
LogicalResult process(PathCreateOp pathOp);
LogicalResult process(EmptyPathOp pathOp);
LogicalResult process(ListCreateOp listCreateOp);
LogicalResult processListCreator(Operation *listCreateOp);
LogicalResult process(ObjectFieldOp objectFieldOp);
LogicalResult run(ModuleOp module);
hw::InstanceGraph &instanceGraph;
Expand Down Expand Up @@ -257,8 +257,8 @@ LogicalResult PathVisitor::process(EmptyPathOp path) {
}

/// Replace a ListCreateOp of path types with frozen path types.
LogicalResult PathVisitor::process(ListCreateOp listCreateOp) {
ListType listType = listCreateOp.getResult().getType();
LogicalResult PathVisitor::processListCreator(Operation *listCreateOp) {
ListType listType = cast<ListType>(listCreateOp->getResult(0).getType());

// Check if there are any path types in the list(s).
if (!hasPathType(listType))
Expand All @@ -269,9 +269,10 @@ LogicalResult PathVisitor::process(ListCreateOp listCreateOp) {

// Create a new op with the result type updated to replace path types.
OpBuilder builder(listCreateOp);
auto newListCreateOp = builder.create<ListCreateOp>(
listCreateOp.getLoc(), newListType, listCreateOp.getOperands());
listCreateOp.replaceAllUsesWith(newListCreateOp.getResult());
auto *newListCreateOp = builder.create(
listCreateOp->getLoc(), listCreateOp->getName().getIdentifier(),
listCreateOp->getOperands(), {newListType});
listCreateOp->getResult(0).replaceAllUsesWith(newListCreateOp->getResult(0));
listCreateOp->erase();
return success();
}
Expand Down Expand Up @@ -316,8 +317,8 @@ LogicalResult PathVisitor::run(ModuleOp module) {
} else if (auto path = dyn_cast<EmptyPathOp>(op)) {
if (failed(process(path)))
return WalkResult::interrupt();
} else if (auto listCreate = dyn_cast<ListCreateOp>(op)) {
if (failed(process(listCreate)))
} else if (isa<ListCreateOp, ListConcatOp>(op)) {
if (failed(processListCreator(op)))
return WalkResult::interrupt();
} else if (auto objectField = dyn_cast<ObjectFieldOp>(op)) {
if (failed(process(objectField)))
Expand Down
11 changes: 7 additions & 4 deletions test/Dialect/OM/freeze-paths.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ om.class @PathTest(%basepath : !om.basepath, %path : !om.path) {
}

// CHECK-LABEL: om.class @ListCreateTest
// CHECK-SAME: -> (notpath: !om.list<i1>, basepath: !om.list<!om.frozenbasepath>, path: !om.list<!om.frozenpath>, nestedpath: !om.list<!om.list<!om.frozenpath>>)
om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.path) -> (notpath: !om.list<i1>, basepath: !om.list<!om.basepath>, path: !om.list<!om.path>, nestedpath: !om.list<!om.list<!om.path>>) {
// CHECK-SAME: -> (notpath: !om.list<i1>, basepath: !om.list<!om.frozenbasepath>, path: !om.list<!om.frozenpath>, nestedpath: !om.list<!om.list<!om.frozenpath>>, concatpath: !om.list<!om.list<!om.frozenpath>>)
om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.path) -> (notpath: !om.list<i1>, basepath: !om.list<!om.basepath>, path: !om.list<!om.path>, nestedpath: !om.list<!om.list<!om.path>>, concatpath: !om.list<!om.list<!om.path>>) {
// CHECK: [[NOT_PATH_LIST:%.+]] = om.list_create %notpath : i1
%0 = om.list_create %notpath : i1

Expand All @@ -87,8 +87,11 @@ om.class @ListCreateTest(%notpath: i1, %basepath : !om.basepath, %path : !om.pat
// CHECK: [[NESTED_PATH_LIST:%.+]] = om.list_create [[PATH_LIST]] : !om.list<!om.frozenpath>
%3 = om.list_create %2 : !om.list<!om.path>

// CHECK: om.class.fields [[NOT_PATH_LIST]], [[BASE_PATH_LIST]], [[PATH_LIST]], [[NESTED_PATH_LIST]] : !om.list<i1>, !om.list<!om.frozenbasepath>, !om.list<!om.frozenpath>, !om.list<!om.list<!om.frozenpath>>
om.class.fields %0, %1, %2, %3 : !om.list<i1>, !om.list<!om.basepath>, !om.list<!om.path>, !om.list<!om.list<!om.path>>
// CHECK: [[CONCAT_PATH_LIST:%.+]] = om.list_concat [[NESTED_PATH_LIST]] : <!om.list<!om.frozenpath>>
%4 = om.list_concat %3 : !om.list<!om.list<!om.path>>

// CHECK: om.class.fields [[NOT_PATH_LIST]], [[BASE_PATH_LIST]], [[PATH_LIST]], [[NESTED_PATH_LIST]], [[CONCAT_PATH_LIST]] : !om.list<i1>, !om.list<!om.frozenbasepath>, !om.list<!om.frozenpath>, !om.list<!om.list<!om.frozenpath>>
om.class.fields %0, %1, %2, %3, %4 : !om.list<i1>, !om.list<!om.basepath>, !om.list<!om.path>, !om.list<!om.list<!om.path>>, !om.list<!om.list<!om.path>>
}

// CHECK-LABEL om.class @PathListClass(%pathList: !om.list<!om.frozenpath>) -> (pathList: !om.list<!om.path>
Expand Down

0 comments on commit 116507a

Please sign in to comment.