Skip to content

Commit

Permalink
Factor out GetMemorySpaceDescriptions().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707988309
  • Loading branch information
matthiaskramm authored and Google-ML-Automation committed Dec 19, 2024
1 parent ca3ddd2 commit 2e27dda
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 51 deletions.
3 changes: 3 additions & 0 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,14 @@ cc_library(
deps = [
":pjrt_c_api_hdrs",
":pjrt_c_api_layouts_extension_hdrs",
":pjrt_c_api_memory_descriptions_extension_hdrs",
":pjrt_c_api_profiler_extension_hdrs",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_device_description",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/pjrt/distributed:key_value_store_interface",
Expand Down Expand Up @@ -479,6 +481,7 @@ cc_library(
"//xla/hlo/parser:hlo_parser",
"//xla/pjrt:compile_options_proto_cc",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_device_description",
"//xla/pjrt:pjrt_future",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_proto_cc",
Expand Down
33 changes: 33 additions & 0 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_layouts_extension.h"
#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h"
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/primitive_util.h"
Expand Down Expand Up @@ -1101,4 +1103,35 @@ xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec(
return shape_spec;
}

std::vector<xla::PjRtMemorySpaceDescription> GetMemorySpaceDescriptions(
PJRT_DeviceDescription* device_description, const PJRT_Api* c_api) {
const PJRT_MemoryDescriptions_Extension* extension =
pjrt::FindExtension<PJRT_MemoryDescriptions_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_MemoryDescriptions);
if (!extension) return {};

PJRT_DeviceDescription_MemoryDescriptions_Args mem_desc_args;
mem_desc_args.struct_size =
PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE;
mem_desc_args.extension_start = nullptr;
mem_desc_args.device_description = device_description;
pjrt::LogFatalIfPjrtError(
extension->PJRT_DeviceDescription_MemoryDescriptions(&mem_desc_args),
c_api);

std::vector<xla::PjRtMemorySpaceDescription> memory_space_descriptions;
for (int i = 0; i < mem_desc_args.num_memory_descriptions; i++) {
PJRT_MemoryDescription_Kind_Args kind_args;
kind_args.struct_size = PJRT_MemoryDescription_Kind_Args_STRUCT_SIZE;
kind_args.extension_start = nullptr;
kind_args.memory_description = mem_desc_args.memory_descriptions[i];
pjrt::LogFatalIfPjrtError(
extension->PJRT_MemoryDescription_Kind(&kind_args), c_api);
xla::PjRtMemorySpaceDescription description(
std::string(kind_args.kind, kind_args.kind_size), kind_args.kind_id);
memory_space_descriptions.push_back(description);
}
return memory_space_descriptions;
}

} // namespace pjrt
3 changes: 3 additions & 0 deletions xla/pjrt/c/pjrt_c_api_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ int64_t GetTracemeContextId(InputType* args) {
return traceme_context_id;
}

std::vector<xla::PjRtMemorySpaceDescription> GetMemorySpaceDescriptions(
PJRT_DeviceDescription* device_description, const PJRT_Api* c_api);

} // namespace pjrt

#endif // XLA_PJRT_C_PJRT_C_API_HELPERS_H_
36 changes: 6 additions & 30 deletions xla/pjrt/c/pjrt_c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "xla/pjrt/c/pjrt_c_api_test_base.h"
#include "xla/pjrt/compile_options.pb.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo.pb.h"
Expand Down Expand Up @@ -562,37 +563,12 @@ TEST_F(PjrtCApiTest, DeviceDescriptionAndMemoryDescriptionss) {
PJRT_Error* error = api_->PJRT_Device_GetDescription(&get_description);
EXPECT_EQ(error, nullptr);

PJRT_DeviceDescription_MemoryDescriptions_Args memory_descriptions =
PJRT_DeviceDescription_MemoryDescriptions_Args{
.struct_size =
PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE,
.extension_start = nullptr,
.device_description = get_description.device_description,
};
std::vector<xla::PjRtMemorySpaceDescription> memory_descriptions =
GetMemorySpaceDescriptions(get_description.device_description, api_);

const PJRT_MemoryDescriptions_Extension* extension =
FindExtension<PJRT_MemoryDescriptions_Extension>(
api_, PJRT_Extension_Type::PJRT_Extension_Type_MemoryDescriptions);

if (extension != nullptr) {
error = extension->PJRT_DeviceDescription_MemoryDescriptions(
&memory_descriptions);
EXPECT_EQ(error, nullptr);

for (int i = 0; i < memory_descriptions.num_memory_descriptions; i++) {
PJRT_MemoryDescription_Kind_Args memory_description =
PJRT_MemoryDescription_Kind_Args{
.struct_size =
PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE,
.extension_start = nullptr,
.memory_description = memory_descriptions.memory_descriptions[i],
};
error = extension->PJRT_MemoryDescription_Kind(&memory_description);
EXPECT_EQ(error, nullptr);
EXPECT_NE(memory_description.kind, nullptr);
EXPECT_GT(memory_description.kind_size, 0);
EXPECT_GE(memory_description.kind_id, 0);
}
for (int i = 0; i < memory_descriptions.size(); i++) {
EXPECT_NE(memory_descriptions[i].kind().size(), 0);
EXPECT_GE(memory_descriptions[i].kind_id(), 0);
}
}

Expand Down
25 changes: 4 additions & 21 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1021,27 +1021,10 @@ PjRtCApiDeviceDescription::memory_spaces() const {
if (!extension) return {};

if (memory_space_description_pointers_.empty()) {
PJRT_DeviceDescription_MemoryDescriptions_Args mem_desc_args;
mem_desc_args.struct_size =
PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE,
mem_desc_args.extension_start = nullptr,
mem_desc_args.device_description = device_description_,
pjrt::LogFatalIfPjrtError(
extension->PJRT_DeviceDescription_MemoryDescriptions(&mem_desc_args),
c_api_);

for (int i = 0; i < mem_desc_args.num_memory_descriptions; i++) {
PJRT_MemoryDescription_Kind_Args kind_args;
kind_args.struct_size = PJRT_MemoryDescription_Kind_Args_STRUCT_SIZE,
kind_args.extension_start = nullptr,
kind_args.memory_description = mem_desc_args.memory_descriptions[i],
pjrt::LogFatalIfPjrtError(
extension->PJRT_MemoryDescription_Kind(&kind_args), c_api_);
PjRtMemorySpaceDescription description(
std::string(kind_args.kind, kind_args.kind_size), kind_args.kind_id);
memory_space_descriptions_.push_back(description);
memory_space_description_pointers_.push_back(
&memory_space_descriptions_[i]);
memory_space_descriptions_ =
pjrt::GetMemorySpaceDescriptions(device_description_, c_api_);
for (int i = 0; i < memory_space_descriptions_.size(); i++) {
memory_space_description_pointers_[i] = &memory_space_descriptions_[i];
}
}
return memory_space_description_pointers_;
Expand Down

0 comments on commit 2e27dda

Please sign in to comment.