File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
src/libtorchaudio/forced_align/gpu Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change 11#include < ATen/core/TensorAccessor.h>
22#include < ATen/cuda/CUDAContext.h>
33#include < c10/cuda/CUDAException.h>
4+ #include < c10/cuda/CUDAGuard.h>
45#include < limits.h>
56#include < torch/torch.h>
67#include < cub/cub.cuh>
@@ -115,6 +116,10 @@ void forced_align_impl(
115116 const torch::Tensor& targets,
116117 const int64_t blank,
117118 torch::Tensor& paths) {
119+
120+ // only guard logProbs, since in L263 it'll be verified on targets
121+ const at::cuda::OptionalCUDAGuard device_guard (logProbs.device ());
122+
118123 auto defaultStream = at::cuda::getCurrentCUDAStream ();
119124 auto cpuDataTranferStream = at::cuda::getStreamFromPool ();
120125 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
You can’t perform that action at this time.
0 commit comments