@@ -35,7 +35,6 @@ limitations under the License.
3535#include " mlir/Parser/Parser.h" // from @llvm-project
3636#include " xla/hlo/ir/hlo_computation.h"
3737#include " xla/hlo/ir/hlo_module.h"
38- #include " xla/mlir/runtime/transforms/compiler.h"
3938#include " xla/service/buffer_assignment.h"
4039#include " xla/service/computation_layout.h"
4140#include " xla/service/logical_buffer.h"
@@ -56,8 +55,6 @@ limitations under the License.
5655namespace xla {
5756namespace cpu {
5857
59- namespace runtime = ::xla::runtime;
60-
6158absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create (
6259 std::unique_ptr<SimpleOrcJIT> jit,
6360 std::unique_ptr<const BufferAssignment> assignment,
@@ -95,15 +92,11 @@ absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
9592 std::unique_ptr<HloModule> hlo_module,
9693 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
9794 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
98- std::unique_ptr<const BufferAssignment> assignment,
99- std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable) {
95+ std::unique_ptr<const BufferAssignment> assignment) {
10096 std::unique_ptr<CpuExecutable> executable (new CpuExecutable (
10197 std::move (hlo_module), std::move (hlo_profile_printer_data),
10298 std::move (hlo_profile_index_map), std::move (assignment)));
103- executable->set_ir_module_string (
104- xla_runtime_executable->GetExecutable ().take_ir_module_string ());
10599 executable->module_name_ = " main" ;
106- executable->xla_runtime_executable_ = std::move (xla_runtime_executable);
107100 return executable;
108101}
109102
@@ -237,33 +230,17 @@ Status CpuExecutable::ExecuteComputeFunction(
237230 }
238231 };
239232
240- if (IsXlaRuntime ()) {
241- std::vector<BufferDesc> descriptor_table;
242- descriptor_table.reserve (buffers.size ());
243- for (const auto & buffer : buffers) {
244- const tensorflow::se::DeviceMemoryBase& base =
245- buffer.AsDeviceMemoryBase ();
246- BufferDesc desc (const_cast <void *>(base.opaque ()), base.size ());
247- descriptor_table.push_back (std::move (desc));
248- }
249- Status status = ExecuteXlaRuntime (descriptor_table, run_options);
250- record_profile ();
251- if (!status.ok ()) {
252- return status;
253- }
254- } else {
255- XlaCustomCallStatus status;
256- // For the entry computation (like all global computations), all inputs and
257- // outputs are in the buffer table, and both the result pointer and args
258- // array pointers are unused (so we set them to 'nullptr').
259- compute_function_ (nullptr , run_options, nullptr , buffer_pointers.data (),
260- &status, profile_counters);
261- record_profile ();
262- std::optional<absl::string_view> error_message =
263- CustomCallStatusGetMessage (&status);
264- if (error_message) {
265- return Internal (" CustomCall failed: %s" , *error_message);
266- }
233+ XlaCustomCallStatus status;
234+ // For the entry computation (like all global computations), all inputs and
235+ // outputs are in the buffer table, and both the result pointer and args
236+ // array pointers are unused (so we set them to 'nullptr').
237+ compute_function_ (nullptr , run_options, nullptr , buffer_pointers.data (),
238+ &status, profile_counters);
239+ record_profile ();
240+ std::optional<absl::string_view> error_message =
241+ CustomCallStatusGetMessage (&status);
242+ if (error_message) {
243+ return Internal (" CustomCall failed: %s" , *error_message);
267244 }
268245
269246 return OkStatus ();
@@ -369,162 +346,6 @@ absl::StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
369346 return std::move (result);
370347}
371348
372- // Converts a BufferDesc to a MemrefDesc according to the given 'operand_type',
373- // which should point to a runtime::MemrefType.
374- // Note: 'descriptor_index' and 'operand_index' are just used for error
375- // reporting.
376- static absl::StatusOr<runtime::MemrefDesc> BufferToMemref (
377- const BufferDesc& descriptor, const runtime::Type& operand_type,
378- size_t descriptor_index, size_t operand_index) {
379- auto * memref = llvm::dyn_cast<runtime::MemrefType>(&operand_type);
380- if (!memref) {
381- return Internal (
382- " Cannot convert descriptor %zu (operand_index %zu): "
383- " the corresponding type in the signature is a %s, "
384- " not a MemrefType." ,
385- descriptor_index, operand_index, operand_type.ToString ());
386- }
387-
388- absl::Span<const int64_t > dims = memref->sizes ();
389-
390- // Verify that the provided descriptor size matches that of the memref.
391- size_t n_elem = absl::c_accumulate (dims, size_t {1 }, std::multiplies<>());
392- size_t expected_size =
393- primitive_util::ByteWidth (memref->element_type ()) * n_elem;
394- if (LLVM_UNLIKELY (expected_size != descriptor.size ())) {
395- return InvalidArgument (
396- " Cannot convert descriptor %zu (operand_index %zu): "
397- " buffer size is not equal to that expected from the element type: "
398- " got %zu vs expected %zu." ,
399- descriptor_index, operand_index, descriptor.size (), expected_size);
400- }
401-
402- auto fill_sizes_and_strides = [&](auto sizes, auto strides) {
403- size_t multiplier = 1 ;
404- for (int i = static_cast <int >(dims.size ()) - 1 ; i >= 0 ; --i) {
405- size_t size = dims[i];
406- sizes[i] = size;
407- strides[i] = multiplier;
408- multiplier *= size;
409- }
410- };
411- return runtime::MemrefDesc (memref->rank (), memref->element_type (),
412- descriptor.data (), /* offset=*/ 0 ,
413- fill_sizes_and_strides);
414- }
415-
416- // Executes from an XLA Runtime CPU executable, given a buffer descriptor table.
417- // Relevant elements of the descriptor table (i.e. arguments and results) are
418- // converted to MemrefDesc's according to the corresponding operands in the
419- // runtime signature.
420- Status XlaRuntimeCpuExecutable::Execute (
421- const std::vector<BufferDesc>& descriptor_table,
422- const ExecutableRunOptions* run_options) {
423- const runtime::FunctionType& signature = GetExecutable ().runtime_signature ();
424-
425- size_t num_arguments = xla_framework_mapping_.inputs .size ();
426- if (xla_framework_mapping_.output_is_tuple ) {
427- num_arguments += xla_framework_mapping_.flattened_outputs .size ();
428- } else if (xla_framework_mapping_.result != -1 ) {
429- num_arguments += 1 ;
430- }
431-
432- // Verify that the number of arguments in the mapping matches the signature.
433- // Add one to num_arguments to account for the signature's execution context.
434- if (num_arguments + 1 != signature.num_operands ()) {
435- return Internal (
436- " Wrong number of arguments: got %zu via XLA FrameworkMapping, expected "
437- " %d." ,
438- num_arguments, static_cast <int >(signature.num_operands ()) - 1 );
439- }
440-
441- std::vector<runtime::MemrefDesc> arguments;
442- arguments.reserve (num_arguments);
443-
444- auto append_converted_buffer = [&](size_t descriptor_index) -> Status {
445- const BufferDesc& descriptor = descriptor_table[descriptor_index];
446-
447- // Use 1-based index to account for the execution context.
448- size_t operand_index = arguments.size () + 1 ;
449- const runtime::Type* operand_type = signature.operand (operand_index);
450-
451- absl::StatusOr<runtime::MemrefDesc> memref = BufferToMemref (
452- descriptor, *operand_type, descriptor_index, operand_index);
453- if (!memref.ok ()) {
454- return memref.status ();
455- }
456- arguments.push_back (std::move (*memref));
457- return OkStatus ();
458- };
459-
460- // Inputs come first; results come last.
461- for (int64_t index : xla_framework_mapping_.inputs ) {
462- TF_RETURN_IF_ERROR (append_converted_buffer (index));
463- }
464-
465- int64_t result_index = xla_framework_mapping_.result ;
466- if (xla_framework_mapping_.output_is_tuple ) {
467- size_t num_outputs = xla_framework_mapping_.flattened_outputs .size ();
468- for (size_t i = 0 ; i < num_outputs; ++i) {
469- int64_t output_index = xla_framework_mapping_.flattened_outputs [i];
470-
471- TF_RETURN_IF_ERROR (append_converted_buffer (output_index));
472-
473- // Populate the output tuple with a pointer to this result.
474- // TODO(b/249078472): make this work with nested tuples, if needed.
475- assert (result_index != -1 );
476- void ** results =
477- static_cast <void **>(descriptor_table[result_index].data ());
478- results[i] = descriptor_table[output_index].data ();
479- }
480- } else if (result_index != -1 ) {
481- TF_RETURN_IF_ERROR (append_converted_buffer (result_index));
482- }
483-
484- runtime::Executable::CallFrame call_frame;
485- // Skip verification. The MemrefDesc's we created above come from the runtime
486- // signature; verifying them against the same signature would be redundant.
487- if (auto status =
488- GetExecutable ().InitializeCallFrame (arguments, &call_frame,
489- /* verify_arguments=*/ false );
490- !status.ok ()) {
491- return Internal (" Failed to initialize call frame: %s." ,
492- status.message ());
493- }
494-
495- // No results to return; they are returned via out params.
496- runtime::NoResultConverter converter;
497-
498- // Collect all emitted diagnostic messages.
499- std::string diagnostic;
500- runtime::DiagnosticEngine diagnostic_engine;
501- diagnostic_engine.AddHandler ([&](runtime::Diagnostic& d) {
502- absl::StrAppend (&diagnostic, d.status ().message ());
503- return runtime::success ();
504- });
505-
506- runtime::CustomCall::UserData user_data (run_options);
507-
508- runtime::Executable::ExecuteOpts opts;
509- opts.custom_call_data = &user_data;
510- opts.diagnostic_engine = &diagnostic_engine;
511- opts.custom_call_registry = &dynamic_custom_calls_;
512-
513- // We don't expect to see any async tasks in the XLA Runtime executable.
514- opts.async_task_runner =
515- reinterpret_cast <runtime::AsyncTaskRunner*>(0xdeadbeef );
516-
517- // Execute with the prepared call frame.
518- GetExecutable ().Execute (call_frame, opts);
519- if (auto status = GetExecutable ().ReturnResults (converter, &call_frame);
520- !status.ok ()) {
521- return Internal (" Failed to execute XLA Runtime executable: %s%s%s." ,
522- status.message (), diagnostic.empty () ? " " : " : " ,
523- diagnostic);
524- }
525- return OkStatus ();
526- }
527-
528349absl::StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream (
529350 const ServiceExecutableRunOptions* run_options,
530351 std::vector<ExecutionInput> arguments,
@@ -612,9 +433,6 @@ const InstructionValueSet& CpuExecutable::GetRootValueSet() const {
612433}
613434
614435int64_t CpuExecutable::SizeOfGeneratedCodeInBytes () const {
615- // TODO(b/233850967): support profiling in XLA:CPU-Next, instead of
616- // punting on it as we are doing here.
617- if (IsXlaRuntime ()) return 0 ;
618436 return jit_->SizeOfGeneratedCodeInBytes ();
619437}
620438
0 commit comments