diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index e47a7ed8fd04ce..01c1b03e3cb57f 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -188,12 +188,14 @@ cc_library( deps = [ ":pjrt_c_api_hdrs", ":pjrt_c_api_layouts_extension_hdrs", + ":pjrt_c_api_memory_descriptions_extension_hdrs", ":pjrt_c_api_profiler_extension_hdrs", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt/distributed:key_value_store_interface", @@ -479,6 +481,7 @@ cc_library( "//xla/hlo/parser:hlo_parser", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_future", "//xla/service:computation_placer_hdr", "//xla/service:hlo_proto_cc", diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index cf92041af497d5..2060a73a634a48 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -37,10 +37,12 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_layouts_extension.h" +#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/primitive_util.h" @@ -1101,4 +1103,35 @@ xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec( return shape_spec; } +std::vector GetMemorySpaceDescriptions( + PJRT_DeviceDescription* device_description, const PJRT_Api* c_api) { + const PJRT_MemoryDescriptions_Extension* extension = + pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_MemoryDescriptions); + if (!extension) return {}; + + PJRT_DeviceDescription_MemoryDescriptions_Args mem_desc_args; + mem_desc_args.struct_size = + PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE; + mem_desc_args.extension_start = nullptr; + mem_desc_args.device_description = device_description; + pjrt::LogFatalIfPjrtError( + extension->PJRT_DeviceDescription_MemoryDescriptions(&mem_desc_args), + c_api); + + std::vector memory_space_descriptions; + for (int i = 0; i < mem_desc_args.num_memory_descriptions; i++) { + PJRT_MemoryDescription_Kind_Args kind_args; + kind_args.struct_size = PJRT_MemoryDescription_Kind_Args_STRUCT_SIZE; + kind_args.extension_start = nullptr; + kind_args.memory_description = mem_desc_args.memory_descriptions[i]; + pjrt::LogFatalIfPjrtError( + extension->PJRT_MemoryDescription_Kind(&kind_args), c_api); + xla::PjRtMemorySpaceDescription description( + std::string(kind_args.kind, kind_args.kind_size), kind_args.kind_id); + memory_space_descriptions.push_back(description); + } + return memory_space_descriptions; +} + } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index f530b82f423573..709558fba465af 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -350,6 +350,9 @@ int64_t GetTracemeContextId(InputType* args) { return traceme_context_id; } +std::vector GetMemorySpaceDescriptions( + PJRT_DeviceDescription* device_description, const PJRT_Api* c_api); + } // namespace pjrt #endif // XLA_PJRT_C_PJRT_C_API_HELPERS_H_ diff --git a/xla/pjrt/c/pjrt_c_api_test.cc b/xla/pjrt/c/pjrt_c_api_test.cc index 5fb77870d55a4f..d47a0c059eae65 100644 --- a/xla/pjrt/c/pjrt_c_api_test.cc +++ b/xla/pjrt/c/pjrt_c_api_test.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_test_base.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_future.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" @@ -562,37 +563,12 @@ TEST_F(PjrtCApiTest, DeviceDescriptionAndMemoryDescriptionss) { PJRT_Error* error = api_->PJRT_Device_GetDescription(&get_description); EXPECT_EQ(error, nullptr); - PJRT_DeviceDescription_MemoryDescriptions_Args memory_descriptions = - PJRT_DeviceDescription_MemoryDescriptions_Args{ - .struct_size = - PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE, - .extension_start = nullptr, - .device_description = get_description.device_description, - }; + std::vector memory_descriptions = + GetMemorySpaceDescriptions(get_description.device_description, api_); - const PJRT_MemoryDescriptions_Extension* extension = - FindExtension( - api_, PJRT_Extension_Type::PJRT_Extension_Type_MemoryDescriptions); - - if (extension != nullptr) { - error = extension->PJRT_DeviceDescription_MemoryDescriptions( - &memory_descriptions); - EXPECT_EQ(error, nullptr); - - for (int i = 0; i < memory_descriptions.num_memory_descriptions; i++) { - PJRT_MemoryDescription_Kind_Args memory_description = - PJRT_MemoryDescription_Kind_Args{ - .struct_size = - PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE, - .extension_start = nullptr, - .memory_description = memory_descriptions.memory_descriptions[i], - }; - error = extension->PJRT_MemoryDescription_Kind(&memory_description); - EXPECT_EQ(error, nullptr); - EXPECT_NE(memory_description.kind, nullptr); - EXPECT_GT(memory_description.kind_size, 0); - EXPECT_GE(memory_description.kind_id, 0); - } + for (int i = 0; i < memory_descriptions.size(); i++) { + EXPECT_NE(memory_descriptions[i].kind().size(), 0); + EXPECT_GE(memory_descriptions[i].kind_id(), 0); } } diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index a6ebe3a39dfe31..a1b8966bd34e9b 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -1021,27 +1021,10 @@ PjRtCApiDeviceDescription::memory_spaces() const { if (!extension) return {}; if (memory_space_description_pointers_.empty()) { - PJRT_DeviceDescription_MemoryDescriptions_Args mem_desc_args; - mem_desc_args.struct_size = - PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE, - mem_desc_args.extension_start = nullptr, - mem_desc_args.device_description = device_description_, - pjrt::LogFatalIfPjrtError( - extension->PJRT_DeviceDescription_MemoryDescriptions(&mem_desc_args), - c_api_); - - for (int i = 0; i < mem_desc_args.num_memory_descriptions; i++) { - PJRT_MemoryDescription_Kind_Args kind_args; - kind_args.struct_size = PJRT_MemoryDescription_Kind_Args_STRUCT_SIZE, - kind_args.extension_start = nullptr, - kind_args.memory_description = mem_desc_args.memory_descriptions[i], - pjrt::LogFatalIfPjrtError( - extension->PJRT_MemoryDescription_Kind(&kind_args), c_api_); - PjRtMemorySpaceDescription description( - std::string(kind_args.kind, kind_args.kind_size), kind_args.kind_id); - memory_space_descriptions_.push_back(description); - memory_space_description_pointers_.push_back( - &memory_space_descriptions_[i]); + memory_space_descriptions_ = + pjrt::GetMemorySpaceDescriptions(device_description_, c_api_); + for (int i = 0; i < memory_space_descriptions_.size(); i++) { + memory_space_description_pointers_[i] = &memory_space_descriptions_[i]; } } return memory_space_description_pointers_;