Skip to content

Commit 69b7c1c

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla:cpu] NFC: Remove deprecated XLA:CPU mlir based codegen part #4
PiperOrigin-RevId: 630139768
1 parent b470d92 commit 69b7c1c

File tree

6 files changed

+16
-308
lines changed

6 files changed

+16
-308
lines changed

third_party/xla/xla/service/cpu/BUILD

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -245,19 +245,10 @@ cc_library(
245245
"//xla:xla_proto_cc",
246246
"//xla/hlo/ir:hlo",
247247
"//xla/hlo/ir:hlo_module_group",
248-
"//xla/mlir/framework/ir:xla_framework",
249-
"//xla/mlir/runtime/ir:rt",
250-
"//xla/mlir/runtime/transforms:calling_convention",
251-
"//xla/mlir/runtime/transforms:compilation_pipeline_cpu",
252-
"//xla/mlir/runtime/transforms:compiler",
253-
"//xla/mlir/runtime/transforms:jit_compiler",
254248
"//xla/mlir_hlo",
255249
"//xla/mlir_hlo:all_passes",
256250
"//xla/mlir_hlo:mhlo_passes",
257251
"//xla/mlir_hlo:transforms_passes",
258-
"//xla/runtime:custom_call_registry",
259-
"//xla/runtime:executable",
260-
"//xla/runtime:jit_executable",
261252
"//xla/service:algebraic_simplifier",
262253
"//xla/service:all_reduce_promotion",
263254
"//xla/service:all_to_all_decomposer",
@@ -564,9 +555,6 @@ cc_library(
564555
"//xla:util",
565556
"//xla:xla_data_proto_cc",
566557
"//xla/hlo/ir:hlo",
567-
"//xla/mlir/runtime/transforms:compiler",
568-
"//xla/runtime:executable",
569-
"//xla/runtime:jit_executable",
570558
"//xla/service:buffer_assignment",
571559
"//xla/service:computation_layout",
572560
"//xla/service:custom_call_status_internal",
@@ -808,7 +796,6 @@ cc_library(
808796
"//xla:statusor",
809797
"//xla:types",
810798
"//xla:util",
811-
"//xla/runtime:execution_engine",
812799
"//xla/service:llvm_compiler",
813800
"//xla/service/llvm_ir:llvm_util",
814801
"@com_google_absl//absl/functional:any_invocable",

third_party/xla/xla/service/cpu/compiler_functor.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ limitations under the License.
3535
#include "llvm/Support/raw_ostream.h"
3636
#include "llvm/Target/TargetMachine.h"
3737
#include "llvm/Transforms/Instrumentation/DataFlowSanitizer.h"
38-
#include "xla/runtime/execution_engine.h"
3938
#include "xla/service/cpu/cpu_runtime.h"
4039
#include "xla/service/cpu/llvm_ir_runtime.h"
4140
#include "xla/service/llvm_ir/llvm_util.h"
@@ -160,20 +159,6 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
160159

161160
llvm::ModulePassManager pm;
162161

163-
for (const auto& func_name : convert_to_xla_runtime_abi_) {
164-
llvm::Function* func = module.getFunction(func_name);
165-
// Create a new function with the XLA Runtime ABI and inline the original
166-
// (i.e. with ctx + memref args) into it.
167-
std::string inlined_func_name =
168-
absl::StrCat(func_name, "__orig_xla_runtime_abi");
169-
func->setName(inlined_func_name);
170-
absl::Status status = xla::runtime::ExportWithXlaRuntimeAbi(
171-
module, inlined_func_name, func_name);
172-
if (!status.ok()) {
173-
LOG(FATAL) << status.message();
174-
}
175-
}
176-
177162
if (dfsan_enabled_) {
178163
pm.addPass(llvm::DataFlowSanitizerPass(dfsan_abi_list_files_));
179164
}

third_party/xla/xla/service/cpu/compiler_functor.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
4646
absl::AnyInvocable<void(const llvm::object::ObjectFile&)>
4747
post_codegen_hook = nullptr,
4848
bool dfsan_enabled = false,
49-
const std::vector<std::string>& dfsan_abi_list_files = {},
50-
const std::vector<std::string>& convert_to_xla_runtime_abi = {})
49+
const std::vector<std::string>& dfsan_abi_list_files = {})
5150
: IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()),
5251
target_machine_(target_machine),
5352
opt_level_(opt_level),
@@ -59,8 +58,7 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
5958
post_optimization_hook_(std::move(post_optimization_hook)),
6059
post_codegen_hook_(std::move(post_codegen_hook)),
6160
dfsan_enabled_(dfsan_enabled),
62-
dfsan_abi_list_files_(dfsan_abi_list_files),
63-
convert_to_xla_runtime_abi_(convert_to_xla_runtime_abi) {}
61+
dfsan_abi_list_files_(dfsan_abi_list_files) {}
6462

