Skip to content

Commit ad59fdf

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:CPU] Make CpuFusionEmitterBase a pure interface class
PiperOrigin-RevId: 745944794
1 parent de0d40e commit ad59fdf

File tree

7 files changed

+125
-157
lines changed

7 files changed

+125
-157
lines changed

xla/backends/cpu/codegen/emitters/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ cc_library(
3434
"//xla:util",
3535
"//xla:xla_data_proto_cc",
3636
"//xla/backends/cpu:alignment",
37-
"//xla/backends/cpu/codegen:fusion_compiler",
3837
"//xla/backends/cpu/codegen:kernel_api_ir_builder",
3938
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
4039
"//xla/codegen/emitters:computation_partitioner",
@@ -101,6 +100,7 @@ xla_cc_test(
101100
"//xla/tests:hlo_test_base",
102101
"//xla/tests:xla_internal_test_main",
103102
"//xla/tsl/platform:statusor",
103+
"@com_google_absl//absl/container:flat_hash_set",
104104
"@com_google_absl//absl/status:statusor",
105105
"@com_google_absl//absl/strings:string_view",
106106
"@com_google_googletest//:gtest",

xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc

Lines changed: 35 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ limitations under the License.
2323

2424
#include "absl/algorithm/container.h"
2525
#include "absl/container/flat_hash_map.h"
26+
#include "absl/container/flat_hash_set.h"
2627
#include "absl/log/check.h"
2728
#include "absl/log/log.h"
28-
#include "absl/status/status.h"
2929
#include "absl/status/statusor.h"
3030
#include "absl/strings/str_cat.h"
3131
#include "absl/types/span.h"
3232
#include "llvm/ADT/STLExtras.h"
3333
#include "llvm/ADT/SmallVector.h"
3434
#include "llvm/IR/Function.h"
3535
#include "llvm/IR/Instructions.h"
36+
#include "llvm/IR/LLVMContext.h"
3637
#include "llvm/Linker/Linker.h"
3738
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
3839
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -60,7 +61,6 @@ limitations under the License.
6061
#include "xla/backends/cpu/alignment.h"
6162
#include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_ops.h"
6263
#include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_types.h"
63-
#include "xla/backends/cpu/codegen/fusion_compiler.h"
6464
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
6565
#include "xla/codegen/emitters/computation_partitioner.h"
6666
#include "xla/codegen/emitters/elemental_hlo_to_mlir.h"
@@ -76,7 +76,6 @@ limitations under the License.
7676
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
7777
#include "xla/service/buffer_assignment.h"
7878
#include "xla/service/dump.h"
79-
#include "xla/service/llvm_ir/llvm_util.h"
8079
#include "xla/shape.h"
8180
#include "xla/shape_util.h"
8281
#include "xla/tsl/platform/errors.h"
@@ -131,86 +130,6 @@ bool Needs64BitIndices(const HloComputation* computation) {
131130
}
132131
} // namespace
133132

134-
absl::StatusOr<CpuFusionEmissionResult> CpuFusionEmitterBase::Emit() const {
135-
// Single-threaded for now.
136-
TF_ASSIGN_OR_RETURN(auto module,
137-
CreateLLVMModule(*mlir_context_, *llvm_context_, *fusion_,
138-
buffer_assignment_));
139-
140-
const HloModule* hlo_module = fusion_->GetModule();
141-
if (hlo_module == nullptr) {
142-
return Internal("HloModule is null");
143-
}
144-
// Create a Kernel API Builder and a throwaway kernel prototype in order to
145-
// extract useful info from them, e.g. noalias, invariant_arguments and
146-
// entry function attributes.
147-
// TODO(ecg): find a way to obtain the same info without wasting work by
148-
// creating a throwaway module. All of this additional info should probably be
149-
// explicit in the generated MLIR, not added afterwards like we're doing here.
150-
// TODO(ecg): some attributes on the final loads are missing wrt those
151-
// generated via KernelApiIrBuilder, e.g. noalias. Add them.
152-
KernelApiIrBuilder kernel_api_ir_builder(
153-
*llvm_context_,
154-
KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config()));
155-
std::unique_ptr<llvm::Module> throwaway_llvm_module =
156-
KernelApiIrBuilder::CreateModule(
157-
absl::StrCat(fusion_->name(), "_throwaway_module"), *llvm_context_);
158-
TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype,
159-
kernel_api_ir_builder.EmitKernelPrototype(
160-
*throwaway_llvm_module, fusion_, &buffer_assignment_,
161-
"_throwaway_kernel_prototype"));
162-
llvm::Function* kernel_function = module->getFunction(fusion_->name());
163-
kernel_api_ir_builder.SetKernelFunctionAttributes(kernel_function);
164-
165-
CpuFusionEmissionResult result;
166-
result.llvm_module = std::move(module);
167-
result.invariant_arguments = std::move(kernel_prototype.invariant_arguments);
168-
return result;
169-
}
170-
171-
absl::StatusOr<std::unique_ptr<llvm::Module>>
172-
CpuFusionEmitterBase::CreateLLVMModule(
173-
mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context,
174-
const HloFusionInstruction& fusion,
175-
const BufferAssignment& buffer_assignment) const {
176-
TF_ASSIGN_OR_RETURN(auto module,
177-
CreateMLIRModule(mlir_context, fusion,
178-
std::string(fusion.name()) + "_entry",
179-
buffer_assignment));
180-
181-
FusionCompiler compiler(FusionCompiler::Options{});
182-
return compiler.Compile(llvm_context, module.get());
183-
}
184-
185-
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
186-
CpuFusionEmitterBase::CreateMLIRModule(
187-
mlir::MLIRContext& context, const HloFusionInstruction& fusion,
188-
const std::string& entry_function_name,
189-
const BufferAssignment& buffer_assignment,
190-
mlir::interpreter::MlirCompilationTrace* trace) const {
191-
mlir::OpBuilder builder(&context);
192-
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name()));
193-
mlir::OwningOpRef<mlir::ModuleOp> module = llvm_ir::CreateMlirModuleOp(loc);
194-
SetDataLayoutAttribute(module.get(), fusion);
195-
196-
TF_ASSIGN_OR_RETURN(
197-
mlir::func::FuncOp entry_func,
198-
EmitFusionKernelApi(module.get(), fusion, entry_function_name,
199-
buffer_assignment));
200-
201-
std::vector<emitters::EpilogueSpecification> epilogues =
202-
GetEpilogues(fusion, &context);
203-
emitters::PartitionedComputations computations(
204-
fusion.fused_instructions_computation(), &context, epilogues);
205-
TF_ASSIGN_OR_RETURN(
206-
emitters::CallTargetProvider call_targets,
207-
EmitCallTargets(module.get(), fusion, computations, epilogues));
208-
209-
TF_RETURN_IF_ERROR(
210-
EmitEntryFunction(computations, call_targets, entry_func, fusion));
211-
return module;
212-
}
213-
214133
using mlir::AffineExpr;
215134

