Skip to content

Commit

Permalink
Merge pull request #2170 from igchor/enqueueWaitTest
Browse files Browse the repository at this point in the history
[CTS] extend tests for urEnqueueEventsWait
  • Loading branch information
pbalcer authored Oct 22, 2024
2 parents c32a78c + 2c60671 commit d8cc532
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 49 deletions.
1 change: 1 addition & 0 deletions test/conformance/enqueue/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_conformance_test_with_kernels_environment(enqueue
urEnqueueDeviceGlobalVariableRead.cpp
urEnqueueDeviceGlobalVariableWrite.cpp
urEnqueueEventsWait.cpp
urEnqueueEventsWaitMultiDevice.cpp
urEnqueueEventsWaitWithBarrier.cpp
urEnqueueKernelLaunch.cpp
urEnqueueKernelLaunchAndMemcpyInOrder.cpp
Expand Down
9 changes: 9 additions & 0 deletions test/conformance/enqueue/enqueue_adapter_native_cpu.match
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
{{NONDETERMINISTIC}}
{{OPT}}urEnqueueEventsWaitMultiDeviceTest.EmptyWaitList
{{OPT}}urEnqueueEventsWaitMultiDeviceTest.EmptyWaitListWithEvent
{{OPT}}urEnqueueEventsWaitMultiDeviceTest.EnqueueWaitOnADifferentQueue
{{OPT}}urEnqueueDeviceGetGlobalVariableReadTest.Success/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}urEnqueueDeviceGetGlobalVariableReadTest.InvalidNullHandleQueue/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}urEnqueueDeviceGetGlobalVariableReadTest.InvalidNullHandleProgram/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
Expand All @@ -16,6 +19,12 @@
{{OPT}}urEnqueueDeviceGetGlobalVariableWriteTest.InvalidEventWaitInvalidEvent/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}urEnqueueEventsWaitTest.Success/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}urEnqueueEventsWaitTest.InvalidNullPtrEventWaitList/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}urEnqueueEventsWaitMultiDeviceMTTest.EnqueueWaitSingleQueueMultiOps/MultiThread
{{OPT}}urEnqueueEventsWaitMultiDeviceMTTest.EnqueueWaitSingleQueueMultiOps/NoMultiThread
{{OPT}}urEnqueueEventsWaitMultiDeviceMTTest.EnqueueWaitOnAllQueues/MultiThread
{{OPT}}urEnqueueEventsWaitMultiDeviceMTTest.EnqueueWaitOnAllQueues/NoMultiThread
{{OPT}}urEnqueueEventsWaitMultiDeviceMTTest.EnqueueWaitOnAllQueuesCommonDependency/MultiThread
{{OPT}}urEnqueueEventsWaitMultiDeviceMTTest.EnqueueWaitOnAllQueuesCommonDependency/NoMultiThread
{{OPT}}urEnqueueEventsWaitWithBarrierTest.Success/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}urEnqueueEventsWaitWithBarrierTest.InvalidNullPtrEventWaitList/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
urEnqueueEventsWaitWithBarrierOrderingTest.SuccessEventDependenciesBarrierOnly/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}_
Expand Down
44 changes: 44 additions & 0 deletions test/conformance/enqueue/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,50 @@ printFillTestString(const testing::TestParamInfo<typename T::ParamType> &info) {
return test_name.str();
}

struct urMultiQueueMultiDeviceTest : uur::urMultiDeviceContextTestTemplate<1> {
void initQueues(std::vector<ur_device_handle_t> srcDevices,
size_t numDuplicate) {
for (size_t i = 0; i < numDuplicate; i++) {
devices.insert(devices.end(), srcDevices.begin(), srcDevices.end());
}

for (auto &device : devices) {
ur_queue_handle_t queue = nullptr;
ASSERT_SUCCESS(urQueueCreate(context, device, nullptr, &queue));
queues.push_back(queue);
}
}

// Default implementation that uses all available devices
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(
uur::urMultiDeviceContextTestTemplate<1>::SetUp());
initQueues(uur::KernelsEnvironment::instance->devices, 1);
}

// Specialized implementation that duplicates all devices and queues
void SetUp(std::vector<ur_device_handle_t> srcDevices,
size_t numDuplicate) {
UUR_RETURN_ON_FATAL_FAILURE(
uur::urMultiDeviceContextTestTemplate<1>::SetUp());
initQueues(srcDevices, numDuplicate);
}

