From 7caab56ac1db50227eb878971f4adbfda6679e5b Mon Sep 17 00:00:00 2001 From: Serge Lu Date: Wed, 13 Nov 2024 13:25:52 +0800 Subject: [PATCH] Added missing set_device calls for nvidia_transform --- bitsandbytes/backends/cuda.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index ad478431c..af3c044b8 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -188,7 +188,10 @@ def transform( if HIP_ENVIRONMENT: # transform kernel formats (col32/col_turing/col_ampere) are not applicable to ROCm # Use nvidia_transform instead - return nvidia_transform(A, to_order, from_order, out, transpose, state, ld) + prev_device = pre_call(A.device) + out, new_state = nvidia_transform(A, to_order, from_order, out, transpose, state, ld) + post_call(prev_device) + return out, new_state prev_device = pre_call(A.device) if state is None: