@@ -23,16 +23,17 @@ limitations under the License.
2323
2424#include " absl/algorithm/container.h"
2525#include " absl/container/flat_hash_map.h"
26+ #include " absl/container/flat_hash_set.h"
2627#include " absl/log/check.h"
2728#include " absl/log/log.h"
28- #include " absl/status/status.h"
2929#include " absl/status/statusor.h"
3030#include " absl/strings/str_cat.h"
3131#include " absl/types/span.h"
3232#include " llvm/ADT/STLExtras.h"
3333#include " llvm/ADT/SmallVector.h"
3434#include " llvm/IR/Function.h"
3535#include " llvm/IR/Instructions.h"
36+ #include " llvm/IR/LLVMContext.h"
3637#include " llvm/Linker/Linker.h"
3738#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
3839#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -60,7 +61,6 @@ limitations under the License.
6061#include " xla/backends/cpu/alignment.h"
6162#include " xla/backends/cpu/codegen/emitters/ir/xla_cpu_ops.h"
6263#include " xla/backends/cpu/codegen/emitters/ir/xla_cpu_types.h"
63- #include " xla/backends/cpu/codegen/fusion_compiler.h"
6464#include " xla/backends/cpu/codegen/kernel_api_ir_builder.h"
6565#include " xla/codegen/emitters/computation_partitioner.h"
6666#include " xla/codegen/emitters/elemental_hlo_to_mlir.h"
@@ -76,7 +76,6 @@ limitations under the License.
7676#include " xla/mlir_hlo/mhlo/transforms/passes.h"
7777#include " xla/service/buffer_assignment.h"
7878#include " xla/service/dump.h"
79- #include " xla/service/llvm_ir/llvm_util.h"
8079#include " xla/shape.h"
8180#include " xla/shape_util.h"
8281#include " xla/tsl/platform/errors.h"
@@ -131,86 +130,6 @@ bool Needs64BitIndices(const HloComputation* computation) {
131130}
132131} // namespace
133132
134- absl::StatusOr<CpuFusionEmissionResult> CpuFusionEmitterBase::Emit () const {
135- // Single-threaded for now.
136- TF_ASSIGN_OR_RETURN (auto module ,
137- CreateLLVMModule (*mlir_context_, *llvm_context_, *fusion_,
138- buffer_assignment_));
139-
140- const HloModule* hlo_module = fusion_->GetModule ();
141- if (hlo_module == nullptr ) {
142- return Internal (" HloModule is null" );
143- }
144- // Create a Kernel API Builder and a throwaway kernel prototype in order to
145- // extract useful info from them, e.g. noalias, invariant_arguments and
146- // entry function attributes.
147- // TODO(ecg): find a way to obtain the same info without wasting work by
148- // creating a throwaway module. All of this additional info should probably be
149- // explicit in the generated MLIR, not added afterwards like we're doing here.
150- // TODO(ecg): some attributes on the final loads are missing wrt those
151- // generated via KernelApiIrBuilder, e.g. noalias. Add them.
152- KernelApiIrBuilder kernel_api_ir_builder (
153- *llvm_context_,
154- KernelApiIrBuilder::Options::FromHloModuleConfig (hlo_module->config ()));
155- std::unique_ptr<llvm::Module> throwaway_llvm_module =
156- KernelApiIrBuilder::CreateModule (
157- absl::StrCat (fusion_->name (), " _throwaway_module" ), *llvm_context_);
158- TF_ASSIGN_OR_RETURN (KernelApiIrBuilder::KernelPrototype kernel_prototype,
159- kernel_api_ir_builder.EmitKernelPrototype (
160- *throwaway_llvm_module, fusion_, &buffer_assignment_,
161- " _throwaway_kernel_prototype" ));
162- llvm::Function* kernel_function = module ->getFunction (fusion_->name ());
163- kernel_api_ir_builder.SetKernelFunctionAttributes (kernel_function);
164-
165- CpuFusionEmissionResult result;
166- result.llvm_module = std::move (module );
167- result.invariant_arguments = std::move (kernel_prototype.invariant_arguments );
168- return result;
169- }
170-
171- absl::StatusOr<std::unique_ptr<llvm::Module>>
172- CpuFusionEmitterBase::CreateLLVMModule (
173- mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context,
174- const HloFusionInstruction& fusion,
175- const BufferAssignment& buffer_assignment) const {
176- TF_ASSIGN_OR_RETURN (auto module ,
177- CreateMLIRModule (mlir_context, fusion,
178- std::string (fusion.name ()) + " _entry" ,
179- buffer_assignment));
180-
181- FusionCompiler compiler (FusionCompiler::Options{});
182- return compiler.Compile (llvm_context, module .get ());
183- }
184-
185- absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
186- CpuFusionEmitterBase::CreateMLIRModule (
187- mlir::MLIRContext& context, const HloFusionInstruction& fusion,
188- const std::string& entry_function_name,
189- const BufferAssignment& buffer_assignment,
190- mlir::interpreter::MlirCompilationTrace* trace) const {
191- mlir::OpBuilder builder (&context);
192- auto loc = mlir::NameLoc::get (builder.getStringAttr (fusion.name ()));
193- mlir::OwningOpRef<mlir::ModuleOp> module = llvm_ir::CreateMlirModuleOp (loc);
194- SetDataLayoutAttribute (module .get (), fusion);
195-
196- TF_ASSIGN_OR_RETURN (
197- mlir::func::FuncOp entry_func,
198- EmitFusionKernelApi (module .get (), fusion, entry_function_name,
199- buffer_assignment));
200-
201- std::vector<emitters::EpilogueSpecification> epilogues =
202- GetEpilogues (fusion, &context);
203- emitters::PartitionedComputations computations (
204- fusion.fused_instructions_computation (), &context, epilogues);
205- TF_ASSIGN_OR_RETURN (
206- emitters::CallTargetProvider call_targets,
207- EmitCallTargets (module .get (), fusion, computations, epilogues));
208-
209- TF_RETURN_IF_ERROR (
210- EmitEntryFunction (computations, call_targets, entry_func, fusion));
211- return module ;
212- }
213-
214133using mlir::AffineExpr;
215134
216135IndexingMap GetDefaultIndexingMap (absl::Span<const int64_t > thread_tile_sizes,
@@ -405,6 +324,39 @@ void SetDataLayoutAttribute(mlir::ModuleOp module,
405324 mlir::DataLayoutSpecAttr::get (module ->getContext (), {index_layout}));
406325}
407326
327+ absl::StatusOr<absl::flat_hash_set<int64_t >> SetKernelFunctionAttributes (
328+ llvm::Module& module , const BufferAssignment& buffer_assignment,
329+ const HloFusionInstruction* fusion) {
330+ const HloModule* hlo_module = fusion->GetModule ();
331+ if (hlo_module == nullptr ) {
332+ return Internal (" HloModule is null" );
333+ }
334+
335+ // Create a Kernel API Builder and a throwaway kernel prototype in order to
336+ // extract useful info from them, e.g. noalias, invariant_arguments and
337+ // entry function attributes.
338+ // TODO(ecg): find a way to obtain the same info without wasting work by
339+ // creating a throwaway module. All of this additional info should probably be
340+ // explicit in the generated MLIR, not added afterwards like we're doing here.
341+ // TODO(ecg): some attributes on the final loads are missing wrt those
342+ // generated via KernelApiIrBuilder, e.g. noalias. Add them.
343+ llvm::LLVMContext& context = module .getContext ();
344+ KernelApiIrBuilder kernel_api_ir_builder (
345+ context,
346+ KernelApiIrBuilder::Options::FromHloModuleConfig (hlo_module->config ()));
347+ std::unique_ptr<llvm::Module> throwaway_llvm_module =
348+ KernelApiIrBuilder::CreateModule (
349+ absl::StrCat (fusion->name (), " _throwaway_module" ), context);
350+ TF_ASSIGN_OR_RETURN (KernelApiIrBuilder::KernelPrototype kernel_prototype,
351+ kernel_api_ir_builder.EmitKernelPrototype (
352+ *throwaway_llvm_module, fusion, &buffer_assignment,
353+ " _throwaway_kernel_prototype" ));
354+ llvm::Function* kernel_function = module .getFunction (fusion->name ());
355+ kernel_api_ir_builder.SetKernelFunctionAttributes (kernel_function);
356+
357+ return kernel_prototype.invariant_arguments ;
358+ }
359+
408360int64_t CeilDiv (int64_t a, int64_t b) { return (a + b - 1 ) / b; }
409361} // namespace cpu
410362} // namespace xla
0 commit comments