diff --git a/tests/dialects/test_gpu.py b/tests/dialects/test_gpu.py index 597540cba8..8849f1f5f6 100644 --- a/tests/dialects/test_gpu.py +++ b/tests/dialects/test_gpu.py @@ -30,6 +30,7 @@ SubgroupSizeOp, TerminatorOp, ThreadIdOp, + WaitOp, YieldOp, ) from xdsl.ir import Block, Operation, Region, SSAValue @@ -400,6 +401,21 @@ def test_terminator(): assert isinstance(terminator, TerminatorOp) +def test_wait(): + waitOp = WaitOp() + + assert isinstance(waitOp, WaitOp) + assert waitOp.asyncToken is not None + assert isinstance(waitOp.asyncToken.type, AsyncTokenType) + + waitOp1 = WaitOp() + + waitOpWithDep = WaitOp([waitOp, waitOp1]) + assert waitOpWithDep.asyncToken is not None + assert waitOpWithDep.asyncDependencies[0] is waitOp.asyncToken + assert waitOpWithDep.asyncDependencies[1] is waitOp1.asyncToken + + def test_yield(): operands: list[SSAValue | Operation] = [ o diff --git a/tests/filecheck/dialects/gpu/ops.mlir b/tests/filecheck/dialects/gpu/ops.mlir index 3f940f5702..e563f6bf03 100644 --- a/tests/filecheck/dialects/gpu/ops.mlir +++ b/tests/filecheck/dialects/gpu/ops.mlir @@ -11,6 +11,8 @@ builtin.module attributes {"gpu.container_module"} { "gpu.host_register"(%unranked) : (memref<*xi32>) -> () "gpu.host_unregister"(%unranked) : (memref<*xi32>) -> () + %wait_token = "gpu.wait"() : () -> !gpu.async.token + %threadidx = "gpu.thread_id"() {"dimension" = #gpu} : () -> index %threadidy = "gpu.thread_id"() {"dimension" = #gpu} : () -> index %threadidz = "gpu.thread_id"() {"dimension" = #gpu} : () -> index @@ -89,6 +91,8 @@ builtin.module attributes {"gpu.container_module"} { // CHECK-NEXT: "gpu.host_register"(%{{.*}}) : (memref<*xi32>) -> () // CHECK-NEXT: "gpu.host_unregister"(%{{.*}}) : (memref<*xi32>) -> () + // CHECK-NEXT: %{{.*}} = "gpu.wait"() : () -> !gpu.async.token + // CHECK-NEXT: %{{.*}} = "gpu.thread_id"() <{"dimension" = #gpu}> : () -> index // CHECK-NEXT: %{{.*}} = "gpu.thread_id"() <{"dimension" = #gpu}> : () -> index // CHECK-NEXT: %{{.*}} = "gpu.thread_id"() <{"dimension" = #gpu}> : () -> index diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/ops.mlir index c47f12fbe5..02f788f3cc 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/ops.mlir @@ -11,6 +11,8 @@ "gpu.host_register"(%unranked) : (memref<*xi32>) -> () "gpu.host_unregister"(%unranked) : (memref<*xi32>) -> () + %wait_token = "gpu.wait"() : () -> !gpu.async.token + %threadidx = "gpu.thread_id"() {"dimension" = #gpu} : () -> index %threadidy = "gpu.thread_id"() {"dimension" = #gpu} : () -> index %threadidz = "gpu.thread_id"() {"dimension" = #gpu} : () -> index @@ -88,6 +90,8 @@ // CHECK-NEXT: "gpu.host_register"(%{{.*}}) : (memref<*xi32>) -> () // CHECK-NEXT: "gpu.host_unregister"(%{{.*}}) : (memref<*xi32>) -> () + // CHECK-NEXT: %{{.*}} = "gpu.wait"() : () -> !gpu.async.token + // CHECK-NEXT: %{{.*}} = "gpu.thread_id"() <{"dimension" = #gpu}> : () -> index // CHECK-NEXT: %{{.*}} = "gpu.thread_id"() <{"dimension" = #gpu}> : () -> index // CHECK-NEXT: %{{.*}} = "gpu.thread_id"() <{"dimension" = #gpu}> : () -> index diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 50834ee35c..67d1519540 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -730,6 +730,22 @@ def __init__(self, dim: DimensionAttr): super().__init__(result_types=[IndexType()], properties={"dimension": dim}) +@irdl_op_definition +class WaitOp(IRDLOperation): + name = "gpu.wait" + asyncDependencies: VarOperand = var_operand_def(AsyncTokenType) + asyncToken: OptOpResult = opt_result_def(AsyncTokenType) + + def __init__( + self, + async_dependencies: Sequence[SSAValue | Operation] | None = None, + ): + super().__init__( + operands=[async_dependencies], + result_types=[[AsyncTokenType()]], + ) + + @irdl_op_definition class YieldOp(IRDLOperation): name = "gpu.yield" @@ -779,9 +795,11 @@ def verify_(self) -> None: SubgroupSizeOp, TerminatorOp, ThreadIdOp, + WaitOp, YieldOp, ], [ + AsyncTokenType, AllReduceOpAttr, DimensionAttr, ProcessorAttr,