6563
// Compile a Module to an ObjectFile.
6664
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> operator()(
@@ -78,7 +76,6 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
7876
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook_;
7977
const bool dfsan_enabled_ = false;
8078
const std::vector<std::string> dfsan_abi_list_files_;
81-
const std::vector<std::string> convert_to_xla_runtime_abi_;
8279
};
8380

8481
} // namespace cpu

third_party/xla/xla/service/cpu/cpu_compiler.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,11 +1412,6 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
14121412
obj_file.getData().size()));
14131413
};
14141414

1415-
std::vector<std::string> xla_runtime_abi_conversions;
1416-
if (options.use_mlir_hlo_lowering()) {
1417-
xla_runtime_abi_conversions.push_back(options.entry_point_name());
1418-
}
1419-
14201415
CompilerFunctor compiler_functor(
14211416
target_machine.get(), static_cast<int>(opt_level),
14221417
options::OptimizeForSizeRequested(module->config()),
@@ -1425,8 +1420,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
14251420
llvm_ir::GetCpuFastMathFlags(module->config()),
14261421
pre_optimization_ir_hook, post_optimization_ir_hook,
14271422
post_codegen_hook, aot_options.sanitize_dataflow(),
1428-
aot_options.sanitize_abilists_dataflow(),
1429-
xla_runtime_abi_conversions);
1423+
aot_options.sanitize_abilists_dataflow());
14301424
std::unique_ptr<llvm::MemoryBuffer> object_file =
14311425
cantFail(compiler_functor(*llvm_module));
14321426
ObjectFileData object_file_data(object_file->getBufferStart(),

third_party/xla/xla/service/cpu/cpu_executable.cc

Lines changed: 12 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ limitations under the License.
3535
#include "mlir/Parser/Parser.h" // from @llvm-project
3636
#include "xla/hlo/ir/hlo_computation.h"
3737
#include "xla/hlo/ir/hlo_module.h"
38-
#include "xla/mlir/runtime/transforms/compiler.h"
3938
#include "xla/service/buffer_assignment.h"
4039
#include "xla/service/computation_layout.h"
4140
#include "xla/service/logical_buffer.h"
@@ -56,8 +55,6 @@ limitations under the License.
5655
namespace xla {
5756
namespace cpu {
5857

59-
namespace runtime = ::xla::runtime;
60-
6158
absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
6259
std::unique_ptr<SimpleOrcJIT> jit,
6360
std::unique_ptr<const BufferAssignment> assignment,
@@ -95,15 +92,11 @@ absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
9592
std::unique_ptr<HloModule> hlo_module,
9693
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
9794
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
98-
std::unique_ptr<const BufferAssignment> assignment,
99-
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable) {
95+
std::unique_ptr<const BufferAssignment> assignment) {
10096
std::unique_ptr<CpuExecutable> executable(new CpuExecutable(
10197
std::move(hlo_module), std::move(hlo_profile_printer_data),
10298
std::move(hlo_profile_index_map), std::move(assignment)));
103-
executable->set_ir_module_string(
104-
xla_runtime_executable->GetExecutable().take_ir_module_string());
10599
executable->module_name_ = "main";
106-
executable->xla_runtime_executable_ = std::move(xla_runtime_executable);
107100
return executable;
108101
}
109102

@@ -237,33 +230,17 @@ Status CpuExecutable::ExecuteComputeFunction(
237230
}
238231
};
239232

240-
if (IsXlaRuntime()) {
241-
std::vector<BufferDesc> descriptor_table;
242-
descriptor_table.reserve(buffers.size());
243-
for (const auto& buffer : buffers) {
244-
const tensorflow::se::DeviceMemoryBase& base =
245-
buffer.AsDeviceMemoryBase();
246-
BufferDesc desc(const_cast<void*>(base.opaque()), base.size());
247-
descriptor_table.push_back(std::move(desc));
248-
}
249-
Status status = ExecuteXlaRuntime(descriptor_table, run_options);
250-
record_profile();
251-
if (!status.ok()) {
252-
return status;
253-
}
254-
} else {
255-
XlaCustomCallStatus status;
256-
// For the entry computation (like all global computations), all inputs and
257-
// outputs are in the buffer table, and both the result pointer and args
258-
// array pointers are unused (so we set them to 'nullptr').
259-
compute_function_(nullptr, run_options, nullptr, buffer_pointers.data(),
260-
&status, profile_counters);
261-
record_profile();
262-
std::optional<absl::string_view> error_message =
263-
CustomCallStatusGetMessage(&status);
264-
if (error_message) {
265-
return Internal("CustomCall failed: %s", *error_message);
266-
}
233+
XlaCustomCallStatus status;
234+
// For the entry computation (like all global computations), all inputs and
235+
// outputs are in the buffer table, and both the result pointer and args
236+
// array pointers are unused (so we set them to 'nullptr').
237+
compute_function_(nullptr, run_options, nullptr, buffer_pointers.data(),
238+
&status, profile_counters);
239+
record_profile();
240+
std::optional<absl::string_view> error_message =
241+
CustomCallStatusGetMessage(&status);
242+
if (error_message) {
243+
return Internal("CustomCall failed: %s", *error_message);
267244
}
268245

269246
return OkStatus();
@@ -369,162 +346,6 @@ absl::StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
369346
return std::move(result);
370347
}
371348

372-
// Converts a BufferDesc to a MemrefDesc according to the given 'operand_type',
373-
// which should point to a runtime::MemrefType.
374-
// Note: 'descriptor_index' and 'operand_index' are just used for error
375-
// reporting.
376-
static absl::StatusOr<runtime::MemrefDesc> BufferToMemref(
377-
const BufferDesc& descriptor, const runtime::Type& operand_type,
378-
size_t descriptor_index, size_t operand_index) {
379-
auto* memref = llvm::dyn_cast<runtime::MemrefType>(&operand_type);
380-
if (!memref) {
381-
return Internal(
382-
"Cannot convert descriptor %zu (operand_index %zu): "
383-
"the corresponding type in the signature is a %s, "
384-
"not a MemrefType.",
385-
descriptor_index, operand_index, operand_type.ToString());
386-
}
387-
388-
absl::Span<const int64_t> dims = memref->sizes();
389-
390-
// Verify that the provided descriptor size matches that of the memref.
391-
size_t n_elem = absl::c_accumulate(dims, size_t{1}, std::multiplies<>());
392-
size_t expected_size =
393-
primitive_util::ByteWidth(memref->element_type()) * n_elem;
394-
if (LLVM_UNLIKELY(expected_size != descriptor.size())) {
395-
return InvalidArgument(
396-
"Cannot convert descriptor %zu (operand_index %zu): "
397-
"buffer size is not equal to that expected from the element type: "
398-
"got %zu vs expected %zu.",
399-
descriptor_index, operand_index, descriptor.size(), expected_size);
400-
}
401-
402-
auto fill_sizes_and_strides = [&](auto sizes, auto strides) {
403-
size_t multiplier = 1;
404-
for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
405-
size_t size = dims[i];
406-
sizes[i] = size;
407-
strides[i] = multiplier;
408-
multiplier *= size;
409-
}
410-
};
411-
return runtime::MemrefDesc(memref->rank(), memref->element_type(),
412-
descriptor.data(), /*offset=*/0,
413-
fill_sizes_and_strides);
414-
}
415-
416-
// Executes from an XLA Runtime CPU executable, given a buffer descriptor table.
417-
// Relevant elements of the descriptor table (i.e. arguments and results) are
418-
// converted to MemrefDesc's according to the corresponding operands in the
419-
// runtime signature.
420-
Status XlaRuntimeCpuExecutable::Execute(
421-
const std::vector<BufferDesc>& descriptor_table,
422-
const ExecutableRunOptions* run_options) {
423-
const runtime::FunctionType& signature = GetExecutable().runtime_signature();
424-
425-
size_t num_arguments = xla_framework_mapping_.inputs.size();
426-
if (xla_framework_mapping_.output_is_tuple) {
427-
num_arguments += xla_framework_mapping_.flattened_outputs.size();
428-
} else if (xla_framework_mapping_.result != -1) {
429-
num_arguments += 1;
430-
}
431-
432-
// Verify that the number of arguments in the mapping matches the signature.
433-
// Add one to num_arguments to account for the signature's execution context.
434-
if (num_arguments + 1 != signature.num_operands()) {
435-
return Internal(
436-
"Wrong number of arguments: got %zu via XLA FrameworkMapping, expected "
437-
"%d.",
438-
num_arguments, static_cast<int>(signature.num_operands()) - 1);
439-
}
440-
441-
std::vector<runtime::MemrefDesc> arguments;
442-
arguments.reserve(num_arguments);
443-
444-
auto append_converted_buffer = [&](size_t descriptor_index) -> Status {
445-
const BufferDesc& descriptor = descriptor_table[descriptor_index];
446-
447-
// Use 1-based index to account for the execution context.
448-
size_t operand_index = arguments.size() + 1;
449-
const runtime::Type* operand_type = signature.operand(operand_index);
450-
451-
absl::StatusOr<runtime::MemrefDesc> memref = BufferToMemref(
452-
descriptor, *operand_type, descriptor_index, operand_index);
453-
if (!memref.ok()) {
454-
return memref.status();
455-
}
456-
arguments.push_back(std::move(*memref));
457-
return OkStatus();
458-
};
459-
460-
// Inputs come first; results come last.
461-
for (int64_t index : xla_framework_mapping_.inputs) {
462-
TF_RETURN_IF_ERROR(append_converted_buffer(index));
463-
}
464-
465-
int64_t result_index = xla_framework_mapping_.result;
466-
if (xla_framework_mapping_.output_is_tuple) {
467-
size_t num_outputs = xla_framework_mapping_.flattened_outputs.size();
468-
for (size_t i = 0; i < num_outputs; ++i) {
469-
int64_t output_index = xla_framework_mapping_.flattened_outputs[i];
470-
471-
TF_RETURN_IF_ERROR(append_converted_buffer(output_index));
472-
473-
// Populate the output tuple with a pointer to this result.
474-
// TODO(b/249078472): make this work with nested tuples, if needed.
475-
assert(result_index != -1);
476-
void** results =
477-
static_cast<void**>(descriptor_table[result_index].data());
478-
results[i] = descriptor_table[output_index].data();
479-
}
480-
} else if (result_index != -1) {
481-
TF_RETURN_IF_ERROR(append_converted_buffer(result_index));
482-
}
483-
484-
runtime::Executable::CallFrame call_frame;
485-
// Skip verification. The MemrefDesc's we created above come from the runtime
486-
// signature; verifying them against the same signature would be redundant.
487-
if (auto status =
488-
GetExecutable().InitializeCallFrame(arguments, &call_frame,
489-
/*verify_arguments=*/false);
490-
!status.ok()) {
491-
return Internal("Failed to initialize call frame: %s.",
492-
status.message());
493-
}
494-
495-
// No results to return; they are returned via out params.
496-
runtime::NoResultConverter converter;
497-
498-
// Collect all emitted diagnostic messages.
499-
std::string diagnostic;
500-
runtime::DiagnosticEngine diagnostic_engine;
501-
diagnostic_engine.AddHandler([&](runtime::Diagnostic& d) {
502-
absl::StrAppend(&diagnostic, d.status().message());
503-
return runtime::success();
504-
});
505-
506-
runtime::CustomCall::UserData user_data(run_options);
507-
508-
runtime::Executable::ExecuteOpts opts;
509-
opts.custom_call_data = &user_data;
510-
opts.diagnostic_engine = &diagnostic_engine;
511-
opts.custom_call_registry = &dynamic_custom_calls_;
512-
513-
// We don't expect to see any async tasks in the XLA Runtime executable.
514-
opts.async_task_runner =
515-
reinterpret_cast<runtime::AsyncTaskRunner*>(0xdeadbeef);
516-
517-
// Execute with the prepared call frame.
518-
GetExecutable().Execute(call_frame, opts);
519-
if (auto status = GetExecutable().ReturnResults(converter, &call_frame);
520-
!status.ok()) {
521-
return Internal("Failed to execute XLA Runtime executable: %s%s%s.",
522-
status.message(), diagnostic.empty() ? "" : ": ",
523-
diagnostic);
524-
}
525-
return OkStatus();
526-
}
527-
528349
absl::StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
529350
const ServiceExecutableRunOptions* run_options,
530351
std::vector<ExecutionInput> arguments,
@@ -612,9 +433,6 @@ const InstructionValueSet& CpuExecutable::GetRootValueSet() const {
612433
}
613434

614435
int64_t CpuExecutable::SizeOfGeneratedCodeInBytes() const {
615-
// TODO(b/233850967): support profiling in XLA:CPU-Next, instead of
616-
// punting on it as we are doing here.
617-
if (IsXlaRuntime()) return 0;
618436
return jit_->SizeOfGeneratedCodeInBytes();
619437
}
620438

0 commit comments

Comments
 (0)