Skip to content

Commit dca045a

Browse files
authored
[STABLE ABI] Eliminate ATen/cuda/ and c10/cuda/ includes. (#4140)
* Eliminate include c10/cuda/CUDAException.h * Eliminate ATen/cuda/CUDAContext.h and c10/cuda/CUDAGuard.h
1 parent 06776f8 commit dca045a

File tree

6 files changed

+98
-30
lines changed

6 files changed

+98
-30
lines changed

src/libtorchaudio/cuda_utils.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
3+
#include <cuda_runtime_api.h>
4+
#include <torch/csrc/stable/c/shim.h>
5+
#include <torch/csrc/stable/device.h>
6+
7+
namespace libtorchaudio::cuda {
8+
9+
inline cudaStream_t getCurrentCUDAStream(
10+
torch::stable::DeviceIndex device_index = -1) {
11+
void* stream_ptr = nullptr;
12+
TORCH_ERROR_CODE_CHECK(
13+
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
14+
return static_cast<cudaStream_t>(stream_ptr);
15+
}
16+
17+
inline void setCurrentCUDAStream(
18+
cudaStream_t stream,
19+
torch::stable::DeviceIndex device_index = -1) {
20+
TORCH_ERROR_CODE_CHECK(
21+
torch_set_current_cuda_stream(static_cast<void*>(stream), device_index));
22+
}
23+
24+
inline cudaStream_t getStreamFromPool(
25+
const bool isHighPriority = false,
26+
torch::stable::DeviceIndex device_index = -1) {
27+
void* stream_ptr = nullptr;
28+
TORCH_ERROR_CODE_CHECK(torch_get_cuda_stream_from_pool(
29+
isHighPriority, device_index, &stream_ptr));
30+
return static_cast<cudaStream_t>(stream_ptr);
31+
}
32+
33+
inline void synchronize(
34+
cudaStream_t stream,
35+
torch::stable::DeviceIndex device_index = -1) {
36+
TORCH_ERROR_CODE_CHECK(
37+
torch_cuda_stream_synchronize(static_cast<void*>(stream), device_index));
38+
}
39+
40+
} // namespace libtorchaudio::cuda

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
#include <libtorchaudio/cuda_utils.h>
12
#include <libtorchaudio/utils.h>
23
#include <torch/csrc/stable/library.h>
4+
#include <torch/csrc/stable/macros.h>
35
#include <torch/headeronly/core/Dispatch_v2.h>
46
#include <torch/headeronly/core/ScalarType.h>
57

@@ -119,8 +121,9 @@ void forced_align_impl(
119121
const Tensor& targets,
120122
const int64_t blank,
121123
Tensor& paths) {
122-
auto defaultStream = at::cuda::getCurrentCUDAStream();
123-
auto cpuDataTranferStream = at::cuda::getStreamFromPool();
124+
auto device_index = logProbs.get_device_index();
125+
auto defaultStream = libtorchaudio::cuda::getCurrentCUDAStream(device_index);
126+
auto cpuDataTranferStream = libtorchaudio::cuda::getStreamFromPool(false, device_index);
124127
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
125128
using target_t = typename std::
126129
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
@@ -204,29 +207,29 @@ void forced_align_impl(
204207
backPtrBufferLen,
205208
torchaudio::packed_accessor32<scalar_t, 2>(alphas),
206209
torchaudio::packed_accessor32<int8_t, 2>(backPtrBuffer));
207-
C10_CUDA_KERNEL_LAUNCH_CHECK();
210+
STD_CUDA_KERNEL_LAUNCH_CHECK();
208211
++backPtrBufferLen;
209212
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
210-
cpuDataTranferStream.synchronize();
213+
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
211214
// GPU -> GPU copy
212215
bufferCopy = torch::stable::clone(backPtrBuffer);
213216
STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()")
214-
defaultStream.synchronize();
215-
at::cuda::setCurrentCUDAStream(cpuDataTranferStream);
217+
libtorchaudio::cuda::synchronize(defaultStream, device_index);
218+
libtorchaudio::cuda::setCurrentCUDAStream(cpuDataTranferStream, device_index);
216219
// Copy ASYNC from GPU to CPU
217220
int64_t offset =
218221
static_cast<int64_t>(t + 1 - backPtrBufferLen) * S * sizeof(int8_t);
219-
C10_CUDA_CHECK(cudaMemcpyAsync(
222+
STD_CUDA_CHECK(cudaMemcpyAsync(
220223
static_cast<int8_t*>(backPtrCpu.data_ptr()) + offset,
221224
bufferCopy.data_ptr(),
222225
backPtrBufferLen * S * sizeof(int8_t),
223226
cudaMemcpyDeviceToHost,
224227
cpuDataTranferStream));
225-
at::cuda::setCurrentCUDAStream(defaultStream);
228+
libtorchaudio::cuda::setCurrentCUDAStream(defaultStream, device_index);
226229
backPtrBufferLen = 0;
227230
}
228231
}
229-
cpuDataTranferStream.synchronize();
232+
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
230233
auto alphasCpu = torchaudio::stable::cpu(alphas);
231234
auto alphasCpu_a = torchaudio::accessor<scalar_t, 2>(alphasCpu);
232235
int curIdxOffset = ((T - 1) % 2);

src/libtorchaudio/iir_cuda.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#include <libtorchaudio/utils.h>
2+
#include <torch/csrc/stable/accelerator.h>
3+
#include <torch/csrc/stable/macros.h>
24
#include <torch/headeronly/core/Dispatch_v2.h>
35
#include <torch/headeronly/core/ScalarType.h>
4-
#include <c10/cuda/CUDAException.h>
5-
#include <c10/cuda/CUDAGuard.h>
6-
#include <c10/core/DeviceGuard.h>
76

87
using torch::headeronly::ScalarType;
98
using torch::stable::Tensor;
@@ -65,8 +64,7 @@ Tensor cuda_lfilter_core_loop(
6564

6665
STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
6766

68-
const at::cuda::OptionalCUDAGuard device_guard(in.get_device_index());
69-
67+
const torch::stable::accelerator::DeviceGuard device_guard(in.get_device_index());
7068
const dim3 threads(256);
7169
const dim3 blocks((N * C + threads.x - 1) / threads.x);
7270

@@ -76,7 +74,7 @@ Tensor cuda_lfilter_core_loop(
7674
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
7775
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
7876
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
79-
C10_CUDA_KERNEL_LAUNCH_CHECK();
77+
STD_CUDA_KERNEL_LAUNCH_CHECK();
8078
}), AT_FLOATING_TYPES);
8179
return padded_out;
8280
}

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
2+
#include <libtorchaudio/cuda_utils.h>
3+
#include <libtorchaudio/stable/ops.h>
24

3-
#include <c10/cuda/CUDAException.h>
4-
#include <c10/cuda/CUDAStream.h>
55
#include <torch/csrc/stable/library.h>
66
#include <torch/csrc/stable/ops.h>
77
#include <torch/headeronly/core/Dispatch_v2.h>
@@ -76,9 +76,8 @@ std::tuple<Tensor, Tensor> compute(
7676
"blank must be within [0, logits.shape[-1])");
7777

7878
auto max_ivalue = [](const Tensor& t) {
79-
int32_t value;
80-
C10_CUDA_CHECK(cudaMemcpy(&value, torch::stable::amax(t, {}).data_ptr(), sizeof(int32_t), cudaMemcpyDeviceToHost));
81-
return value;
79+
auto mx = torchaudio::stable::cpu(torch::stable::amax(t, {}));
80+
return reinterpret_cast<int32_t*>(mx.data_ptr())[0];
8281
};
8382

8483
STD_TORCH_CHECK(
@@ -100,7 +99,7 @@ std::tuple<Tensor, Tensor> compute(
10099
options.blank_ = blank;
101100
options.clamp_ = clamp;
102101
options.fusedLogSmax_ = fused_log_softmax;
103-
options.stream_ = at::cuda::getCurrentCUDAStream();
102+
options.stream_ = libtorchaudio::cuda::getCurrentCUDAStream(logits.get_device_index());
104103
cudaSetDevice(logits.get_device());
105104
options.device_ = GPU;
106105

src/libtorchaudio/shim_temporary.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
// TODO: remove this file once https://github.com/pytorch/pytorch/pull/169376
3+
// has landed in nightly.
4+
5+
#include <c10/cuda/CUDAStream.h>
6+
#include <torch/csrc/inductor/aoti_torch/utils.h>
7+
#include <torch/csrc/stable/c/shim.h>
8+
9+
inline AOTITorchError tmp_torch_set_current_cuda_stream(
10+
void* stream,
11+
int32_t device_index) {
12+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
13+
at::cuda::setCurrentCUDAStream(at::cuda::getStreamFromExternal(
14+
static_cast<cudaStream_t>(stream), device_index));
15+
});
16+
}
17+
18+
inline AOTITorchError tmp_torch_get_cuda_stream_from_pool(
19+
const bool isHighPriority,
20+
int32_t device_index,
21+
void** ret_stream) {
22+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
23+
*(cudaStream_t*)(ret_stream) =
24+
at::cuda::getStreamFromPool(isHighPriority, device_index);
25+
});
26+
}
27+
28+
inline AOTITorchError tmp_torch_cuda_stream_synchronize(
29+
void* stream,
30+
int32_t device_index) {
31+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
32+
at::cuda::getStreamFromExternal(
33+
static_cast<cudaStream_t>(stream), device_index)
34+
.synchronize();
35+
});
36+
}

src/libtorchaudio/stable/ops.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
#include <torch/csrc/stable/ops.h>
1313
#include <torch/csrc/stable/tensor.h>
1414

15-
#ifdef USE_CUDA
16-
#include <ATen/cuda/CUDAContext.h>
17-
#include <c10/cuda/CUDAException.h>
18-
#endif
19-
2015
namespace torchaudio::stable {
2116

2217
using torch::stable::Tensor;
@@ -83,10 +78,7 @@ T item(const Tensor& self) {
8378
return reinterpret_cast<const T*>(self.const_data_ptr())[0];
8479
#ifdef USE_CUDA
8580
} else if (self.is_cuda()) {
86-
T value;
87-
C10_CUDA_CHECK(cudaMemcpyAsync(
88-
&value, self.data_ptr(), sizeof(T), cudaMemcpyDeviceToHost));
89-
return value;
81+
return torchaudio::stable::item<T>(torchaudio::stable::cpu(self));
9082
#endif
9183
} else {
9284
STD_TORCH_CHECK(false, "unreachable"); // not implemented

0 commit comments

Comments
 (0)