Skip to content

Commit d5c179b

Browse files
cotatensorflower-gardener
authored andcommitted
hlo_runner_pjrt: support parameter and output streaming
PiperOrigin-RevId: 630165970
1 parent 0368aba commit d5c179b

File tree

3 files changed

+175
-27
lines changed

3 files changed

+175
-27
lines changed

third_party/xla/xla/service/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5701,12 +5701,21 @@ cc_library(
57015701
":executable",
57025702
":hlo_module_util",
57035703
":hlo_runner_interface",
5704+
"//xla:shape_layout",
5705+
"//xla:shape_util",
5706+
"//xla:status",
5707+
"//xla:status_macros",
57045708
"//xla:statusor",
5709+
"//xla:util",
57055710
"//xla/client:xla_computation",
57065711
"//xla/hlo/ir:hlo",
5712+
"//xla/pjrt:host_memory_spaces",
57075713
"//xla/pjrt:pjrt_client",
57085714
"//xla/pjrt:pjrt_executable",
57095715
"//xla/pjrt:pjrt_future",
5716+
"//xla/service:computation_layout",
5717+
"@com_google_absl//absl/algorithm:container",
5718+
"@com_google_absl//absl/types:span",
57105719
"@local_tsl//tsl/platform:errors",
57115720
"@local_tsl//tsl/platform:statusor",
57125721
],

third_party/xla/xla/service/hlo_runner_pjrt.cc

Lines changed: 162 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,102 @@ limitations under the License.
1515

1616
#include "xla/service/hlo_runner_pjrt.h"
1717

18+
#include <cstdint>
1819
#include <functional>
1920
#include <memory>
2021
#include <optional>
2122
#include <utility>
2223
#include <vector>
2324

25+
#include "absl/algorithm/container.h"
26+
#include "absl/types/span.h"
2427
#include "xla/client/xla_computation.h"
2528
#include "xla/hlo/ir/hlo_module.h"
29+
#include "xla/layout.h"
30+
#include "xla/pjrt/host_memory_spaces.h"
2631
#include "xla/pjrt/pjrt_client.h"
2732
#include "xla/pjrt/pjrt_executable.h"
2833
#include "xla/pjrt/pjrt_future.h"
34+
#include "xla/service/computation_layout.h"
2935
#include "xla/service/executable.h"
3036
#include "xla/service/hlo_module_util.h"
37+
#include "xla/shape_layout.h"
38+
#include "xla/shape_util.h"
39+
#include "xla/status.h"
40+
#include "xla/status_macros.h"
3141
#include "xla/statusor.h"
42+
#include "xla/util.h"
3243
#include "tsl/platform/errors.h"
3344
#include "tsl/platform/statusor.h"
3445

