Skip to content

Commit

Permalink
[ROCm] Add custom call handling by Triton.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed May 30, 2024
1 parent 72e7d77 commit 8a15082
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall(

absl::Status IrEmitterUnnested::EmitTritonCustomCall(
const HloCustomCallInstruction* instr) {
#if !GOOGLE_CUDA
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
return absl::UnimplementedError("Triton support requires CUDA");
#else
auto generate = [this, &instr]() -> absl::StatusOr<KernelReuseCache::Entry> {
Expand All @@ -1615,7 +1615,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
TF_ASSIGN_OR_RETURN(
auto result,
CompileTritonToLLVM(hlo_module->config(), hlo_module->name(),
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_compute_capability(),
ir_emitter_context_->gpu_device_info(), gemm_config,
triton_module.get(),
ir_emitter_context_->llvm_module(), mlir_context));
Expand Down

0 comments on commit 8a15082

Please sign in to comment.