@@ -15,25 +15,102 @@ limitations under the License.
1515
1616#include " xla/service/hlo_runner_pjrt.h"
1717
18+ #include < cstdint>
1819#include < functional>
1920#include < memory>
2021#include < optional>
2122#include < utility>
2223#include < vector>
2324
25+ #include " absl/algorithm/container.h"
26+ #include " absl/types/span.h"
2427#include " xla/client/xla_computation.h"
2528#include " xla/hlo/ir/hlo_module.h"
29+ #include " xla/layout.h"
30+ #include " xla/pjrt/host_memory_spaces.h"
2631#include " xla/pjrt/pjrt_client.h"
2732#include " xla/pjrt/pjrt_executable.h"
2833#include " xla/pjrt/pjrt_future.h"
34+ #include " xla/service/computation_layout.h"
2935#include " xla/service/executable.h"
3036#include " xla/service/hlo_module_util.h"
37+ #include " xla/shape_layout.h"
38+ #include " xla/shape_util.h"
39+ #include " xla/status.h"
40+ #include " xla/status_macros.h"
3141#include " xla/statusor.h"
42+ #include " xla/util.h"
3243#include " tsl/platform/errors.h"
3344#include " tsl/platform/statusor.h"
3445
3546namespace xla {
3647
48+ namespace {
49+
50+ absl::Status SanityCheckParameterLayouts (
51+ const ComputationLayout& entry_layout) {
52+ const std::vector<ShapeLayout>& layouts = entry_layout.parameter_layouts ();
53+ bool has_nested_tuples =
54+ absl::c_any_of (layouts, [](const auto & shape_layout) {
55+ return ShapeUtil::IsNestedTuple (shape_layout.shape ());
56+ });
57+ if (has_nested_tuples) {
58+ return InvalidArgument (
59+ " PJRT does not support nested tuples as input parameters" );
60+ }
61+ int num_tuples = absl::c_count_if (layouts, [](const auto & shape_layout) {
62+ return shape_layout.shape ().IsTuple ();
63+ });
64+ if (num_tuples > 1 ) {
65+ return InvalidArgument (
66+ " PJRT does not support more than one tuple as input parameters"
67+ " (found %d tuples)" ,
68+ num_tuples);
69+ }
70+ if (num_tuples == 1 && num_tuples != layouts.size ()) {
71+ return InvalidArgument (
72+ " PJRT does not support mixing tuples and non-tuples as input "
73+ " parameters (found 1 tuple out of %d arguments)" ,
74+ layouts.size ());
75+ }
76+ return OkStatus ();
77+ }
78+
79+ absl::StatusOr<bool > MustFlattenInputTuple (
80+ const ComputationLayout& entry_layout) {
81+ TF_RETURN_IF_ERROR (SanityCheckParameterLayouts (entry_layout));
82+ // Strictly, we only need to flatten tuples with mixed host/device leaves
83+ // because mixed host/device PjRtBuffer's are not supported.
84+ // However, splitting all tuples makes the code simpler and is the way
85+ // PJRT is commonly used by JAX.
86+ return entry_layout.parameter_count () == 1 &&
87+ entry_layout.parameter_shape (0 ).IsTuple ();
88+ }
89+
90+ absl::StatusOr<ExecuteOptions> GenerateExecuteOptions (const HloModule& module ) {
91+ ExecuteOptions execute_options;
92+
93+ // If any output leaf buffer is in host memory, PJRT requires untuple_result.
94+ auto output_has_tuple_leaf_in_host_memory_space =
95+ [&module ]() -> absl::StatusOr<bool > {
96+ if (!module .result_shape ().IsTuple ()) {
97+ return false ;
98+ }
99+ TF_ASSIGN_OR_RETURN (
100+ std::vector<Layout> output_layouts,
101+ module .entry_computation_layout ().FlattenedResultLayouts ());
102+ return absl::c_any_of (output_layouts, [](const Layout& layout) {
103+ return layout.memory_space () == Layout::kHostMemorySpace ;
104+ });
105+ };
106+ TF_ASSIGN_OR_RETURN (execute_options.untuple_result ,
107+ output_has_tuple_leaf_in_host_memory_space ());
108+
109+ return execute_options;
110+ }
111+
112+ } // namespace
113+
37114// TODO(b/245550554): Remove the use of PjRtWrappedExecutable.
38115class PjRtWrappedExecutable : public Executable {
39116 public:
@@ -102,6 +179,13 @@ absl::StatusOr<CompileOptions> HloRunnerPjRt::GenerateDefaultCompileOptions(
102179 }
103180 compile_options.argument_layouts = parameter_shapes;
104181
182+ TF_ASSIGN_OR_RETURN (
183+ bool flatten, MustFlattenInputTuple (module ->entry_computation_layout ()));
184+ compile_options.parameter_is_tupled_arguments = flatten;
185+
186+ compile_options.executable_build_options .set_result_layout (
187+ module ->entry_computation_layout ().result_shape ());
188+
105189 return compile_options;
106190}
107191
@@ -114,28 +198,68 @@ absl::StatusOr<Literal> HloRunnerPjRt::TransferLiteralFromDevice(
114198}
115199
116200absl::StatusOr<std::unique_ptr<PjRtBuffer>>
117- HloRunnerPjRt::TransferLiteralToDevice (const Literal& literal) {
201+ HloRunnerPjRt::TransferLiteralToDevice (const Literal& literal,
202+ int64_t memory_space) {
118203 auto devices = pjrt_client_->addressable_devices ();
204+ PjRtDevice* device = devices[kDeviceIdx ];
119205
120- TF_ASSIGN_OR_RETURN (auto assignment, pjrt_client_->BufferFromHostLiteral (
121- literal, devices[kDeviceIdx ]));
206+ if (pjrt_client_->memory_spaces ().empty ()) {
207+ TF_ASSIGN_OR_RETURN (auto assignment,
208+ pjrt_client_->BufferFromHostLiteral (literal, device));
209+ return std::move (assignment);
210+ }
122211
212+ auto get_pjrt_memory_space = [](PjRtDevice* pjrt_device,
213+ int64_t xla_memory_space) {
214+ if (xla_memory_space == Layout::kHostMemorySpace ) {
215+ return pjrt_device->memory_space_by_kind (PinnedHostMemorySpace::kKind );
216+ }
217+ return pjrt_device->default_memory_space ();
218+ };
219+ TF_ASSIGN_OR_RETURN (PjRtMemorySpace * pjrt_memory_space,
220+ get_pjrt_memory_space (device, memory_space));
221+ TF_ASSIGN_OR_RETURN (auto assignment, pjrt_client_->BufferFromHostLiteral (
222+ literal, pjrt_memory_space));
123223 return std::move (assignment);
124224}
125225
126226absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
127227HloRunnerPjRt::TransferLiteralsToDevice (
228+ const ComputationLayout& entry_layout,
128229 absl::Span<const Literal* const > literals) {
129- std::vector<std::unique_ptr<PjRtBuffer>> buffers;
130- buffers.reserve (literals.size ());
131- for (const Literal* literal : literals) {
132- TF_RET_CHECK (literal != nullptr );
133- TF_ASSIGN_OR_RETURN (std::unique_ptr<PjRtBuffer> buffer,
134- TransferLiteralToDevice (*literal));
135- TF_RETURN_IF_ERROR (buffer->GetReadyFuture ().Await ());
136- buffers.push_back (std::move (buffer));
230+ TF_ASSIGN_OR_RETURN (bool flatten, MustFlattenInputTuple (entry_layout));
231+ TF_ASSIGN_OR_RETURN (std::vector<Layout> parameter_layouts,
232+ entry_layout.FlattenedParameterLayouts ());
233+
234+ auto transfer_literals = [¶meter_layouts, this ](
235+ absl::Span<const Literal* const > input_literals)
236+ -> absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> {
237+ TF_RET_CHECK (parameter_layouts.size () == input_literals.size ());
238+ std::vector<std::unique_ptr<PjRtBuffer>> buffers;
239+ buffers.reserve (input_literals.size ());
240+ for (int i = 0 ; i < input_literals.size (); ++i) {
241+ const Literal* literal = input_literals[i];
242+ TF_RET_CHECK (literal != nullptr );
243+ int64_t memory_space = parameter_layouts[i].memory_space ();
244+ TF_ASSIGN_OR_RETURN (std::unique_ptr<PjRtBuffer> buffer,
245+ TransferLiteralToDevice (*literal, memory_space));
246+ TF_RETURN_IF_ERROR (buffer->GetReadyFuture ().Await ());
247+ buffers.push_back (std::move (buffer));
248+ }
249+ return std::move (buffers);
250+ };
251+
252+ if (flatten) {
253+ Literal cloned_literal = literals[0 ]->Clone ();
254+ std::vector<Literal> flattened = cloned_literal.DecomposeTuple ();
255+ std::vector<const Literal*> flattened_ptrs;
256+ flattened_ptrs.reserve (flattened.size ());
257+ for (const Literal& literal : flattened) {
258+ flattened_ptrs.push_back (&literal);
259+ }
260+ return transfer_literals (flattened_ptrs);
137261 }
138- return std::move (buffers );
262+ return transfer_literals (literals );
139263}
140264
141265absl::StatusOr<Literal> HloRunnerPjRt::Execute (
@@ -180,15 +304,13 @@ HloRunnerPjRt::CreateExecutable(HloModule* module,
180304 CompileOptions compile_options) {
181305 XlaComputation computation (module ->ToProto ());
182306
183- return pjrt_client_->Compile (computation, compile_options);
307+ return pjrt_client_->Compile (computation, std::move ( compile_options) );
184308}
185309
186310absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
187311HloRunnerPjRt::ExecuteWithDeviceBuffers (
188- PjRtLoadedExecutable* executable,
312+ PjRtLoadedExecutable* executable, const ExecuteOptions& execute_options,
189313 const std::vector<std::unique_ptr<PjRtBuffer>>& arguments) {
190- ExecuteOptions execute_options;
191-
192314 std::vector<PjRtBuffer*> argument_ptrs = BufferVecToPointerVec (arguments);
193315
194316 auto devices = pjrt_client_->addressable_devices ();
@@ -209,17 +331,33 @@ absl::StatusOr<Literal> HloRunnerPjRt::ExecuteWithExecutable(
209331 PjRtWrappedExecutable* wrapped_executable =
210332 static_cast <PjRtWrappedExecutable*>(executable);
211333
212- TF_ASSIGN_OR_RETURN (auto argument_handles,
213- TransferLiteralsToDevice (arguments));
334+ auto * pjrt_executable = wrapped_executable->GetPjRtLoadedExecutable ();
335+ TF_ASSIGN_OR_RETURN (std::vector<std::shared_ptr<HloModule>> hlo_modules,
336+ pjrt_executable->GetHloModules ());
337+ TF_RET_CHECK (hlo_modules.size () == 1 );
338+ const HloModule& module = *hlo_modules.front ();
214339
340+ TF_ASSIGN_OR_RETURN (ExecuteOptions execute_options,
341+ GenerateExecuteOptions (module ));
215342 TF_ASSIGN_OR_RETURN (
216- auto output_buffer,
217- ExecuteWithDeviceBuffers (wrapped_executable->GetPjRtLoadedExecutable (),
218- std::move (argument_handles)));
219- // TODO (b/245550554): Support more than 1 output.
220- CHECK_EQ (output_buffer.size (), 1 );
343+ auto argument_handles,
344+ TransferLiteralsToDevice (module .entry_computation_layout (), arguments));
221345
222- return TransferLiteralFromDevice (*output_buffer[0 ]);
346+ TF_ASSIGN_OR_RETURN (
347+ std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
348+ ExecuteWithDeviceBuffers (wrapped_executable->GetPjRtLoadedExecutable (),
349+ execute_options, std::move (argument_handles)));
350+ if (!execute_options.untuple_result ) {
351+ TF_RET_CHECK (output_buffers.size () == 1 )
352+ << " , got " << output_buffers.size ();
353+ return TransferLiteralFromDevice (*output_buffers[0 ]);
354+ }
355+ std::vector<Literal> result_leaves;
356+ for (const auto & leaf_buffer : output_buffers) {
357+ TF_ASSIGN_OR_RETURN (Literal leaf, TransferLiteralFromDevice (*leaf_buffer));
358+ result_leaves.push_back (std::move (leaf));
359+ }
360+ return Literal::MoveIntoTuple (absl::MakeSpan (result_leaves));
223361}
224362
225363absl::StatusOr<std::unique_ptr<Executable>> HloRunnerPjRt::CreateExecutable (
0 commit comments