Skip to content

Commit

Permalink
[xla:runtime] NFC: Extract calling_convention library from jitrt and …
Browse files Browse the repository at this point in the history
…move it to xla/runtime

PiperOrigin-RevId: 467360160
  • Loading branch information
ezhulenev authored and copybara-github committed Aug 13, 2022
1 parent b20ec05 commit 0a042cb
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 139 deletions.
1 change: 1 addition & 0 deletions backends/jitrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ licenses(["notice"])
# "@llvm-project//mlir:ToLLVMIRTranslation",
# "@llvm-project//mlir:Transforms",
# "@llvm-project//mlir:mlir_c_runner_utils",
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:calling_convention",
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:passes",
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:specialization",
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:type_converter",
Expand Down
6 changes: 3 additions & 3 deletions backends/jitrt/cpp_tests/calling_convention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static const char* entrypoint = "log2_1d";

struct CallingConventionTestCase {
std::string test_name;
CompilationOptions::CallingConvention calling_convention;
CallingConvention calling_convention;
int expected_num_results;
int expected_num_operands;
};
Expand Down Expand Up @@ -117,11 +117,11 @@ INSTANTIATE_TEST_SUITE_P(
CallingConventionTest, CallingConventionTest,
testing::ValuesIn<CallingConventionTestCase>({
{"DefaultCallingConvention",
CompilationOptions::DefaultCallingConvention(
xla::runtime::DefaultCallingConvention(
mlir::bufferization::BufferizeTypeConverter()),
/*expected_num_results=*/1, /*expected_num_operands=*/2},
{"ResultsToOutsCallingConvention",
CompilationOptions::ResultsToOutsCallingConvention(
xla::runtime::ResultsToOutsCallingConvention(
mlir::bufferization::BufferizeTypeConverter()),
/*expected_num_results=*/0, /*expected_num_operands=*/3},
}),
Expand Down
2 changes: 1 addition & 1 deletion backends/jitrt/cpp_tests/end_to_end_example_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ TEST(EndToEndExampleTest, CompiledAndExecute) {
// the ABI boundary. The expectation is that compiler pipeline will act
// according to this calling convention, and the entrypoint will have the same
// function signature.
opts.calling_convention = CompilationOptions::DefaultCallingConvention(
opts.calling_convention = xla::runtime::DefaultCallingConvention(
mlir::bufferization::BufferizeTypeConverter());

// Add a conversion from the `!testlib.custom_arg` MLIR type to the run time
Expand Down
65 changes: 5 additions & 60 deletions backends/jitrt/include/tfrt/jitrt/jitrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "tfrt/host_context/task_function.h"
#include "tfrt/jitrt/results.h"
#include "tfrt/support/forward_decls.h"
#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h"
#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h"
#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h"
#include "third_party/tensorflow/compiler/xla/runtime/arguments.h"
Expand Down Expand Up @@ -75,6 +76,8 @@ class Tensor;

namespace jitrt {

using xla::runtime::CallingConvention;

// Compiled module example:
//
// module @kernel attributes { tfrt.compiled } {
Expand Down Expand Up @@ -175,24 +178,6 @@ namespace jitrt {
class JitExecutable;

struct CompilationOptions {
// Calling convention defines an ABI for JitRt to call a compiled kernel. See
// documentation and example below.
using CallingConvention =
std::function<mlir::FunctionType(mlir::FunctionType)>;

// Returns a calling convention that only adds the kernel context argument.
static CallingConvention DefaultCallingConvention();

// Returns a calling convention that uses user-provided type converter to
// convert all inputs and results types, and adds the kernel context argument.
static CallingConvention DefaultCallingConvention(mlir::TypeConverter);

// Returns a calling convention that (1) prepends the kernel context argument,
// (2) uses the user-provided type converter to convert all inputs and results
// types, and (3) converts result types into out-params by appending them
// to the arguments.
static CallingConvention ResultsToOutsCallingConvention(mlir::TypeConverter);

// Compiled kernel can be specialized and recompiled at runtime to the
// concrete input shapes and sometimes values (e.g. reduciton dimension).
enum class Specialization {
Expand Down Expand Up @@ -237,48 +222,8 @@ struct CompilationOptions {
// `rt-to-kernel-function` pass to convert regular functions to "kernels".
std::function<void(mlir::PassManager&)> create_compilation_pipeline;

// Calling convention converts the compiled module entrypoint function type to
// the function type with a well defined ABI (e.g. tensors do not have an ABI,
// and must be passed across the function boundary as memrefs). In a nutshell
// it tells the JitRt how to call the compiled kernel at run time, and how to
// return results back to the JitRt.
//
// All types in the converted function signature should have a registered
// type conversion (see `type_converter` below) to a type with defined
// argument or result ABI (see Type::ArgumentAbi and Type::ResultAbi).
//
// If conversion is not possible, calling convention must return a null value.
//
// Example: abstract kernel defined in high level dialect, e.g. MHLO
//
// ```mlir
// func @kernel(%arg0: tensor<?xf32>,
// %arg1: tensor<?xf32>) -> tensor<?x?xf32> { ... }
// ```
//
// after calling convention conversion becomes:
//
// ```mlir
// func @kernel(%ctx: !rt.kernel_context,
// %arg0: memref<?xf32>,
// %arg1: memref<?xf32>) -> memref<?x?xf32> { ... }
// ```
//
// Calling convention function type is not the same as the entrypoint function
// type produced by the compilation pipeline for several reasons:
//
// 1) Compilation pipeline produces LLVM functions with LLVM types, and high
// level information is lost, e.g. all memrefs are deconstructed into
// primitive fields when passed as inputs.
//
// 2) Compiled kernel function always returns void, and uses runtime API to
// return results back to the caller (see `rt-convert-to-entrypoint` pass).
//
// Calling convention function type is a JitRt-compatible description of the
// compiled kernel ABI, so that JitRt can correctly initialize CallFrame
// arguments, allocate memory for returned results, and then correctly decode
// results memory into the high level types (e.g. convert returned memref
// descriptor to a Tensorfow tensor).
// Calling convention defines an ABI for XLA runtime to call an executable.
// See `CallingConvention` documentation for details.
CallingConvention calling_convention = DefaultCallingConvention();

// Type converter converts MLIR types to the corresponding run time types.
Expand Down
75 changes: 0 additions & 75 deletions backends/jitrt/lib/jitrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,81 +142,6 @@ Expected<MemrefDesc> ConvertTensorToMemrefDesc(const Tensor& tensor) {
return MakeStringError("unsupported tensor type: ", tensor.tensor_type());
}

//----------------------------------------------------------------------------//
// Default calling convention for kernels compiled for JitRt.
//----------------------------------------------------------------------------//

using CallingConvention = CompilationOptions::CallingConvention;

/*static*/ CallingConvention CompilationOptions::DefaultCallingConvention() {
return [](mlir::FunctionType func) {
mlir::MLIRContext* ctx = func.getContext();

llvm::SmallVector<mlir::Type> inputs = {KernelContextType::get(ctx)};
inputs.reserve(1 + func.getNumInputs());
llvm::append_range(inputs, func.getInputs());

return mlir::FunctionType::get(ctx, inputs, func.getResults());
};
}

/*static*/ CallingConvention CompilationOptions::DefaultCallingConvention(
mlir::TypeConverter type_converter) {
return [c = std::move(type_converter)](mlir::FunctionType func) mutable {
mlir::MLIRContext* ctx = func.getContext();

// Track if all type conversions were successful.
bool failed_conversion = false;
auto convert = [&](mlir::Type type) -> mlir::Type {
auto converted = c.convertType(type);
if (!converted) failed_conversion = true;
return converted;
};

// Add kernel context as the first argument.
llvm::SmallVector<mlir::Type> inputs = {KernelContextType::get(ctx)};
inputs.reserve(1 + func.getNumInputs());
llvm::transform(func.getInputs(), std::back_inserter(inputs), convert);

// Apply type conversion to all results types.
llvm::SmallVector<mlir::Type> results;
results.reserve(func.getNumResults());
llvm::transform(func.getResults(), std::back_inserter(results), convert);

// Return null if any of the type conversions failed.
if (failed_conversion) return mlir::FunctionType();

return mlir::FunctionType::get(ctx, inputs, results);
};
}

/*static*/ CallingConvention CompilationOptions::ResultsToOutsCallingConvention(
mlir::TypeConverter type_converter) {
return [c = std::move(type_converter)](mlir::FunctionType func) mutable {
mlir::MLIRContext* ctx = func.getContext();

// Track if all type conversions were successful.
bool failed_conversion = false;

auto convert = [&](mlir::Type type) -> mlir::Type {
auto converted = c.convertType(type);
if (!converted) failed_conversion = true;
return converted;
};

llvm::SmallVector<mlir::Type> inputs;
inputs.reserve(1 + func.getNumInputs() + func.getNumResults());
inputs.push_back(KernelContextType::get(ctx));
llvm::transform(func.getInputs(), std::back_inserter(inputs), convert);
llvm::transform(func.getResults(), std::back_inserter(inputs), convert);

// Return null if any of the type conversions failed.
if (failed_conversion) return mlir::FunctionType();

return mlir::FunctionType::get(ctx, inputs, {});
};
}

//----------------------------------------------------------------------------//
// Setup MLIR pass pipeline to lower to LLVM dialect, and use ORC JIT to codegen
// functions at runtime.
Expand Down

0 comments on commit 0a042cb

Please sign in to comment.