diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 9ba1be41febde..76fc9e64d7ae2 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -1590,7 +1590,7 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall( absl::Status IrEmitterUnnested::EmitTritonCustomCall( const HloCustomCallInstruction* instr) { -#if !GOOGLE_CUDA +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM return absl::UnimplementedError("Triton support requires CUDA"); #else auto generate = [this, &instr]() -> absl::StatusOr { @@ -1615,7 +1615,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( TF_ASSIGN_OR_RETURN( auto result, CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), - ir_emitter_context_->cuda_compute_capability(), + ir_emitter_context_->gpu_compute_capability(), ir_emitter_context_->gpu_device_info(), gemm_config, triton_module.get(), ir_emitter_context_->llvm_module(), mlir_context));