void TearDown() override {
for (auto &queue : queues) {
EXPECT_SUCCESS(urQueueRelease(queue));
}
UUR_RETURN_ON_FATAL_FAILURE(
uur::urMultiDeviceContextTestTemplate<1>::TearDown());
}
std::function<std::tuple<std::vector<ur_device_handle_t>,
std::vector<ur_queue_handle_t>>(void)>
makeQueues;

std::vector<ur_device_handle_t> devices;
std::vector<ur_queue_handle_t> queues;
};

} // namespace uur

#endif // UUR_ENQUEUE_RECT_HELPERS_H_INCLUDED
218 changes: 218 additions & 0 deletions test/conformance/enqueue/urEnqueueEventsWaitMultiDevice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (C) 2024 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "helpers.h"

#include <thread>

#include <uur/fixtures.h>
#include <uur/raii.h>

struct urEnqueueEventsWaitMultiDeviceTest : uur::urMultiQueueMultiDeviceTest {
void SetUp() override { SetUp(2); /* we need at least 2 devices */ }

void SetUp(size_t numDuplicateDevices) {
UUR_RETURN_ON_FATAL_FAILURE(uur::urMultiQueueMultiDeviceTest::SetUp(
uur::KernelsEnvironment::instance->devices, numDuplicateDevices));

for (auto device : devices) {
ur_device_usm_access_capability_flags_t shared_usm_single = 0;
EXPECT_SUCCESS(uur::GetDeviceUSMSingleSharedSupport(
device, shared_usm_single));
if (!shared_usm_single) {
GTEST_SKIP() << "Shared USM is not supported by the device.";
}
}

ptrs.resize(devices.size());
for (size_t i = 0; i < devices.size(); i++) {
EXPECT_SUCCESS(urUSMSharedAlloc(context, devices[i], nullptr,
nullptr, size, &ptrs[i]));
}
}

void TearDown() override {
for (auto ptr : ptrs) {
if (ptr) {
EXPECT_SUCCESS(urUSMFree(context, ptr));
}
}
UUR_RETURN_ON_FATAL_FAILURE(
uur::urMultiQueueMultiDeviceTest::TearDown());
}

void initData() {
EXPECT_SUCCESS(urEnqueueUSMFill(queues[0], ptrs[0], sizeof(pattern),
&pattern, size, 0, nullptr, nullptr));
EXPECT_SUCCESS(urQueueFinish(queues[0]));
}

void verifyData(void *ptr, uint32_t pattern) {
for (size_t i = 0; i < count; i++) {
ASSERT_EQ(reinterpret_cast<uint32_t *>(ptr)[i], pattern);
}
}

uint32_t pattern = 42;
const size_t count = 1024;
const size_t size = sizeof(uint32_t) * count;

std::vector<void *> ptrs;
};

TEST_F(urEnqueueEventsWaitMultiDeviceTest, EmptyWaitList) {
initData();

ASSERT_SUCCESS(urEnqueueUSMMemcpy(queues[0], false, ptrs[1], ptrs[0], size,
0, nullptr, nullptr));
ASSERT_SUCCESS(urEnqueueEventsWait(queues[0], 0, nullptr, nullptr));
ASSERT_SUCCESS(urQueueFinish(queues[0]));

verifyData(ptrs[1], pattern);
}

TEST_F(urEnqueueEventsWaitMultiDeviceTest, EmptyWaitListWithEvent) {
initData();

ASSERT_SUCCESS(urEnqueueUSMMemcpy(queues[0], false, ptrs[1], ptrs[0], size,
0, nullptr, nullptr));

uur::raii::Event event;
ASSERT_SUCCESS(urEnqueueEventsWait(queues[0], 0, nullptr, event.ptr()));
ASSERT_SUCCESS(urEventWait(1, event.ptr()));

verifyData(ptrs[1], pattern);
}

TEST_F(urEnqueueEventsWaitMultiDeviceTest, EnqueueWaitOnADifferentQueue) {
initData();

uur::raii::Event event;
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queues[0], false, ptrs[1], ptrs[0], size,
0, nullptr, event.ptr()));
ASSERT_SUCCESS(urEnqueueEventsWait(queues[0], 1, event.ptr(), nullptr));
ASSERT_SUCCESS(urQueueFinish(queues[0]));

verifyData(ptrs[1], pattern);
}