3546
namespace xla {
3647

48+
namespace {
49+
50+
absl::Status SanityCheckParameterLayouts(
51+
const ComputationLayout& entry_layout) {
52+
const std::vector<ShapeLayout>& layouts = entry_layout.parameter_layouts();
53+
bool has_nested_tuples =
54+
absl::c_any_of(layouts, [](const auto& shape_layout) {
55+
return ShapeUtil::IsNestedTuple(shape_layout.shape());
56+
});
57+
if (has_nested_tuples) {
58+
return InvalidArgument(
59+
"PJRT does not support nested tuples as input parameters");
60+
}
61+
int num_tuples = absl::c_count_if(layouts, [](const auto& shape_layout) {
62+
return shape_layout.shape().IsTuple();
63+
});
64+
if (num_tuples > 1) {
65+
return InvalidArgument(
66+
"PJRT does not support more than one tuple as input parameters"
67+
" (found %d tuples)",
68+
num_tuples);
69+
}
70+
if (num_tuples == 1 && num_tuples != layouts.size()) {
71+
return InvalidArgument(
72+
"PJRT does not support mixing tuples and non-tuples as input "
73+
"parameters (found 1 tuple out of %d arguments)",
74+
layouts.size());
75+
}
76+
return OkStatus();
77+
}
78+
79+
absl::StatusOr<bool> MustFlattenInputTuple(
80+
const ComputationLayout& entry_layout) {
81+
TF_RETURN_IF_ERROR(SanityCheckParameterLayouts(entry_layout));
82+
// Strictly, we only need to flatten tuples with mixed host/device leaves
83+
// because mixed host/device PjRtBuffer's are not supported.
84+
// However, splitting all tuples makes the code simpler and is the way
85+
// PJRT is commonly used by JAX.
86+
return entry_layout.parameter_count() == 1 &&
87+
entry_layout.parameter_shape(0).IsTuple();
88+
}
89+
90+
absl::StatusOr<ExecuteOptions> GenerateExecuteOptions(const HloModule& module) {
91+
ExecuteOptions execute_options;
92+
93+
// If any output leaf buffer is in host memory, PJRT requires untuple_result.
94+
auto output_has_tuple_leaf_in_host_memory_space =
95+
[&module]() -> absl::StatusOr<bool> {
96+
if (!module.result_shape().IsTuple()) {
97+
return false;
98+
}
99+
TF_ASSIGN_OR_RETURN(
100+
std::vector<Layout> output_layouts,
101+
module.entry_computation_layout().FlattenedResultLayouts());
102+
return absl::c_any_of(output_layouts, [](const Layout& layout) {
103+
return layout.memory_space() == Layout::kHostMemorySpace;
104+
});
105+
};
106+
TF_ASSIGN_OR_RETURN(execute_options.untuple_result,
107+
output_has_tuple_leaf_in_host_memory_space());
108+
109+
return execute_options;
110+
}
111+
112+
} // namespace
113+
37114
// TODO(b/245550554): Remove the use of PjRtWrappedExecutable.
38115
class PjRtWrappedExecutable : public Executable {
39116
public:
@@ -102,6 +179,13 @@ absl::StatusOr<CompileOptions> HloRunnerPjRt::GenerateDefaultCompileOptions(
102179
}
103180
compile_options.argument_layouts = parameter_shapes;
104181

182+
TF_ASSIGN_OR_RETURN(
183+
bool flatten, MustFlattenInputTuple(module->entry_computation_layout()));
184+
compile_options.parameter_is_tupled_arguments = flatten;
185+
186+
compile_options.executable_build_options.set_result_layout(
187+
module->entry_computation_layout().result_shape());
188+
105189
return compile_options;
106190
}
107191

@@ -114,28 +198,68 @@ absl::StatusOr<Literal> HloRunnerPjRt::TransferLiteralFromDevice(
114198
}
115199

116200
absl::StatusOr<std::unique_ptr<PjRtBuffer>>
117-
HloRunnerPjRt::TransferLiteralToDevice(const Literal& literal) {
201+
HloRunnerPjRt::TransferLiteralToDevice(const Literal& literal,
202+
int64_t memory_space) {
118203
auto devices = pjrt_client_->addressable_devices();
204+
PjRtDevice* device = devices[kDeviceIdx];
119205

120-
TF_ASSIGN_OR_RETURN(auto assignment, pjrt_client_->BufferFromHostLiteral(
121-
literal, devices[kDeviceIdx]));
206+
if (pjrt_client_->memory_spaces().empty()) {
207+
TF_ASSIGN_OR_RETURN(auto assignment,
208+
pjrt_client_->BufferFromHostLiteral(literal, device));
209+
return std::move(assignment);
210+
}
122211

212+
auto get_pjrt_memory_space = [](PjRtDevice* pjrt_device,
213+
int64_t xla_memory_space) {
214+
if (xla_memory_space == Layout::kHostMemorySpace) {
215+
return pjrt_device->memory_space_by_kind(PinnedHostMemorySpace::kKind);
216+
}
217+
return pjrt_device->default_memory_space();
218+
};
219+
TF_ASSIGN_OR_RETURN(PjRtMemorySpace * pjrt_memory_space,
220+
get_pjrt_memory_space(device, memory_space));
221+
TF_ASSIGN_OR_RETURN(auto assignment, pjrt_client_->BufferFromHostLiteral(
222+
literal, pjrt_memory_space));
123223
return std::move(assignment);
124224
}
125225

126226
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
127227
HloRunnerPjRt::TransferLiteralsToDevice(
228+
const ComputationLayout& entry_layout,
128229
absl::Span<const Literal* const> literals) {
129-
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
130-
buffers.reserve(literals.size());
131-
for (const Literal* literal : literals) {
132-
TF_RET_CHECK(literal != nullptr);
133-
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer,
134-
TransferLiteralToDevice(*literal));
135-
TF_RETURN_IF_ERROR(buffer->GetReadyFuture().Await());
136-
buffers.push_back(std::move(buffer));
230+
TF_ASSIGN_OR_RETURN(bool flatten, MustFlattenInputTuple(entry_layout));
231+
TF_ASSIGN_OR_RETURN(std::vector<Layout> parameter_layouts,
232+
entry_layout.FlattenedParameterLayouts());
233+
234+
auto transfer_literals = [&parameter_layouts, this](
235+
absl::Span<const Literal* const> input_literals)
236+
-> absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> {
237+
TF_RET_CHECK(parameter_layouts.size() == input_literals.size());
238+
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
239+
buffers.reserve(input_literals.size());
240+
for (int i = 0; i < input_literals.size(); ++i) {
241+
const Literal* literal = input_literals[i];
242+
TF_RET_CHECK(literal != nullptr);
243+
int64_t memory_space = parameter_layouts[i].memory_space();
244+
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer,
245+
TransferLiteralToDevice(*literal, memory_space));
246+
TF_RETURN_IF_ERROR(buffer->GetReadyFuture().Await());
247+
buffers.push_back(std::move(buffer));
248+
}
249+
return std::move(buffers);
250+
};
251+
252+
if (flatten) {
253+
Literal cloned_literal = literals[0]->Clone();
254+
std::vector<Literal> flattened = cloned_literal.DecomposeTuple();
255+
std::vector<const Literal*> flattened_ptrs;
256+
flattened_ptrs.reserve(flattened.size());
257+
for (const Literal& literal : flattened) {
258+
flattened_ptrs.push_back(&literal);
259+
}
260+
return transfer_literals(flattened_ptrs);
137261
}
138-
return std::move(buffers);
262+
return transfer_literals(literals);
139263
}
140264

141265
absl::StatusOr<Literal> HloRunnerPjRt::Execute(
@@ -180,15 +304,13 @@ HloRunnerPjRt::CreateExecutable(HloModule* module,
180304
CompileOptions compile_options) {
181305
XlaComputation computation(module->ToProto());
182306

183-
return pjrt_client_->Compile(computation, compile_options);
307+
return pjrt_client_->Compile(computation, std::move(compile_options));
184308
}
185309

186310
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
187311
HloRunnerPjRt::ExecuteWithDeviceBuffers(
188-
PjRtLoadedExecutable* executable,
312+
PjRtLoadedExecutable* executable, const ExecuteOptions& execute_options,
189313
const std::vector<std::unique_ptr<PjRtBuffer>>& arguments) {
190-
ExecuteOptions execute_options;
191-
192314
std::vector<PjRtBuffer*> argument_ptrs = BufferVecToPointerVec(arguments);
193315

194316
auto devices = pjrt_client_->addressable_devices();
@@ -209,17 +331,33 @@ absl::StatusOr<Literal> HloRunnerPjRt::ExecuteWithExecutable(
209331
PjRtWrappedExecutable* wrapped_executable =
210332
static_cast<PjRtWrappedExecutable*>(executable);
211333

212-
TF_ASSIGN_OR_RETURN(auto argument_handles,
213-
TransferLiteralsToDevice(arguments));
334+
auto* pjrt_executable = wrapped_executable->GetPjRtLoadedExecutable();
335+
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<HloModule>> hlo_modules,
336+
pjrt_executable->GetHloModules());
337+
TF_RET_CHECK(hlo_modules.size() == 1);
338+
const HloModule& module = *hlo_modules.front();
214339

340+
TF_ASSIGN_OR_RETURN(ExecuteOptions execute_options,
341+
GenerateExecuteOptions(module));
215342
TF_ASSIGN_OR_RETURN(
216-
auto output_buffer,
217-
ExecuteWithDeviceBuffers(wrapped_executable->GetPjRtLoadedExecutable(),
218-
std::move(argument_handles)));
219-
// TODO (b/245550554): Support more than 1 output.
220-
CHECK_EQ(output_buffer.size(), 1);
343+
auto argument_handles,
344+
TransferLiteralsToDevice(module.entry_computation_layout(), arguments));
221345

222-
return TransferLiteralFromDevice(*output_buffer[0]);
346+
TF_ASSIGN_OR_RETURN(
347+
std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
348+
ExecuteWithDeviceBuffers(wrapped_executable->GetPjRtLoadedExecutable(),
349+
execute_options, std::move(argument_handles)));
350+
if (!execute_options.untuple_result) {
351+
TF_RET_CHECK(output_buffers.size() == 1)
352+
<< ", got " << output_buffers.size();
353+
return TransferLiteralFromDevice(*output_buffers[0]);
354+
}
355+
std::vector<Literal> result_leaves;
356+
for (const auto& leaf_buffer : output_buffers) {
357+
TF_ASSIGN_OR_RETURN(Literal leaf, TransferLiteralFromDevice(*leaf_buffer));
358+
result_leaves.push_back(std::move(leaf));
359+
}
360+
return Literal::MoveIntoTuple(absl::MakeSpan(result_leaves));
223361
}
224362

225363
absl::StatusOr<std::unique_ptr<Executable>> HloRunnerPjRt::CreateExecutable(

third_party/xla/xla/service/hlo_runner_pjrt.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ class HloRunnerPjRt : public HloRunnerInterface {
4040

4141
// Transfers data between the host and device.
4242
absl::StatusOr<std::unique_ptr<PjRtBuffer>> TransferLiteralToDevice(
43-
const Literal& literal);
43+
const Literal& literal, int64_t memory_space);
4444
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
45-
TransferLiteralsToDevice(absl::Span<const Literal* const> literals);
45+
TransferLiteralsToDevice(const ComputationLayout& entry_layout,
46+
absl::Span<const Literal* const> literals);
4647
absl::StatusOr<Literal> TransferLiteralFromDevice(PjRtBuffer& buffer);
4748

4849
// Executes the given module with given literals as input and returns the
@@ -56,7 +57,7 @@ class HloRunnerPjRt : public HloRunnerInterface {
5657
// buffers.
5758
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
5859
ExecuteWithDeviceBuffers(
59-
PjRtLoadedExecutable* executable,
60+
PjRtLoadedExecutable* executable, const ExecuteOptions& execute_options,
6061
const std::vector<std::unique_ptr<PjRtBuffer>>& arguments);
6162

6263
// Creates an executable object for an HloModule.

0 commit comments

Comments
 (0)