Skip to content

Commit

Permalink
Use the new PjRt InterpreterClient in test base and PjRt test client …
Browse files Browse the repository at this point in the history
…registry.

PiperOrigin-RevId: 707247154
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Dec 19, 2024
1 parent ca3ddd2 commit f7b1c20
Showing 3 changed files with 17 additions and 27 deletions.
9 changes: 2 additions & 7 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
@@ -163,10 +163,8 @@ cc_library(
],
deps = [
":pjrt_client_registry",
"//xla/pjrt:interpreter_device",
"//xla/pjrt:pjrt_client",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:status",
"//xla/pjrt/interpreter:interpreter_client",
],
)

@@ -278,15 +276,12 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/pjrt:pjrt_client",
"//xla/service:hlo_runner",
"//xla/pjrt/interpreter:interpreter_client",
"//xla/service:hlo_runner_interface",
"//xla/service:hlo_runner_pjrt",
"//xla/service:interpreter_plugin", # reference backend
"//xla/service:platform_util",
"//xla/stream_executor:platform",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:logging",
],
)
15 changes: 6 additions & 9 deletions xla/tests/hlo_pjrt_test_base.cc
Original file line number Diff line number Diff line change
@@ -21,13 +21,10 @@ limitations under the License.

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/pjrt/interpreter/interpreter_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/service/hlo_runner.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/hlo_runner_pjrt.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/platform.h"
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "xla/tests/pjrt_client_registry.h"
#include "xla/util.h"
@@ -56,11 +53,11 @@ std::unique_ptr<HloRunnerInterface> GetHloRunnerForTest() {
}

std::unique_ptr<HloRunnerInterface> GetHloRunnerForReference() {
absl::StatusOr<se::Platform*> platform =
PlatformUtil::GetPlatform("interpreter");
CHECK_OK(platform.status())
<< "Failed to get interpreter platform. " << platform.status();
return std::make_unique<HloRunner>(*platform);
return std::make_unique<HloRunnerPjRt>(
std::make_unique<InterpreterClient>(),
InterpreterClient::DeviceShapeRepresentation,
InterpreterClient::ShapeSizeBytes,
/*use_parameter_layout_on_device=*/true);
}

} // namespace
20 changes: 9 additions & 11 deletions xla/tests/pjrt_interpreter_client_registry.cc
Original file line number Diff line number Diff line change
@@ -14,25 +14,23 @@ limitations under the License.
==============================================================================*/

#include <memory>
#include <utility>

#include "absl/status/statusor.h"
#include "xla/pjrt/interpreter_device.h"
#include "xla/pjrt/interpreter/interpreter_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/tests/pjrt_client_registry.h"
#include "tsl/platform/status.h"

namespace xla {
namespace {

// Register an interpreter PjRt client for tests.
const bool kUnused = (RegisterPjRtClientTestFactory([]() {
absl::StatusOr<std::unique_ptr<PjRtClient>> client =
GetInterpreterClient();
TF_CHECK_OK(client.status());
return *std::move(client);
}),
true);
const bool kUnused =
(RegisterPjRtClientTestFactory(
[]() { return std::make_unique<InterpreterClient>(); },
[](PjRtClient* client) {
return InterpreterClient::DeviceShapeRepresentation;
},
[](PjRtClient* client) { return InterpreterClient::ShapeSizeBytes; }),
true);

} // namespace
} // namespace xla

0 comments on commit f7b1c20

Please sign in to comment.