From c3d5769f5a3a579e603d2ae46c0def63ba3968c7 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 30 Oct 2024 02:22:14 -0700 Subject: [PATCH] Add tests for CudaStream::WaitFor PiperOrigin-RevId: 691339308 --- xla/stream_executor/cuda/BUILD | 1 + xla/stream_executor/cuda/cuda_stream_test.cc | 71 +++++++++++++++++++ xla/stream_executor/rocm/BUILD | 1 + xla/stream_executor/rocm/rocm_stream_test.cc | 73 ++++++++++++++++++++ 4 files changed, 146 insertions(+) diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 11a678b874f48..d825184c9ce25 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -1243,6 +1243,7 @@ xla_test( backends = ["gpu"], tags = ["cuda-only"], deps = [ + ":cuda_event", ":cuda_executor", ":cuda_platform_id", ":cuda_stream", diff --git a/xla/stream_executor/cuda/cuda_stream_test.cc b/xla/stream_executor/cuda/cuda_stream_test.cc index 2e905e990e648..c0b09644722d7 100644 --- a/xla/stream_executor/cuda/cuda_stream_test.cc +++ b/xla/stream_executor/cuda/cuda_stream_test.cc @@ -21,12 +21,14 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/cuda/cuda_executor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" @@ -47,7 +49,9 @@ namespace gpu { namespace { using ::testing::Each; +using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::IsEmpty; using ::tsl::testing::IsOk; class CudaStreamTest : public ::testing::Test { @@ -236,6 +240,73 @@ TEST_F(CudaStreamTest, SetName) { EXPECT_EQ(stream->GetName(), kStreamName); } +TEST_F(CudaStreamTest, WaitForEvent) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + TF_ASSERT_OK_AND_ASSIGN(CudaEvent event, + CudaEvent::Create(executor_, /*allow_timing=*/false)); + + EXPECT_THAT(stream->WaitFor(&event), IsOk()); + + bool callback_called = false; + EXPECT_THAT( + stream->DoHostCallback([&callback_called]() { callback_called = true; }), + IsOk()); + + EXPECT_FALSE(callback_called); + EXPECT_THAT(stream->RecordEvent(&event), IsOk()); + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_TRUE(callback_called); +} + +TEST_F(CudaStreamTest, WaitForOtherStream) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream1, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream2, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + TF_ASSERT_OK_AND_ASSIGN(CudaEvent event, + CudaEvent::Create(executor_, /*allow_timing=*/false)); + + enum class ExecutionStage { + kBeforeWaitForEvent, + kAfterWaitForEvent, + kAfterWaitForStream + }; + + std::vector execution_order; + + // - stream1 waits for the event to be recorded and + // - stream2 waits for stream1 to be done. + // - Afterwards stream2 invokes the host callback. + EXPECT_THAT(stream1->DoHostCallback([&execution_order]() { + execution_order.push_back(ExecutionStage::kBeforeWaitForEvent); + }), + IsOk()); + EXPECT_THAT(stream1->WaitFor(&event), IsOk()); + EXPECT_THAT(stream1->DoHostCallback([&execution_order]() { + execution_order.push_back(ExecutionStage::kAfterWaitForEvent); + }), + IsOk()); + EXPECT_THAT(stream2->WaitFor(stream1.get()), IsOk()); + EXPECT_THAT(stream2->DoHostCallback([&execution_order]() { + execution_order.push_back(ExecutionStage::kAfterWaitForStream); + }), + IsOk()); + + EXPECT_THAT(execution_order, IsEmpty()); + EXPECT_THAT(stream1->RecordEvent(&event), IsOk()); + EXPECT_THAT(stream2->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(execution_order, + ElementsAre(ExecutionStage::kBeforeWaitForEvent, + ExecutionStage::kAfterWaitForEvent, + ExecutionStage::kAfterWaitForStream)); +} + } // namespace } // namespace gpu } // namespace stream_executor diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 6d405bb619e30..36daa421b4e20 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -1045,6 +1045,7 @@ xla_test( "manual", ]), deps = [ + ":rocm_event", ":rocm_executor", ":rocm_platform_id", ":rocm_stream", diff --git a/xla/stream_executor/rocm/rocm_stream_test.cc b/xla/stream_executor/rocm/rocm_stream_test.cc index 70acb6fdb7306..1bb7fcb71e06a 100644 --- a/xla/stream_executor/rocm/rocm_stream_test.cc +++ b/xla/stream_executor/rocm/rocm_stream_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_executor.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/typed_kernel_factory.h" @@ -46,7 +48,9 @@ namespace gpu { namespace { using ::testing::Each; +using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::IsEmpty; using ::tsl::testing::IsOk; class RocmStreamTest : public ::testing::Test { @@ -235,6 +239,75 @@ TEST_F(RocmStreamTest, SetName) { EXPECT_EQ(stream->GetName(), kStreamName); } +TEST_F(RocmStreamTest, WaitForEvent) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + TF_ASSERT_OK_AND_ASSIGN( + RocmEvent event, + RocmEvent::Create(&executor_.value(), /*allow_timing=*/false)); + + EXPECT_THAT(stream->WaitFor(&event), IsOk()); + + bool callback_called = false; + EXPECT_THAT( + stream->DoHostCallback([&callback_called]() { callback_called = true; }), + IsOk()); + + EXPECT_FALSE(callback_called); + EXPECT_THAT(stream->RecordEvent(&event), IsOk()); + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_TRUE(callback_called); +} + +TEST_F(RocmStreamTest, WaitForOtherStream) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream1, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream2, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + TF_ASSERT_OK_AND_ASSIGN( + RocmEvent event, + RocmEvent::Create(&executor_.value(), /*allow_timing=*/false)); + + enum class ExecutionStage { + kBeforeWaitForEvent, + kAfterWaitForEvent, + kAfterWaitForStream + }; + + std::vector execution_order; + + // - stream1 waits for the event to be recorded and + // - stream2 waits for stream1 to be done. + // - Afterwards stream2 invokes the host callback. + EXPECT_THAT(stream1->DoHostCallback([&execution_order]() { + execution_order.push_back(ExecutionStage::kBeforeWaitForEvent); + }), + IsOk()); + EXPECT_THAT(stream1->WaitFor(&event), IsOk()); + EXPECT_THAT(stream1->DoHostCallback([&execution_order]() { + execution_order.push_back(ExecutionStage::kAfterWaitForEvent); + }), + IsOk()); + EXPECT_THAT(stream2->WaitFor(stream1.get()), IsOk()); + EXPECT_THAT(stream2->DoHostCallback([&execution_order]() { + execution_order.push_back(ExecutionStage::kAfterWaitForStream); + }), + IsOk()); + + EXPECT_THAT(execution_order, IsEmpty()); + EXPECT_THAT(stream1->RecordEvent(&event), IsOk()); + EXPECT_THAT(stream2->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(execution_order, + ElementsAre(ExecutionStage::kBeforeWaitForEvent, + ExecutionStage::kAfterWaitForEvent, + ExecutionStage::kAfterWaitForStream)); +} + } // namespace } // namespace gpu } // namespace stream_executor