216135
IndexingMap GetDefaultIndexingMap(absl::Span<const int64_t> thread_tile_sizes,
@@ -405,6 +324,39 @@ void SetDataLayoutAttribute(mlir::ModuleOp module,
405324
mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout}));
406325
}
407326

327+
absl::StatusOr<absl::flat_hash_set<int64_t>> SetKernelFunctionAttributes(
328+
llvm::Module& module, const BufferAssignment& buffer_assignment,
329+
const HloFusionInstruction* fusion) {
330+
const HloModule* hlo_module = fusion->GetModule();
331+
if (hlo_module == nullptr) {
332+
return Internal("HloModule is null");
333+
}
334+
335+
// Create a Kernel API Builder and a throwaway kernel prototype in order to
336+
// extract useful info from them, e.g. noalias, invariant_arguments and
337+
// entry function attributes.
338+
// TODO(ecg): find a way to obtain the same info without wasting work by
339+
// creating a throwaway module. All of this additional info should probably be
340+
// explicit in the generated MLIR, not added afterwards like we're doing here.
341+
// TODO(ecg): some attributes on the final loads are missing wrt those
342+
// generated via KernelApiIrBuilder, e.g. noalias. Add them.
343+
llvm::LLVMContext& context = module.getContext();
344+
KernelApiIrBuilder kernel_api_ir_builder(
345+
context,
346+
KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config()));
347+
std::unique_ptr<llvm::Module> throwaway_llvm_module =
348+
KernelApiIrBuilder::CreateModule(
349+
absl::StrCat(fusion->name(), "_throwaway_module"), context);
350+
TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype,
351+
kernel_api_ir_builder.EmitKernelPrototype(
352+
*throwaway_llvm_module, fusion, &buffer_assignment,
353+
"_throwaway_kernel_prototype"));
354+
llvm::Function* kernel_function = module.getFunction(fusion->name());
355+
kernel_api_ir_builder.SetKernelFunctionAttributes(kernel_function);
356+
357+
return kernel_prototype.invariant_arguments;
358+
}
359+
408360
int64_t CeilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; }
409361
} // namespace cpu
410362
} // namespace xla

xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,20 @@ limitations under the License.
1616
#define XLA_BACKENDS_CPU_CODEGEN_EMITTERS_CPU_FUSION_EMITTER_H_
1717

1818
#include <cstdint>
19-
#include <memory>
2019
#include <optional>
2120
#include <string>
2221
#include <vector>
2322

2423
#include "absl/container/flat_hash_set.h"
25-
#include "absl/status/status.h"
2624
#include "absl/status/statusor.h"
2725
#include "absl/types/span.h"
2826
#include "llvm/IR/LLVMContext.h"
2927
#include "llvm/IR/Module.h"
3028
#include "mlir/Dialect/Func/IR/FuncOps.h"
3129
#include "mlir/IR/AffineMap.h"
3230
#include "mlir/IR/BuiltinOps.h"
33-
#include "mlir/IR/ImplicitLocOpBuilder.h"
3431
#include "mlir/IR/MLIRContext.h"
3532
#include "mlir/IR/OwningOpRef.h"
36-
#include "mlir/IR/Value.h"
3733
#include "mlir/Pass/PassManager.h"
3834
#include "xla/codegen/emitters/computation_partitioner.h"
3935
#include "xla/hlo/analysis/indexing_map.h"
@@ -44,11 +40,6 @@ limitations under the License.
4440
namespace xla {
4541
namespace cpu {
4642

47-
struct CpuFusionEmissionResult {
48-
std::unique_ptr<llvm::Module> llvm_module;
49-
absl::flat_hash_set<int64_t> invariant_arguments;
50-
};
51-
5243
IndexingMap GetDefaultIndexingMap(absl::Span<const int64_t> thread_tile_sizes,
5344
absl::Span<const int64_t> shape,
5445
mlir::MLIRContext* mlir_context);
@@ -69,17 +60,12 @@ absl::StatusOr<emitters::CallTargetProvider> EmitCallTargets(
6960
void SetDataLayoutAttribute(mlir::ModuleOp module,
7061
const HloFusionInstruction& fusion);
7162

63+
absl::StatusOr<absl::flat_hash_set<int64_t>> SetKernelFunctionAttributes(
64+
llvm::Module& module, const BufferAssignment& buffer_assignment,
65+
const HloFusionInstruction* fusion);
66+
7267
class CpuFusionEmitterBase {
7368
public:
74-
CpuFusionEmitterBase(mlir::MLIRContext* mlir_context,
75-
llvm::LLVMContext* llvm_context,
76-
const BufferAssignment& buffer_assignment,
77-
const HloFusionInstruction* fusion)
78-
: mlir_context_(mlir_context),
79-
llvm_context_(llvm_context),
80-
buffer_assignment_(buffer_assignment),
81-
fusion_(fusion) {}
82-
8369
virtual ~CpuFusionEmitterBase() = default;
8470

8571
virtual int64_t num_threads() const = 0;
@@ -92,42 +78,7 @@ class CpuFusionEmitterBase {
9278

9379
virtual std::string BackendExtraOptions() { return {}; }
9480

95-
absl::StatusOr<CpuFusionEmissionResult> Emit() const;
96-
97-
// Visible for testing.
98-
absl::StatusOr<std::unique_ptr<llvm::Module>> CreateLLVMModule(
99-
mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context,
100-
const HloFusionInstruction& fusion,
101-
const BufferAssignment& buffer_assignment) const;
102-
103-
// Visible for testing.
104-
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateMLIRModule(
105-
mlir::MLIRContext& context, const HloFusionInstruction& fusion,
106-
const std::string& entry_function_name,
107-
const BufferAssignment& buffer_assignment,
108-
mlir::interpreter::MlirCompilationTrace* trace = nullptr) const;
109-
110-
protected:
111-
virtual absl::Status EmitEntryFunction(
112-
const emitters::PartitionedComputations& computations,
113-
const emitters::CallTargetProvider& call_targets,
114-
mlir::func::FuncOp entry_function,
115-
const HloFusionInstruction& fusion) const = 0;
116-
117-
virtual std::vector<emitters::EpilogueSpecification> GetEpilogues(
118-
const HloFusionInstruction& fusion,
119-
mlir::MLIRContext* mlir_context) const {
120-
// We don't actually support epilogues for scatter, but this is how we tell
121-
// the base class that we don't want it to generate code for the scatter.
122-
return {};
123-
}
124-
125-
mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const;
126-
127-
mlir::MLIRContext* mlir_context_;
128-
llvm::LLVMContext* llvm_context_;
129-
const BufferAssignment& buffer_assignment_;
130-
const HloFusionInstruction* fusion_;
81+
virtual absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Emit() const = 0;
13182
};
13283

13384
int64_t CeilDiv(int64_t a, int64_t b);

xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ limitations under the License.
1515

1616
#include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h"
1717

18+
#include <cstdint>
1819
#include <memory>
1920
#include <string>
2021

2122
#include <gtest/gtest.h>
23+
#include "absl/container/flat_hash_set.h"
2224
#include "absl/status/statusor.h"
2325
#include "absl/strings/string_view.h"
2426
#include "llvm/IR/LLVMContext.h"
@@ -135,11 +137,7 @@ TEST_F(CpuFusionEmitterTest, ScatterMlir) {
135137
hlo_module->entry_computation()->root_instruction());
136138
CpuScatterFusion emitter(mlir_context_.get(), &llvm_context_,
137139
*buffer_assignment, fusion);
138-
TF_ASSERT_OK_AND_ASSIGN(
139-
auto mlir_module,
140-
emitter.CreateMLIRModule(*mlir_context_, *fusion,
141-
std::string(fusion->name()) + "_entry",
142-
*buffer_assignment));
140+
TF_ASSERT_OK_AND_ASSIGN(auto mlir_module, emitter.Emit());
143141
auto mlir_dump = MlirModuleToString(*mlir_module);
144142
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
145143
RunFileCheck(mlir_dump, kExpected));
@@ -165,8 +163,15 @@ TEST_F(CpuFusionEmitterTest, ScatterLlvm) {
165163
hlo_module->entry_computation()->root_instruction());
166164
CpuScatterFusion emitter(mlir_context_.get(), &llvm_context_,
167165
*buffer_assignment, fusion);
168-
TF_ASSERT_OK_AND_ASSIGN(auto result, emitter.Emit());
169-
auto llvm_dump = LlvmModuleToString(*result.llvm_module);
166+
TF_ASSERT_OK_AND_ASSIGN(auto mlir_module, emitter.Emit());
167+
FusionCompiler compiler(FusionCompiler::Options{});
168+
llvm::LLVMContext llvm_context;
169+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<llvm::Module> llvm_module,
170+
compiler.Compile(llvm_context, mlir_module.get()));
171+
TF_ASSERT_OK_AND_ASSIGN(
172+
absl::flat_hash_set<int64_t> invariant_arguments,
173+
SetKernelFunctionAttributes(*llvm_module, *buffer_assignment, fusion));
174+
auto llvm_dump = LlvmModuleToString(*llvm_module);
170175
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
171176
RunFileCheck(llvm_dump, kExpected));
172177
EXPECT_TRUE(filecheck_matched);

xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ limitations under the License.
4040
#include "mlir/IR/BuiltinOps.h"
4141
#include "mlir/IR/ImplicitLocOpBuilder.h"
4242
#include "mlir/IR/Location.h"
43+
#include "mlir/IR/OwningOpRef.h"
4344
#include "mlir/IR/Value.h"
4445
#include "mlir/IR/ValueRange.h"
46+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
4547
#include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h"
4648
#include "xla/codegen/emitters/computation_partitioner.h"
4749
#include "xla/codegen/emitters/elemental_hlo_to_mlir.h"
@@ -55,9 +57,12 @@ limitations under the License.
5557
#include "xla/primitive_util.h"
5658
#include "xla/service/buffer_assignment.h"
5759
#include "xla/service/cpu/backend_config.pb.h"
60+
#include "xla/service/llvm_ir/llvm_util.h"
5861
#include "xla/service/scatter_simplifier.h"
5962
#include "xla/shape.h"
6063
#include "xla/shape_util.h"
64+
#include "xla/tsl/platform/errors.h"
65+
#include "xla/tsl/platform/statusor.h"
6166
#include "xla/util.h"
6267
#include "xla/xla_data.pb.h"
6368

@@ -173,8 +178,10 @@ CpuScatterFusion::CpuScatterFusion(mlir::MLIRContext* mlir_context,
173178
llvm::LLVMContext* llvm_context,
174179
const BufferAssignment& buffer_assignment,
175180
const HloFusionInstruction* fusion)
176-
: CpuFusionEmitterBase{mlir_context, llvm_context, buffer_assignment,
177-
fusion} {
181+
: mlir_context_(mlir_context),
182+
llvm_context_(llvm_context),
183+
buffer_assignment_(buffer_assignment),
184+
fusion_(fusion) {
178185
const auto* scatter = Cast<HloScatterInstruction>(
179186
fusion->fused_instructions_computation()->root_instruction());
180187
auto update_shape = scatter->scatter_updates().front()->shape();
@@ -236,6 +243,32 @@ IndexingMap GetScatterIndexingMap(
236243
{}, constraints);
237244
}
238245

246+
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CpuScatterFusion::Emit()
247+
const {
248+
mlir::OpBuilder builder(mlir_context_);
249+
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_->name()));
250+
mlir::OwningOpRef<mlir::ModuleOp> module = llvm_ir::CreateMlirModuleOp(loc);
251+
SetDataLayoutAttribute(module.get(), *fusion_);
252+
253+
TF_ASSIGN_OR_RETURN(
254+
mlir::func::FuncOp entry_func,
255+
EmitFusionKernelApi(module.get(), *fusion_,
256+
std::string(fusion_->name()) + "_entry",
257+
buffer_assignment_));
258+
259+
std::vector<emitters::EpilogueSpecification> epilogues =
260+
GetEpilogues(*fusion_, mlir_context_);
261+
emitters::PartitionedComputations computations(
262+
fusion_->fused_instructions_computation(), mlir_context_, epilogues);
263+
TF_ASSIGN_OR_RETURN(
264+
emitters::CallTargetProvider call_targets,
265+
EmitCallTargets(module.get(), *fusion_, computations, epilogues));
266+
267+
TF_RETURN_IF_ERROR(
268+
EmitEntryFunction(computations, call_targets, entry_func, *fusion_));
269+
return module;
270+
}
271+
239272
absl::Status CpuScatterFusion::EmitEntryFunction(
240273
const emitters::PartitionedComputations& computations,
241274
const emitters::CallTargetProvider& call_targets,

0 commit comments

Comments
 (0)