Skip to content

Commit

Permalink
Implement HloRunnerPjRt::ExecuteReplicated w/ executable_provider
Browse files Browse the repository at this point in the history
… overload.

This is mostly modeled after the implementation that I found in the `HloRunner`
class, with a few modifications.

PiperOrigin-RevId: 707683230
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Dec 19, 2024
1 parent 47c5b30 commit 93a2296
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 16 deletions.
11 changes: 9 additions & 2 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4612,9 +4612,11 @@ cc_library(
hdrs = ["hlo_runner_pjrt.h"],
deps = [
":computation_layout",
":computation_placer_hdr",
":executable",
":hlo_module_util",
":hlo_runner_interface",
"//xla:literal",
"//xla:shape_layout",
"//xla:shape_util",
"//xla:status_macros",
Expand All @@ -4624,15 +4626,20 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/pjrt:host_memory_spaces",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:casts",
],
)

Expand Down
110 changes: 96 additions & 14 deletions xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,35 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/shape_layout.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/casts.h"

namespace xla {

Expand Down Expand Up @@ -109,6 +119,22 @@ absl::StatusOr<ExecuteOptions> GenerateExecuteOptions(const HloModule& module) {
return execute_options;
}

inline PjRtGlobalDeviceId DeviceIdForInvocation(
const DeviceAssignment& device_assignment, const int64_t i) {
const int64_t computation_count = device_assignment.computation_count();
return PjRtGlobalDeviceId(
device_assignment(i / computation_count, i % computation_count));
}

absl::StatusOr<DeviceAssignment> GetStaticDeviceAssignmentOrComputeDefault(
const HloModule& module, PjRtClient& client) {
if (module.config().has_static_device_assignment()) {
return module.config().static_device_assignment();
}
return client.GetDefaultDeviceAssignment(module.config().replica_count(),
module.config().num_partitions());
}

} // namespace

// TODO(b/245550554): Remove the use of PjRtWrappedExecutable.
Expand Down Expand Up @@ -156,9 +182,8 @@ HloRunnerPjRt::~HloRunnerPjRt() = default;
absl::StatusOr<CompileOptions> HloRunnerPjRt::GenerateDefaultCompileOptions(
HloModule* module, bool run_hlo_passes) {
TF_ASSIGN_OR_RETURN(
auto device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(
module->config().replica_count(), module->config().num_partitions()));
const DeviceAssignment device_assignment,
GetStaticDeviceAssignmentOrComputeDefault(*module, *pjrt_client_));

CompileOptions compile_options;

Expand Down Expand Up @@ -448,7 +473,67 @@ absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicated(
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const HloRunnerInterface::ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) {
return Unimplemented("Unimplemeneted ExecuteReplicated");
TF_RET_CHECK(device_assignment->computation_count() == 1)
<< "Only single-computation execution is supported.";
return ExecuteReplicatedImpl(
[&](absl::Span<const std::vector<PjRtBuffer*>>& argument_buffer_slices)
-> absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> {
TF_RET_CHECK(options.use_threads);

// The underlying data is modified concurrently. We don't need to
// protect access as each replica writes only to its own slot.
std::vector<absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>>
per_replica_results(options.num_replicas);
absl::c_fill(per_replica_results,
absl::InternalError("No result for replica."));

{
// NB: `pool` is joined on destruction.
tsl::thread::ThreadPool pool(tsl::Env::Default(), "replicas",
options.num_replicas);
for (int64_t i = 0; i < options.num_replicas; ++i) {
for (const PjRtBuffer* const buffer : argument_buffer_slices[i]) {
TF_RET_CHECK(buffer != nullptr);
}
PjRtWrappedExecutable* executable =
tensorflow::down_cast<PjRtWrappedExecutable*>(
executable_provider(i));
if (executable == nullptr) {
return absl::InternalError(
absl::StrFormat("Failed to cast executable for replica %d "
"to PjRtWrappedExecutable.",
i));
}
TF_ASSIGN_OR_RETURN(
PjRtDevice * device_ptr,
pjrt_client_->LookupDevice(
DeviceIdForInvocation(*device_assignment, i)));
pool.Schedule([&per_replica_results, i, executable,
args = argument_buffer_slices[i], device_ptr]() {
per_replica_results[i] =
executable->GetPjRtLoadedExecutable()->ExecuteSharded(
args, device_ptr, {});
});
}
}
// Aggregate results.
std::vector<std::unique_ptr<PjRtBuffer>> results;
for (int64_t i = 0; i < options.num_replicas; ++i) {
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>&
replica_result = per_replica_results[i];
if (!replica_result.ok()) {
return replica_result.status();
}
if (replica_result->size() != 1) {
return absl::InternalError(absl::StrFormat(
"Expected a single result for replica %d, got %d results.", i,
replica_result->size()));
}
results.push_back(std::move(std::move(replica_result)->front()));
}
return results;
},
argument_count_provider, argument_provider, options, device_assignment);
}

absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicatedImpl(
Expand All @@ -459,16 +544,13 @@ absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicatedImpl(
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) {
const int64_t num_computations = device_assignment->computation_count();
absl::Span<PjRtDevice* const> devices = pjrt_client_->devices();

std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> argument_buffer_slices;
argument_buffer_slices.reserve(pjrt_client_->addressable_device_count());
argument_buffer_slices.reserve(options.num_replicas);

for (int64_t i = 0; i < options.num_replicas; ++i) {
const int64_t device_index =
(*device_assignment)(i / num_computations, i % num_computations);
PjRtDevice* device_ptr = devices[device_index];
TF_ASSIGN_OR_RETURN(PjRtDevice * device_ptr,
pjrt_client_->LookupDevice(
DeviceIdForInvocation(*device_assignment, i)));

// Transfer literals to device.
const int64_t argument_count = argument_count_provider(i);
Expand Down

0 comments on commit 93a2296

Please sign in to comment.