From 0a042cbb5275e6ff9a3a7c2748c74df6dcede09e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 12 Aug 2022 22:41:27 -0700 Subject: [PATCH] [xla:runtime] NFC: Extract calling_convention library from jitrt and move it to xla/runtime PiperOrigin-RevId: 467360160 --- backends/jitrt/BUILD | 1 + .../cpp_tests/calling_convention_test.cc | 6 +- .../cpp_tests/end_to_end_example_test.cc | 2 +- backends/jitrt/include/tfrt/jitrt/jitrt.h | 65 ++-------------- backends/jitrt/lib/jitrt.cc | 75 ------------------- 5 files changed, 10 insertions(+), 139 deletions(-) diff --git a/backends/jitrt/BUILD b/backends/jitrt/BUILD index c13c1af461f..d85f74539ce 100644 --- a/backends/jitrt/BUILD +++ b/backends/jitrt/BUILD @@ -98,6 +98,7 @@ licenses(["notice"]) # "@llvm-project//mlir:ToLLVMIRTranslation", # "@llvm-project//mlir:Transforms", # "@llvm-project//mlir:mlir_c_runner_utils", +# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:calling_convention", # "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:passes", # "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:specialization", # "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:type_converter", diff --git a/backends/jitrt/cpp_tests/calling_convention_test.cc b/backends/jitrt/cpp_tests/calling_convention_test.cc index 3ec3b63bff7..03c301c4e52 100644 --- a/backends/jitrt/cpp_tests/calling_convention_test.cc +++ b/backends/jitrt/cpp_tests/calling_convention_test.cc @@ -47,7 +47,7 @@ static const char* entrypoint = "log2_1d"; struct CallingConventionTestCase { std::string test_name; - CompilationOptions::CallingConvention calling_convention; + CallingConvention calling_convention; int expected_num_results; int expected_num_operands; }; @@ -117,11 +117,11 @@ INSTANTIATE_TEST_SUITE_P( CallingConventionTest, CallingConventionTest, testing::ValuesIn({ {"DefaultCallingConvention", - CompilationOptions::DefaultCallingConvention( + xla::runtime::DefaultCallingConvention( mlir::bufferization::BufferizeTypeConverter()), /*expected_num_results=*/1, /*expected_num_operands=*/2}, {"ResultsToOutsCallingConvention", - CompilationOptions::ResultsToOutsCallingConvention( + xla::runtime::ResultsToOutsCallingConvention( mlir::bufferization::BufferizeTypeConverter()), /*expected_num_results=*/0, /*expected_num_operands=*/3}, }), diff --git a/backends/jitrt/cpp_tests/end_to_end_example_test.cc b/backends/jitrt/cpp_tests/end_to_end_example_test.cc index cf4093dc6e8..3d711effc13 100644 --- a/backends/jitrt/cpp_tests/end_to_end_example_test.cc +++ b/backends/jitrt/cpp_tests/end_to_end_example_test.cc @@ -272,7 +272,7 @@ TEST(EndToEndExampleTest, CompiledAndExecute) { // the ABI boundary. The expectation is that compiler pipeline will act // according to this calling convention, and the entrypoint will have the same // function signature. - opts.calling_convention = CompilationOptions::DefaultCallingConvention( + opts.calling_convention = xla::runtime::DefaultCallingConvention( mlir::bufferization::BufferizeTypeConverter()); // Add a conversion from the `!testlib.custom_arg` MLIR type to the run time diff --git a/backends/jitrt/include/tfrt/jitrt/jitrt.h b/backends/jitrt/include/tfrt/jitrt/jitrt.h index 642f915e91c..74377fc4694 100644 --- a/backends/jitrt/include/tfrt/jitrt/jitrt.h +++ b/backends/jitrt/include/tfrt/jitrt/jitrt.h @@ -39,6 +39,7 @@ #include "tfrt/host_context/task_function.h" #include "tfrt/jitrt/results.h" #include "tfrt/support/forward_decls.h" +#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h" #include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h" #include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h" #include "third_party/tensorflow/compiler/xla/runtime/arguments.h" @@ -75,6 +76,8 @@ class Tensor; namespace jitrt { +using xla::runtime::CallingConvention; + // Compiled module example: // // module @kernel attributes { tfrt.compiled } { @@ -175,24 +178,6 @@ namespace jitrt { class JitExecutable; struct CompilationOptions { - // Calling convention defines an ABI for JitRt to call a compiled kernel. See - // documentation and example below. - using CallingConvention = - std::function; - - // Returns a calling convention that only adds the kernel context argument. - static CallingConvention DefaultCallingConvention(); - - // Returns a calling convention that uses user-provided type converter to - // convert all inputs and results types, and adds the kernel context argument. - static CallingConvention DefaultCallingConvention(mlir::TypeConverter); - - // Returns a calling convention that (1) prepends the kernel context argument, - // (2) uses the user-provided type converter to convert all inputs and results - // types, and (3) converts result types into out-params by appending them - // to the arguments. - static CallingConvention ResultsToOutsCallingConvention(mlir::TypeConverter); - // Compiled kernel can be specialized and recompiled at runtime to the // concrete input shapes and sometimes values (e.g. reduciton dimension). enum class Specialization { @@ -237,48 +222,8 @@ struct CompilationOptions { // `rt-to-kernel-function` pass to convert regular functions to "kernels". std::function create_compilation_pipeline; - // Calling convention converts the compiled module entrypoint function type to - // the function type with a well defined ABI (e.g. tensors do not have an ABI, - // and must be passed across the function boundary as memrefs). In a nutshell - // it tells the JitRt how to call the compiled kernel at run time, and how to - // return results back to the JitRt. - // - // All types in the converted function signature should have a registered - // type conversion (see `type_converter` below) to a type with defined - // argument or result ABI (see Type::ArgumentAbi and Type::ResultAbi). - // - // If conversion is not possible, calling convention must return a null value. - // - // Example: abstract kernel defined in high level dialect, e.g. MHLO - // - // ```mlir - // func @kernel(%arg0: tensor, - // %arg1: tensor) -> tensor { ... } - // ``` - // - // after calling convention conversion becomes: - // - // ```mlir - // func @kernel(%ctx: !rt.kernel_context, - // %arg0: memref, - // %arg1: memref) -> memref { ... } - // ``` - // - // Calling convention function type is not the same as the entrypoint function - // type produced by the compilation pipeline for several reasons: - // - // 1) Compilation pipeline produces LLVM functions with LLVM types, and high - // level information is lost, e.g. all memrefs are deconstructed into - // primitive fields when passed as inputs. - // - // 2) Compiled kernel function always returns void, and uses runtime API to - // return results back to the caller (see `rt-convert-to-entrypoint` pass). - // - // Calling convention function type is a JitRt-compatible description of the - // compiled kernel ABI, so that JitRt can correctly initialize CallFrame - // arguments, allocate memory for returned results, and then correctly decode - // results memory into the high level types (e.g. convert returned memref - // descriptor to a Tensorfow tensor). + // Calling convention defines an ABI for XLA runtime to call an executable. + // See `CallingConvention` documentation for details. CallingConvention calling_convention = DefaultCallingConvention(); // Type converter converts MLIR types to the corresponding run time types. diff --git a/backends/jitrt/lib/jitrt.cc b/backends/jitrt/lib/jitrt.cc index 05d0e56ba84..ec81894e1c6 100644 --- a/backends/jitrt/lib/jitrt.cc +++ b/backends/jitrt/lib/jitrt.cc @@ -142,81 +142,6 @@ Expected ConvertTensorToMemrefDesc(const Tensor& tensor) { return MakeStringError("unsupported tensor type: ", tensor.tensor_type()); } -//----------------------------------------------------------------------------// -// Default calling convention for kernels compiled for JitRt. -//----------------------------------------------------------------------------// - -using CallingConvention = CompilationOptions::CallingConvention; - -/*static*/ CallingConvention CompilationOptions::DefaultCallingConvention() { - return [](mlir::FunctionType func) { - mlir::MLIRContext* ctx = func.getContext(); - - llvm::SmallVector inputs = {KernelContextType::get(ctx)}; - inputs.reserve(1 + func.getNumInputs()); - llvm::append_range(inputs, func.getInputs()); - - return mlir::FunctionType::get(ctx, inputs, func.getResults()); - }; -} - -/*static*/ CallingConvention CompilationOptions::DefaultCallingConvention( - mlir::TypeConverter type_converter) { - return [c = std::move(type_converter)](mlir::FunctionType func) mutable { - mlir::MLIRContext* ctx = func.getContext(); - - // Track if all type conversions were successful. - bool failed_conversion = false; - auto convert = [&](mlir::Type type) -> mlir::Type { - auto converted = c.convertType(type); - if (!converted) failed_conversion = true; - return converted; - }; - - // Add kernel context as the first argument. - llvm::SmallVector inputs = {KernelContextType::get(ctx)}; - inputs.reserve(1 + func.getNumInputs()); - llvm::transform(func.getInputs(), std::back_inserter(inputs), convert); - - // Apply type conversion to all results types. - llvm::SmallVector results; - results.reserve(func.getNumResults()); - llvm::transform(func.getResults(), std::back_inserter(results), convert); - - // Return null if any of the type conversions failed. - if (failed_conversion) return mlir::FunctionType(); - - return mlir::FunctionType::get(ctx, inputs, results); - }; -} - -/*static*/ CallingConvention CompilationOptions::ResultsToOutsCallingConvention( - mlir::TypeConverter type_converter) { - return [c = std::move(type_converter)](mlir::FunctionType func) mutable { - mlir::MLIRContext* ctx = func.getContext(); - - // Track if all type conversions were successful. - bool failed_conversion = false; - - auto convert = [&](mlir::Type type) -> mlir::Type { - auto converted = c.convertType(type); - if (!converted) failed_conversion = true; - return converted; - }; - - llvm::SmallVector inputs; - inputs.reserve(1 + func.getNumInputs() + func.getNumResults()); - inputs.push_back(KernelContextType::get(ctx)); - llvm::transform(func.getInputs(), std::back_inserter(inputs), convert); - llvm::transform(func.getResults(), std::back_inserter(inputs), convert); - - // Return null if any of the type conversions failed. - if (failed_conversion) return mlir::FunctionType(); - - return mlir::FunctionType::get(ctx, inputs, {}); - }; -} - //----------------------------------------------------------------------------// // Setup MLIR pass pipeline to lower to LLVM dialect, and use ORC JIT to codegen // functions at runtime.