struct urEnqueueEventsWaitMultiDeviceMTTest
: urEnqueueEventsWaitMultiDeviceTest,
testing::WithParamInterface<uur::BoolTestParam> {
void doComputation(std::function<void(size_t)> work) {
auto multiThread = GetParam().value;
std::vector<std::thread> threads;
for (size_t i = 0; i < devices.size(); i++) {
if (multiThread) {
threads.emplace_back(work, i);
} else {
work(i);
}
}
for (auto &thread : threads) {
thread.join();
}
}

void SetUp() override {
const size_t numDuplicateDevices = 8;
UUR_RETURN_ON_FATAL_FAILURE(
urEnqueueEventsWaitMultiDeviceTest::SetUp(numDuplicateDevices));
}

void TearDown() override { urEnqueueEventsWaitMultiDeviceTest::TearDown(); }
};

template <typename T>
inline std::string
printParams(const testing::TestParamInfo<typename T::ParamType> &info) {
std::stringstream ss;

auto param1 = info.param;
ss << (param1.value ? "" : "No") << param1.name;

return ss.str();
}

INSTANTIATE_TEST_SUITE_P(
, urEnqueueEventsWaitMultiDeviceMTTest,
testing::ValuesIn(uur::BoolTestParam::makeBoolParam("MultiThread")),
printParams<urEnqueueEventsWaitMultiDeviceMTTest>);

TEST_P(urEnqueueEventsWaitMultiDeviceMTTest, EnqueueWaitSingleQueueMultiOps) {
std::vector<uint32_t> data(count, pattern);

auto work = [this, &data](size_t i) {
ASSERT_SUCCESS(urEnqueueUSMMemcpy(
queues[0], false, ptrs[i], data.data(), size, 0, nullptr, nullptr));
};

doComputation(work);

auto verify = [this](size_t i) {
uur::raii::Event event;
ASSERT_SUCCESS(urEnqueueEventsWait(queues[0], 0, nullptr, event.ptr()));
ASSERT_SUCCESS(urEventWait(1, event.ptr()));

verifyData(ptrs[i], pattern);
};

doComputation(verify);
}

TEST_P(urEnqueueEventsWaitMultiDeviceMTTest, EnqueueWaitOnAllQueues) {
std::vector<uur::raii::Event> eventsRaii(devices.size());
std::vector<ur_event_handle_t> events(devices.size());
auto work = [this, &events, &eventsRaii](size_t i) {
ASSERT_SUCCESS(urEnqueueUSMFill(queues[i], ptrs[i], sizeof(pattern),
&pattern, size, 0, nullptr,
eventsRaii[i].ptr()));
events[i] = eventsRaii[i].get();
};

doComputation(work);

uur::raii::Event gatherEvent;
ASSERT_SUCCESS(urEnqueueEventsWait(queues[0], devices.size(), events.data(),
gatherEvent.ptr()));
ASSERT_SUCCESS(urEventWait(1, gatherEvent.ptr()));

for (size_t i = 0; i < devices.size(); i++) {
verifyData(ptrs[i], pattern);
}
}

TEST_P(urEnqueueEventsWaitMultiDeviceMTTest,
EnqueueWaitOnAllQueuesCommonDependency) {
uur::raii::Event event;
ASSERT_SUCCESS(urEnqueueUSMFill(queues[0], ptrs[0], sizeof(pattern),
&pattern, size, 0, nullptr, event.ptr()));

std::vector<uur::raii::Event> perQueueEvents(devices.size());
std::vector<ur_event_handle_t> eventHandles(devices.size());
auto work = [this, &event, &perQueueEvents, &eventHandles](size_t i) {
ASSERT_SUCCESS(urEnqueueEventsWait(queues[i], 1, event.ptr(),
perQueueEvents[i].ptr()));
eventHandles[i] = perQueueEvents[i].get();
};

doComputation(work);

uur::raii::Event hGatherEvent;
ASSERT_SUCCESS(urEnqueueEventsWait(queues[0], eventHandles.size(),
eventHandles.data(),
hGatherEvent.ptr()));
ASSERT_SUCCESS(urEventWait(1, hGatherEvent.ptr()));

for (auto &event : eventHandles) {
ur_event_status_t status;
ASSERT_SUCCESS(
urEventGetInfo(event, UR_EVENT_INFO_COMMAND_EXECUTION_STATUS,
sizeof(ur_event_status_t), &status, nullptr));
ASSERT_EQ(status, UR_EVENT_STATUS_COMPLETE);
}

verifyData(ptrs[0], pattern);
}
Loading

0 comments on commit d8cc532

Please sign in to comment.