Skip to content

Commit

Permalink
Support JitGlobals on inline and outline flow dispatches. (iree-org#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik authored May 1, 2024
1 parent 652cb79 commit 463fed4
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 1 deletion.
46 changes: 46 additions & 0 deletions compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,45 @@ struct JitFunctionDesc {
llvm::SmallVector<ResultBinding> resultBindings;
};

// Clones all object-like symbols used within the function.
// Objects are only cloned once if used by multiple functions.
// All object contents are cloned and symbol DCE is relied on to remove any
// unused nested symbols later on.
static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp,
SymbolTable &sourceSymbolTable,
SymbolTable &targetSymbolTable,
OpBuilder &moduleBuilder) {
// Gather all symbol uses within the function.
auto uses = SymbolTable::getSymbolUses(funcOp);
if (!uses.has_value())
return success();

// Verify that all uses are to object-like types we can clone.
for (auto use : uses.value()) {
// Lookup the (maybe) object in the source module.
auto objectNameAttr = use.getSymbolRef().getRootReference();
auto *objectOp = sourceSymbolTable.lookup(objectNameAttr);
if (!objectOp) {
return use.getUser()->emitOpError()
<< "references undefined symbol " << use.getSymbolRef();
}
if (!objectOp->hasTrait<OpTrait::IREE::Util::ObjectLike>())
continue;

// Check if the object exists in the target yet. Since we create the
// target we know there should be no conflicts: the only symbols with the
// same name will be already cloned copies of the same source.
if (targetSymbolTable.lookup(objectNameAttr))
continue;

// Clone the object. It's isolated and safe to copy wholesale.
auto *clonedOp = moduleBuilder.clone(*objectOp);
targetSymbolTable.insert(clonedOp);
}

return success();
}

class ProgramBuilder {
public:
ProgramBuilder(ModuleOp sourceModuleOp,
Expand Down Expand Up @@ -437,6 +476,13 @@ class ProgramBuilder {
return failure();

OpBuilder moduleBuilder = OpBuilder::atBlockEnd(targetModuleOp.getBody());

// Find any object-like symbol references used by the initializer and
// clone them.
if (failed(cloneUsedObjects(initializerOp, sourceSymbolTable,
targetSymbolTable, moduleBuilder)))
return failure();

auto funcOp = moduleBuilder.create<IREE::Util::FuncOp>(
initializerOp.getLoc(), "jit_eval",
moduleBuilder.getFunctionType({}, {}));
Expand Down
82 changes: 82 additions & 0 deletions compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,85 @@ module @eval_op_with_no_inputs_currently_broken {
util.return
}
}

// -----

// Tests that dispatches to inlined dispatch regions are JITed.
// This calculates 42 + 4 + 4 to ensure we can handle primitive and tensor arg
// constants.

// CHECK-LABEL: @dispatch_inline
module @dispatch_inline {
// CHECK: util.global private @hoisted = dense<50> : tensor<4xi8>
util.global private @hoisted : tensor<4xi8>
// CHECK-NOT: util.initializer
util.initializer {
%cst0 = arith.constant 42 : i8
%cst1 = arith.constant dense<4> : tensor<4xi8>
%c0 = arith.constant 0 : index
%x = tensor.dim %cst1, %c0 : tensor<4xi8>
%0 = flow.dispatch.workgroups[%x](%cst0, %cst1) : (i8, tensor<4xi8>) -> tensor<4xi8> =
(%arg0: i8, %arg1: !flow.dispatch.tensor<readonly:tensor<4xi8>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<4xi8>>) {
%empty = tensor.empty() : tensor<4xi8>
%input = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[4], strides=[1] : !flow.dispatch.tensor<readonly:tensor<4xi8>> -> tensor<4xi8>
%output = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%input, %input : tensor<4xi8>, tensor<4xi8>) outs(%empty : tensor<4xi8>) {
^bb0(%arg3: i8, %arg4: i8, %arg5: i8):
%addi_x2 = arith.addi %arg3, %arg4 : i8
%result = arith.addi %addi_x2, %arg0 : i8
linalg.yield %result : i8
} -> tensor<4xi8>
flow.dispatch.tensor.store %output, %arg2, offsets=[0], sizes=[4], strides=[1] : tensor<4xi8> -> !flow.dispatch.tensor<writeonly:tensor<4xi8>>
flow.return
}
util.global.store %0, @hoisted : tensor<4xi8>
util.return
}
}

// -----

// Tests that dispatches to executable functions are JITed by cloning referenced
// executables to the JIT module. This calculates 42 + 4 + 4 to ensure we can
// handle primitive and tensor arg constants.

// CHECK-LABEL: @dispatch_executable
module @dispatch_executable {
flow.executable private @exe {
flow.executable.export public @dispatch_fn workgroups(%arg0: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func public @dispatch_fn(%arg0: i8, %arg1: !flow.dispatch.tensor<readonly:tensor<4xi8>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<4xi8>>) {
%empty = tensor.empty() : tensor<4xi8>
%input = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[4], strides=[1] : !flow.dispatch.tensor<readonly:tensor<4xi8>> -> tensor<4xi8>
%output = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%input, %input : tensor<4xi8>, tensor<4xi8>) outs(%empty : tensor<4xi8>) {
^bb0(%arg3: i8, %arg4: i8, %arg5: i8):
%addi_x2 = arith.addi %arg3, %arg4 : i8
%result = arith.addi %addi_x2, %arg0 : i8
linalg.yield %result : i8
} -> tensor<4xi8>
flow.dispatch.tensor.store %output, %arg2, offsets=[0], sizes=[4], strides=[1] : tensor<4xi8> -> !flow.dispatch.tensor<writeonly:tensor<4xi8>>
return
}
}
}
// CHECK: util.global private @hoisted = dense<50> : tensor<4xi8>
util.global private @hoisted : tensor<4xi8>
// CHECK-NOT: util.initializer
util.initializer {
%cst0 = arith.constant 42 : i8
%cst1 = arith.constant dense<4> : tensor<4xi8>
%c0 = arith.constant 0 : index
%x = tensor.dim %cst1, %c0 : tensor<4xi8>
%0 = flow.dispatch @exe::@dispatch_fn[%x](%cst0, %cst1) : (i8, tensor<4xi8>) -> tensor<4xi8>
util.global.store %0, @hoisted : tensor<4xi8>
util.return
}
}
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,12 @@ bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
return false;
}

bool DispatchWorkgroupsOp::isAtomicallyHoistableOp() { return true; }

bool DispatchWorkgroupsOp::isOperandHoistable(OpOperand *operand) {
return getOperandTiedResults(operand->getOperandNumber()).empty();
}

// Refines the tensor access from what is declared on |type| based on actual
// usage. We expect that the access was set correctly to begin with but today
// we sometimes specify things too wide.
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [
IsolatedFromAbove,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
DeclareOpInterfaceMethods<Util_HoistableOpInterface, [
"isAtomicallyHoistableOp",
"isOperandHoistable",
]>,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
]>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ static void populateEscapingProducers(Operation *parentOp,
if (itOp == parentOp) {
info.producers.insert(itOp->getOperands().begin(),
itOp->getOperands().end());
return;
return itOp->hasTrait<OpTrait::IsIsolatedFromAbove>()
? WalkResult::interrupt()
: WalkResult::advance();
}

// For nested operations, only consider that they escape if they are
Expand All @@ -39,6 +41,8 @@ static void populateEscapingProducers(Operation *parentOp,
info.producers.insert(operand);
}
}

return WalkResult::advance();
});
}

Expand Down

0 comments on commit 463fed4

Please sign in to comment.