Skip to content

Commit d05e130

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:CPU] Port scatter to kernel API
PiperOrigin-RevId: 750108264
1 parent 5dc05f0 commit d05e130

20 files changed

+483
-213
lines changed

xla/backends/cpu/codegen/BUILD

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ cc_library(
9999
"//xla:util",
100100
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
101101
"//xla/backends/cpu/codegen/emitters/transforms:passes",
102+
"//xla/codegen:llvm_ir_kernel_source",
103+
"//xla/codegen:mlir_kernel_source",
102104
"//xla/codegen/emitters/ir:xla",
103105
"//xla/codegen/emitters/ir:xla_attrs_inc_gen",
104106
"//xla/codegen/emitters/transforms:passes",
@@ -107,12 +109,12 @@ cc_library(
107109
"//xla/mlir_hlo:mhlo_passes",
108110
"//xla/tsl/framework/mlir:status_scoped_diagnostic_handler",
109111
"//xla/tsl/platform:errors",
112+
"//xla/tsl/platform:statusor",
110113
"@com_google_absl//absl/log",
111114
"@com_google_absl//absl/status",
112115
"@com_google_absl//absl/status:statusor",
113116
"@com_google_absl//absl/strings",
114117
"@llvm-project//llvm:Core",
115-
"@llvm-project//llvm:Support",
116118
"@llvm-project//mlir:AffineDialect",
117119
"@llvm-project//mlir:AffineToStandard",
118120
"@llvm-project//mlir:AffineTransforms",
@@ -485,6 +487,22 @@ cc_library(
485487
],
486488
)
487489

490+
py_strict_test(
491+
name = "scatter_kernel_emitter_test",
492+
srcs = ["scatter_kernel_emitter_test.py"],
493+
tags = [
494+
"no_oss",
495+
],
496+
deps = [
497+
"//third_party/py/numpy",
498+
"//xla/backends/cpu/testlib",
499+
"//xla/codegen/testlib",
500+
"//xla/python:xla_extension",
501+
"@absl_py//absl/testing:absltest",
502+
"@absl_py//absl/testing:parameterized",
503+
],
504+
)
505+
488506
xla_cc_test(
489507
name = "object_loader_test",
490508
srcs = ["object_loader_test.cc"],

xla/backends/cpu/codegen/emitters/BUILD

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ cc_library(
3434
"//xla:util",
3535
"//xla:xla_data_proto_cc",
3636
"//xla/backends/cpu:alignment",
37+
"//xla/backends/cpu/codegen:fusion_compiler",
3738
"//xla/backends/cpu/codegen:kernel_api_ir_builder",
3839
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
40+
"//xla/codegen:kernel_definition",
41+
"//xla/codegen:kernel_emitter",
42+
"//xla/codegen:kernel_spec",
43+
"//xla/codegen:mlir_kernel_source",
3944
"//xla/codegen/emitters:computation_partitioner",
4045
"//xla/codegen/emitters:elemental_hlo_to_mlir",
4146
"//xla/codegen/emitters:type_util",
@@ -51,6 +56,7 @@ cc_library(
5156
"//xla/service:scatter_simplifier",
5257
"//xla/service/cpu:backend_config_proto_cc",
5358
"//xla/service/llvm_ir:llvm_util",
59+
"//xla/stream_executor:launch_dim",
5460
"//xla/tsl/platform:errors",
5561
"//xla/tsl/platform:statusor",
5662
"@com_google_absl//absl/algorithm:container",
@@ -93,6 +99,9 @@ xla_cc_test(
9399
deps = [
94100
":cpu_fusion_emitters",
95101
"//xla/backends/cpu/codegen:fusion_compiler",
102+
"//xla/codegen:kernel_definition",
103+
"//xla/codegen:llvm_ir_kernel_source",
104+
"//xla/codegen:mlir_kernel_source",
96105
"//xla/hlo/analysis:hlo_ordering",
97106
"//xla/hlo/ir:hlo",
98107
"//xla/hlo/testlib:filecheck",
@@ -101,7 +110,6 @@ xla_cc_test(
101110
"//xla/tests:hlo_test_base",
102111
"//xla/tests:xla_internal_test_main",
103112
"//xla/tsl/platform:statusor",
104-
"@com_google_absl//absl/container:flat_hash_set",
105113
"@com_google_absl//absl/status:statusor",
106114
"@com_google_absl//absl/strings:string_view",
107115
"@com_google_googletest//:gtest",
@@ -111,5 +119,6 @@ xla_cc_test(
111119
"@llvm-project//mlir:IR",
112120
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
113121
"@llvm-project//mlir:Pass",
122+
"@tsl//tsl/platform:casts",
114123
],
115124
)

xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "absl/container/flat_hash_set.h"
2727
#include "absl/log/check.h"
2828
#include "absl/log/log.h"
29+
#include "absl/status/status.h"
2930
#include "absl/status/statusor.h"
3031
#include "absl/strings/str_cat.h"
3132
#include "absl/types/span.h"
@@ -39,6 +40,7 @@ limitations under the License.
3940
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
4041
#include "mlir/Dialect/DLTI/DLTI.h"
4142
#include "mlir/Dialect/Func/IR/FuncOps.h"
43+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
4244
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
4345
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
4446
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -128,6 +130,33 @@ bool Needs64BitIndices(const HloComputation* computation) {
128130
}
129131
return false;
130132
}
133+
134+
absl::Status SetKernelFunctionAttributes(const HloFusionInstruction& fusion,
135+
mlir::Builder& builder,
136+
mlir::func::FuncOp& func) {
137+
const HloModule* hlo_module = fusion.GetModule();
138+
if (hlo_module == nullptr) {
139+
return Internal("HloModule is null");
140+
}
141+
142+
mlir::MLIRContext* context = func->getContext();
143+
144+
// This is a hack until https://github.com/llvm/llvm-project/pull/135811 is
145+
// merged, the value "2" corresponds to the default enum value.
146+
mlir::ArrayAttr uwtable_attr = builder.getStrArrayAttr({"uwtable", "2"});
147+
int32_t vector_width =
148+
hlo_module->config().debug_options().xla_cpu_prefer_vector_width();
149+
mlir::ArrayAttr prefer_vector_width_attr = builder.getStrArrayAttr(
150+
{"prefer-vector-width", absl::StrCat(vector_width)});
151+
func->setAttr("passthrough",
152+
builder.getArrayAttr({uwtable_attr, prefer_vector_width_attr}));
153+
func->setAttr(
154+
"frame_pointer",
155+
mlir::LLVM::FramePointerKindAttr::get(
156+
context, mlir::LLVM::framePointerKind::FramePointerKind::All));
157+
158+
return absl::OkStatus();
159+
}
131160
} // namespace
132161

133162
using mlir::AffineExpr;
@@ -239,6 +268,10 @@ absl::StatusOr<mlir::func::FuncOp> EmitFusionKernelApi(
239268
loc, fusion.name(),
240269
builder.getFunctionType(/*arg_types=*/{call_frame_type},
241270
/*result_types=*/{error_type}));
271+
272+
TF_RETURN_IF_ERROR(
273+
SetKernelFunctionAttributes(fusion, builder, call_frame_func));
274+
242275
builder.setInsertionPointToStart(call_frame_func.addEntryBlock());
243276
mlir::Value call_frame_arg = call_frame_func.getArgument(0);
244277
SmallVector<mlir::Value> extracted_values;
@@ -324,39 +357,6 @@ void SetDataLayoutAttribute(mlir::ModuleOp module,
324357
mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout}));
325358
}
326359

327-
absl::StatusOr<absl::flat_hash_set<int64_t>> SetKernelFunctionAttributes(
328-
llvm::Module& module, const BufferAssignment& buffer_assignment,
329-
const HloFusionInstruction* fusion) {
330-
const HloModule* hlo_module = fusion->GetModule();
331-
if (hlo_module == nullptr) {
332-
return Internal("HloModule is null");
333-
}
334-
335-
// Create a Kernel API Builder and a throwaway kernel prototype in order to
336-
// extract useful info from them, e.g. noalias, invariant_arguments and
337-
// entry function attributes.
338-
// TODO(ecg): find a way to obtain the same info without wasting work by
339-
// creating a throwaway module. All of this additional info should probably be
340-
// explicit in the generated MLIR, not added afterwards like we're doing here.
341-
// TODO(ecg): some attributes on the final loads are missing wrt those
342-
// generated via KernelApiIrBuilder, e.g. noalias. Add them.
343-
llvm::LLVMContext& context = module.getContext();
344-
KernelApiIrBuilder kernel_api_ir_builder(
345-
context,
346-
KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config()));
347-
std::unique_ptr<llvm::Module> throwaway_llvm_module =
348-
KernelApiIrBuilder::CreateModule(
349-
absl::StrCat(fusion->name(), "_throwaway_module"), context);
350-
TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype,
351-
kernel_api_ir_builder.EmitKernelPrototype(
352-
*throwaway_llvm_module, fusion, &buffer_assignment,
353-
"_throwaway_kernel_prototype"));
354-
llvm::Function* kernel_function = module.getFunction(fusion->name());
355-
kernel_api_ir_builder.SetKernelFunctionAttributes(kernel_function);
356-
357-
return kernel_prototype.invariant_arguments;
358-
}
359-
360360
int64_t CeilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; }
361361
} // namespace cpu
362362
} // namespace xla

xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ absl::StatusOr<emitters::CallTargetProvider> EmitCallTargets(
6060
void SetDataLayoutAttribute(mlir::ModuleOp module,
6161
const HloFusionInstruction& fusion);
6262

63-
absl::StatusOr<absl::flat_hash_set<int64_t>> SetKernelFunctionAttributes(
64-
llvm::Module& module, const BufferAssignment& buffer_assignment,
65-
const HloFusionInstruction* fusion);
66-
6763
class CpuFusionEmitterBase {
6864
public:
6965
virtual ~CpuFusionEmitterBase() = default;

xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h"
17-
18-
#include <cstdint>
1916
#include <memory>
2017
#include <string>
18+
#include <utility>
2119

2220
#include <gtest/gtest.h>
23-
#include "absl/container/flat_hash_set.h"
2421
#include "absl/status/statusor.h"
2522
#include "absl/strings/string_view.h"
2623
#include "llvm/IR/LLVMContext.h"
@@ -32,6 +29,9 @@ limitations under the License.
3229
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
3330
#include "xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.h"
3431
#include "xla/backends/cpu/codegen/fusion_compiler.h"
32+
#include "xla/codegen/kernel_definition.h"
33+
#include "xla/codegen/llvm_ir_kernel_source.h"
34+
#include "xla/codegen/mlir_kernel_source.h"
3535
#include "xla/hlo/analysis/hlo_ordering.h"
3636
#include "xla/hlo/ir/hlo_casting_utils.h"
3737
#include "xla/hlo/ir/hlo_instructions.h"
@@ -41,6 +41,7 @@ limitations under the License.
4141
#include "xla/service/logical_buffer.h"
4242
#include "xla/tests/hlo_test_base.h"
4343
#include "xla/tsl/platform/statusor.h"
44+
#include "tsl/platform/casts.h"
4445

4546
namespace xla {
4647
namespace cpu {
@@ -62,18 +63,13 @@ std::string MlirModuleToString(const mlir::ModuleOp& module) {
6263

6364
class CpuFusionEmitterTest : public HloTestBase {
6465
protected:
65-
CpuFusionEmitterTest() : mlir_context_(FusionCompiler::CreateContext()) {}
66-
6766
absl::StatusOr<std::unique_ptr<BufferAssignment>> RunBufferAssignment(
6867
const HloModule& hlo) {
6968
return BufferAssigner::Run(
7069
&hlo, std::make_unique<DependencyHloOrdering>(&hlo),
7170
backend().compiler()->BufferSizeBytesFunction(),
7271
[](LogicalBuffer::Color) { return /*alignment=*/1; });
7372
}
74-
75-
std::unique_ptr<mlir::MLIRContext> mlir_context_;
76-
llvm::LLVMContext llvm_context_;
7773
};
7874

7975
static constexpr absl::string_view kScatterHlo = R"(
@@ -136,10 +132,12 @@ TEST_F(CpuFusionEmitterTest, ScatterMlir) {
136132
RunBufferAssignment(*hlo_module));
137133
auto fusion = Cast<HloFusionInstruction>(
138134
hlo_module->entry_computation()->root_instruction());
139-
CpuScatterFusion emitter(mlir_context_.get(), &llvm_context_,
140-
*buffer_assignment, fusion);
141-
TF_ASSERT_OK_AND_ASSIGN(auto mlir_module, emitter.Emit());
142-
auto mlir_dump = MlirModuleToString(*mlir_module);
135+
CpuScatterFusion emitter(*buffer_assignment, fusion);
136+
TF_ASSERT_OK_AND_ASSIGN(KernelDefinition kernel_definition,
137+
emitter.EmitKernelDefinition());
138+
const auto& mlir_source =
139+
tsl::down_cast<const MlirKernelSource&>(kernel_definition.source());
140+
auto mlir_dump = mlir_source.ToString();
143141
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
144142
RunFileCheck(mlir_dump, kExpected));
145143
EXPECT_TRUE(filecheck_matched);
@@ -162,17 +160,15 @@ TEST_F(CpuFusionEmitterTest, ScatterLlvm) {
162160
RunBufferAssignment(*hlo_module));
163161
auto fusion = Cast<HloFusionInstruction>(
164162
hlo_module->entry_computation()->root_instruction());
165-
CpuScatterFusion emitter(mlir_context_.get(), &llvm_context_,
166-
*buffer_assignment, fusion);
167-
TF_ASSERT_OK_AND_ASSIGN(auto mlir_module, emitter.Emit());
163+
CpuScatterFusion emitter(*buffer_assignment, fusion);
164+
TF_ASSERT_OK_AND_ASSIGN(KernelDefinition kernel_definition,
165+
emitter.EmitKernelDefinition());
166+
auto [spec, source] = std::move(kernel_definition).release();
167+
auto& mlir_source = tsl::down_cast<MlirKernelSource&>(*source);
168168
FusionCompiler compiler(FusionCompiler::Options{});
169-
llvm::LLVMContext llvm_context;
170-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<llvm::Module> llvm_module,
171-
compiler.Compile(llvm_context, mlir_module.get()));
172-
TF_ASSERT_OK_AND_ASSIGN(
173-
absl::flat_hash_set<int64_t> invariant_arguments,
174-
SetKernelFunctionAttributes(*llvm_module, *buffer_assignment, fusion));
175-
auto llvm_dump = LlvmModuleToString(*llvm_module);
169+
TF_ASSERT_OK_AND_ASSIGN(LlvmIrKernelSource llvm_source,
170+
compiler.Compile(std::move(mlir_source)));
171+
auto llvm_dump = llvm_source.ToString();
176172
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
177173
RunFileCheck(llvm_dump, kExpected));
178174
EXPECT_TRUE(filecheck_matched);

0 commit comments

Comments
 (0)