Skip to content

Commit 0fd1a6a

Browse files
add cuda guard for forced align module
1 parent 87ff22e commit 0fd1a6a

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/core/TensorAccessor.h>
22
#include <ATen/cuda/CUDAContext.h>
33
#include <c10/cuda/CUDAException.h>
4+
#include <c10/cuda/CUDAGuard.h>
45
#include <limits.h>
56
#include <torch/torch.h>
67
#include <cub/cub.cuh>
@@ -115,6 +116,10 @@ void forced_align_impl(
115116
const torch::Tensor& targets,
116117
const int64_t blank,
117118
torch::Tensor& paths) {
119+
120+
// only guard logProbs, since in L263 it'll be verified on targets
121+
const at::cuda::OptionalCUDAGuard device_guard(logProbs.device());
122+
118123
auto defaultStream = at::cuda::getCurrentCUDAStream();
119124
auto cpuDataTranferStream = at::cuda::getStreamFromPool();
120125
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();

0 commit comments

Comments
 (0)