Skip to content

[Offload] Add MAX_WORK_GROUP_SIZE device info query #143718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion offload/liboffload/API/Device.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def : Enum {
TaggedEtor<"PLATFORM", "ol_platform_handle_t", "the platform associated with the device">,
TaggedEtor<"NAME", "char[]", "Device name">,
TaggedEtor<"VENDOR", "char[]", "Device vendor">,
TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">
TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">,
TaggedEtor<"MAX_WORK_GROUP_SIZE", "ol_dimensions_t", "Maximum work group size in each dimension">,
];
}

Expand Down
35 changes: 35 additions & 0 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "OffloadImpl.hpp"
#include "Helpers.hpp"
#include "OffloadPrint.hpp" // Required for operator<< implementation of ol_device_info_t
#include "PluginManager.h"
#include "llvm/Support/FormatVariadic.h"
#include <OffloadAPI.h>
Expand Down Expand Up @@ -264,6 +265,37 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,

return "";
};
auto GetInfoXyz = [&](std::vector<std::string> Names) -> Error {
if (Device == OffloadContext::get().HostDevice())
return ReturnValue(ol_dimensions_t{0u, 0u, 0u});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's ReturnValue used for? Isn't this was Expected is supposed to be used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReturnValue is a helper object with an operator() that automatically writes the value to PropValue and PropSizeRet as appropriate (as well as doing some validation checks).

If this were to return Expected<ol_dimensions_t>, then the callers would have to do the Error unwrapping themselves.

@callumfare For the sake of clarity, do you mind if (in a separate patch) I rename this to something like SetOutput?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a weird way to work around just using Expected<T>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RossBrunton Yeah I'm happy if that gets renamed. I think the naming is a little inconsistent with the info functions as well ('info' vs 'prop' etc).

@jhuber6 I think the wrappers are helpful to have since the code to output the value and underlying size is a little verbose - https://github.com/llvm/llvm-project/blob/main/offload/liboffload/src/Helpers.hpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing an easier way, that would result in code like this:

case OL_DEVICE_INFO_NAME:
  auto Res = GetInfoWhatever(...);
  if (Res) {
    return ReturnValue(*Res);
  } else {
    return Res.takeError();
  }

For every device info that uses this interface, which feels unnecessarily verbose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that function also doesn't use Expected<T>, it would just be a return otherwise and then you unpack the value at the bottom (Where we actually set the result.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the value of T?


assert(Device->Device &&
"liboffload device handle contains a null plugin device");

auto Info = Device->Device->obtainInfoImpl();
if (auto Err = Info.takeError())
return Err;

for (auto Name : Names) {
if (auto Entry = Info->get(Name)) {
auto Node = *Entry;
ol_dimensions_t Out{0, 0, 0};

if (auto X = Node->get("x"))
Out.x = std::get<size_t>((*X)->Value);
if (auto Y = Node->get("y"))
Out.y = std::get<size_t>((*Y)->Value);
if (auto Z = Node->get("z"))
Out.z = std::get<size_t>((*Z)->Value);
return ReturnValue(Out);
}
}

std::string ErrBuffer;
llvm::raw_string_ostream(ErrBuffer)
<< "plugin did not provide information for " << PropName;
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
};

switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
Expand All @@ -279,6 +311,9 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
case OL_DEVICE_INFO_DRIVER_VERSION:
return ReturnValue(
GetInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
return GetInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
"Maximum Block Dimensions" /*CUDA*/});
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
Expand Down
5 changes: 5 additions & 0 deletions offload/tools/offload-tblgen/PrintGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ template <typename T> inline void printTagged(llvm::raw_ostream &os, const void
"enum {0} value);\n",
EnumRec{R}.getName());
}
for (auto *R : Records.getAllDerivedDefinitions("Struct")) {
OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, "
"const struct {0} param);\n",
StructRec{R}.getName());
}
OS << "\n";

// Create definitions
Expand Down
9 changes: 9 additions & 0 deletions offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ TEST_P(olGetDeviceInfoTest, SuccessDriverVersion) {
ASSERT_EQ(std::strlen(DriverVersion.data()), Size - 1);
}

TEST_P(olGetDeviceInfoTest, SuccessMaxWorkGroupSize) {
ol_dimensions_t Value{0, 0, 0};
ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
sizeof(Value), &Value));
ASSERT_GT(Value.x, 0u);
ASSERT_GT(Value.y, 0u);
ASSERT_GT(Value.z, 0u);
}

TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) {
ol_device_type_t DeviceType;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
Expand Down
8 changes: 8 additions & 0 deletions offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ TEST_P(olGetDeviceInfoSizeTest, SuccessDriverVersion) {
ASSERT_NE(Size, 0ul);
}

TEST_P(olGetDeviceInfoSizeTest, SuccessMaxWorkGroupSize) {
size_t Size = 0;
ASSERT_SUCCESS(
olGetDeviceInfoSize(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE, &Size));
ASSERT_EQ(Size, sizeof(ol_dimensions_t));
ASSERT_EQ(Size, sizeof(uint32_t) * 3);
}

TEST_P(olGetDeviceInfoSizeTest, InvalidNullHandle) {
size_t Size = 0;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
Expand Down
Loading