Skip to content

Commit

Permalink
Add tests for CudaStream::WaitFor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691339308
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Oct 30, 2024
1 parent 116458b commit c3d5769
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,7 @@ xla_test(
backends = ["gpu"],
tags = ["cuda-only"],
deps = [
":cuda_event",
":cuda_executor",
":cuda_platform_id",
":cuda_stream",
Expand Down
71 changes: 71 additions & 0 deletions xla/stream_executor/cuda/cuda_stream_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/cuda/cuda_executor.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_memory.h"
Expand All @@ -47,7 +49,9 @@ namespace gpu {
namespace {

using ::testing::Each;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
using ::tsl::testing::IsOk;

class CudaStreamTest : public ::testing::Test {
Expand Down Expand Up @@ -236,6 +240,73 @@ TEST_F(CudaStreamTest, SetName) {
EXPECT_EQ(stream->GetName(), kStreamName);
}

TEST_F(CudaStreamTest, WaitForEvent) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<CudaStream> stream,
CudaStream::Create(executor_,
/*priority=*/std::nullopt));

TF_ASSERT_OK_AND_ASSIGN(CudaEvent event,
CudaEvent::Create(executor_, /*allow_timing=*/false));

EXPECT_THAT(stream->WaitFor(&event), IsOk());

bool callback_called = false;
EXPECT_THAT(
stream->DoHostCallback([&callback_called]() { callback_called = true; }),
IsOk());

EXPECT_FALSE(callback_called);
EXPECT_THAT(stream->RecordEvent(&event), IsOk());
EXPECT_THAT(stream->BlockHostUntilDone(), IsOk());
EXPECT_TRUE(callback_called);
}

TEST_F(CudaStreamTest, WaitForOtherStream) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<CudaStream> stream1,
CudaStream::Create(executor_,
/*priority=*/std::nullopt));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<CudaStream> stream2,
CudaStream::Create(executor_,
/*priority=*/std::nullopt));

TF_ASSERT_OK_AND_ASSIGN(CudaEvent event,
CudaEvent::Create(executor_, /*allow_timing=*/false));

enum class ExecutionStage {
kBeforeWaitForEvent,
kAfterWaitForEvent,
kAfterWaitForStream
};

std::vector<ExecutionStage> execution_order;

// - stream1 waits for the event to be recorded and
// - stream2 waits for stream1 to be done.
// - Afterwards stream2 invokes the host callback.
EXPECT_THAT(stream1->DoHostCallback([&execution_order]() {
execution_order.push_back(ExecutionStage::kBeforeWaitForEvent);
}),
IsOk());
EXPECT_THAT(stream1->WaitFor(&event), IsOk());
EXPECT_THAT(stream1->DoHostCallback([&execution_order]() {
execution_order.push_back(ExecutionStage::kAfterWaitForEvent);
}),
IsOk());
EXPECT_THAT(stream2->WaitFor(stream1.get()), IsOk());
EXPECT_THAT(stream2->DoHostCallback([&execution_order]() {
execution_order.push_back(ExecutionStage::kAfterWaitForStream);
}),
IsOk());

EXPECT_THAT(execution_order, IsEmpty());

This comment has been minimized.

Copy link
@hsharsha

hsharsha Oct 30, 2024

Contributor

Will this work? We are checking for elements of execution_order at line 304.
execution_order will be non empty.

EXPECT_THAT(stream1->RecordEvent(&event), IsOk());
EXPECT_THAT(stream2->BlockHostUntilDone(), IsOk());
EXPECT_THAT(execution_order,
ElementsAre(ExecutionStage::kBeforeWaitForEvent,
ExecutionStage::kAfterWaitForEvent,
ExecutionStage::kAfterWaitForStream));
}

} // namespace
} // namespace gpu
} // namespace stream_executor
1 change: 1 addition & 0 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ xla_test(
"manual",
]),
deps = [
":rocm_event",
":rocm_executor",
":rocm_platform_id",
":rocm_stream",
Expand Down
73 changes: 73 additions & 0 deletions xla/stream_executor/rocm/rocm_stream_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand All @@ -34,6 +35,7 @@ limitations under the License.
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/rocm/rocm_event.h"
#include "xla/stream_executor/rocm/rocm_executor.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/typed_kernel_factory.h"
Expand All @@ -46,7 +48,9 @@ namespace gpu {
namespace {

using ::testing::Each;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
using ::tsl::testing::IsOk;

class RocmStreamTest : public ::testing::Test {
Expand Down Expand Up @@ -235,6 +239,75 @@ TEST_F(RocmStreamTest, SetName) {
EXPECT_EQ(stream->GetName(), kStreamName);
}

TEST_F(RocmStreamTest, WaitForEvent) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<RocmStream> stream,
RocmStream::Create(&executor_.value(),
/*priority=*/std::nullopt));

TF_ASSERT_OK_AND_ASSIGN(
RocmEvent event,
RocmEvent::Create(&executor_.value(), /*allow_timing=*/false));

EXPECT_THAT(stream->WaitFor(&event), IsOk());

bool callback_called = false;
EXPECT_THAT(
stream->DoHostCallback([&callback_called]() { callback_called = true; }),
IsOk());

EXPECT_FALSE(callback_called);
EXPECT_THAT(stream->RecordEvent(&event), IsOk());
EXPECT_THAT(stream->BlockHostUntilDone(), IsOk());
EXPECT_TRUE(callback_called);
}

TEST_F(RocmStreamTest, WaitForOtherStream) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<RocmStream> stream1,
RocmStream::Create(&executor_.value(),
/*priority=*/std::nullopt));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<RocmStream> stream2,
RocmStream::Create(&executor_.value(),
/*priority=*/std::nullopt));

TF_ASSERT_OK_AND_ASSIGN(
RocmEvent event,
RocmEvent::Create(&executor_.value(), /*allow_timing=*/false));

enum class ExecutionStage {
kBeforeWaitForEvent,
kAfterWaitForEvent,
kAfterWaitForStream
};

std::vector<ExecutionStage> execution_order;

// - stream1 waits for the event to be recorded and
// - stream2 waits for stream1 to be done.
// - Afterwards stream2 invokes the host callback.
EXPECT_THAT(stream1->DoHostCallback([&execution_order]() {
execution_order.push_back(ExecutionStage::kBeforeWaitForEvent);
}),
IsOk());
EXPECT_THAT(stream1->WaitFor(&event), IsOk());
EXPECT_THAT(stream1->DoHostCallback([&execution_order]() {
execution_order.push_back(ExecutionStage::kAfterWaitForEvent);
}),
IsOk());
EXPECT_THAT(stream2->WaitFor(stream1.get()), IsOk());
EXPECT_THAT(stream2->DoHostCallback([&execution_order]() {
execution_order.push_back(ExecutionStage::kAfterWaitForStream);
}),
IsOk());

EXPECT_THAT(execution_order, IsEmpty());

This comment has been minimized.

Copy link
@hsharsha

hsharsha Oct 30, 2024

Contributor

Same as above comment

EXPECT_THAT(stream1->RecordEvent(&event), IsOk());
EXPECT_THAT(stream2->BlockHostUntilDone(), IsOk());
EXPECT_THAT(execution_order,
ElementsAre(ExecutionStage::kBeforeWaitForEvent,
ExecutionStage::kAfterWaitForEvent,
ExecutionStage::kAfterWaitForStream));
}

} // namespace
} // namespace gpu
} // namespace stream_executor

0 comments on commit c3d5769

Please sign in to comment.