1- #include < torch/csrc/inductor/aoti_torch/c/shim.h>
1+ #include < libtorchaudio/stable/ops.h>
2+ #include < libtorchaudio/utils.h>
23#include < torch/csrc/stable/library.h>
3- #include < torch/csrc/stable/ops.h>
44#include < torch/csrc/stable/tensor.h>
5- #include < torch/script.h>
6- #include < torch/torch.h>
7-
8- using namespace std ;
5+ #include < torch/headeronly/core/Dispatch_v2.h>
6+ #include < torch/headeronly/core/ScalarType.h>
97
108namespace torchaudio {
119namespace alignment {
1210namespace cpu {
11+
12+ using torch::headeronly::ScalarType;
13+ using torch::stable::Tensor;
14+
1315// Inspired from
1416// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
15- template <typename scalar_t , at:: ScalarType target_scalar_type>
17+ template <typename scalar_t , ScalarType target_scalar_type>
1618void forced_align_impl (
17- const torch:: Tensor& logProbs,
18- const torch:: Tensor& targets,
19+ const Tensor& logProbs,
20+ const Tensor& targets,
1921 const int64_t blank,
20- torch:: Tensor& paths) {
22+ Tensor& paths) {
2123 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
2224 using target_t = typename std::
23- conditional<target_scalar_type == torch:: kInt , int , int64_t >::type;
25+ conditional<target_scalar_type == ScalarType::Int , int , int64_t >::type;
2426 const auto batchIndex =
2527 0 ; // TODO: support batch version and use the real batch index
2628 const auto T = logProbs.size (1 );
@@ -36,17 +38,16 @@ void forced_align_impl(
3638 for (int i = 0 ; i < T * S; i++) {
3739 backPtr_a[i] = -1 ;
3840 }
39-
40- auto logProbs_a = logProbs.accessor <scalar_t , 3 >();
41- auto targets_a = targets.accessor <target_t , 2 >();
42- auto paths_a = paths.accessor <target_t , 2 >();
41+ auto logProbs_a = torchaudio::stable::accessor<scalar_t , 3 >(logProbs);
42+ auto targets_a = torchaudio::stable::accessor<target_t , 2 >(targets);
43+ auto paths_a = torchaudio::stable::accessor<target_t , 2 >(paths);
4344 auto R = 0 ;
4445 for (auto i = 1 ; i < L; i++) {
4546 if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1 ]) {
4647 ++R;
4748 }
4849 }
49- TORCH_CHECK (
50+ STD_TORCH_CHECK (
5051 T >= L + R,
5152 " targets length is too long for CTC. Found log_probs length: " ,
5253 T,
@@ -138,73 +139,109 @@ void forced_align_impl(
138139 delete[] backPtr_a;
139140}
140141
141- std::tuple<torch::Tensor, torch::Tensor> compute (
142- const torch::Tensor& logProbs,
143- const torch::Tensor& targets,
144- const torch::Tensor& inputLengths,
145- const torch::Tensor& targetLengths,
142+ template <typename scalar_t >
143+ const auto forced_align_long_impl =
144+ forced_align_impl<scalar_t , ScalarType::Long>;
145+
146+ template <typename scalar_t >
147+ const auto forced_align_int_impl = forced_align_impl<scalar_t , ScalarType::Int>;
148+
149+ std::tuple<Tensor, Tensor> compute (
150+ const Tensor& logProbs,
151+ const Tensor& targets,
152+ const Tensor& inputLengths,
153+ const Tensor& targetLengths,
146154 const int64_t blank) {
147- TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
148- TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
149- TORCH_CHECK (
150- logProbs. device () == targets. device (),
151- " log_probs and targets need to be on the same device " );
152- TORCH_CHECK (
153- logProbs.dtype () == torch:: kFloat64 ||
154- logProbs.dtype () == torch:: kFloat32 ||
155- logProbs.dtype () == torch:: kFloat16 ,
155+ STD_TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
156+ STD_TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
157+ STD_TORCH_CHECK (inputLengths. is_cpu (), " input_lengths must be a CPU tensor " );
158+ STD_TORCH_CHECK (
159+ targetLengths. is_cpu (), " target_lengths must be a CPU tensor " );
160+ STD_TORCH_CHECK (
161+ logProbs.scalar_type () == ScalarType::Double ||
162+ logProbs.scalar_type () == ScalarType::Float ||
163+ logProbs.scalar_type () == ScalarType::Half ,
156164 " log_probs must be float64, float32 or float16 (half) type" );
157- TORCH_CHECK (
158- targets.dtype () == torch::kInt32 || targets.dtype () == torch::kInt64 ,
165+ STD_TORCH_CHECK (
166+ targets.scalar_type () == ScalarType::Int ||
167+ targets.scalar_type () == ScalarType::Long,
159168 " targets must be int32 or int64 type" );
160- TORCH_CHECK (logProbs.is_contiguous (), " log_probs must be contiguous" );
161- TORCH_CHECK (targets.is_contiguous (), " targets must be contiguous" );
162- TORCH_CHECK (
169+ STD_TORCH_CHECK (logProbs.is_contiguous (), " log_probs must be contiguous" );
170+ STD_TORCH_CHECK (targets.is_contiguous (), " targets must be contiguous" );
171+ STD_TORCH_CHECK (
163172 logProbs.dim () == 3 ,
164173 " log_probs must be 3-D (batch_size, input length, num classes)" );
165- TORCH_CHECK (
174+ STD_TORCH_CHECK (
166175 targets.dim () == 2 , " targets must be 2-D (batch_size, target length,)" );
167- TORCH_CHECK (
176+ STD_TORCH_CHECK (
168177 inputLengths.dim () == 1 , " input_lengths must be 1-D (batch_size,)" );
169- TORCH_CHECK (
178+ STD_TORCH_CHECK (
170179 targetLengths.dim () == 1 , " target_lengths must be 1-D (batch_size,)" );
171- TORCH_CHECK (
180+ STD_TORCH_CHECK (
172181 logProbs.size (0 ) == 1 ,
173182 " The batch dimension for log_probs must be 1 at the current version." )
174- TORCH_CHECK (
183+ STD_TORCH_CHECK (
175184 targets.size (0 ) == 1 ,
176185 " The batch dimension for targets must be 1 at the current version." )
177- TORCH_CHECK (
186+ STD_TORCH_CHECK (
178187 blank >= 0 && blank < logProbs.size (-1 ),
179188 " blank must be within [0, num classes)" );
180-
181- TORCH_CHECK (
182- logProbs.size (1 ) == at::max (inputLengths).item ().toInt (),
183- " input length mismatch" );
184- TORCH_CHECK (
185- targets.size (1 ) == at::max (targetLengths).item ().toInt (),
186- " target length mismatch" );
187-
189+ THO_DISPATCH_V2 (
190+ inputLengths.scalar_type (),
191+ " forced_align_impl" ,
192+ AT_WRAP ([&] {
193+ STD_TORCH_CHECK (
194+ logProbs.size (1 ) == torchaudio::util::max<scalar_t >(inputLengths),
195+ " input length mismatch" );
196+ }),
197+ ScalarType::Int,
198+ ScalarType::Long);
199+ THO_DISPATCH_V2 (
200+ targetLengths.scalar_type (),
201+ " forced_align_impl" ,
202+ AT_WRAP ([&] {
203+ STD_TORCH_CHECK (
204+ targets.size (1 ) == torchaudio::util::max<scalar_t >(targetLengths),
205+ " target length mismatch" );
206+ }),
207+ ScalarType::Int,
208+ ScalarType::Long);
188209 const auto B = logProbs.size (0 );
189210 const auto T = logProbs.size (1 );
190- auto paths = torch::zeros (
191- {B, T},
192- torch::TensorOptions ().device (targets.device ()).dtype (targets.dtype ()));
193- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
194- logProbs.scalar_type (), " forced_align_impl" , [&] {
195- if (targets.scalar_type () == torch::kInt64 ) {
196- forced_align_impl<scalar_t , torch::kInt64 >(
197- logProbs, targets, blank, paths);
211+ Tensor paths = torchaudio::stable::new_zeros (targets, {B, T});
212+ THO_DISPATCH_V2 (
213+ logProbs.scalar_type (),
214+ " forced_align_impl" ,
215+ AT_WRAP ([&] {
216+ if (targets.scalar_type () == ScalarType::Long) {
217+ forced_align_long_impl<scalar_t >(logProbs, targets, blank, paths);
198218 } else {
199- forced_align_impl<scalar_t , torch::kInt32 >(
200- logProbs, targets, blank, paths);
219+ forced_align_int_impl<scalar_t >(logProbs, targets, blank, paths);
201220 }
202- });
221+ }),
222+ AT_EXPAND (AT_FLOATING_TYPES),
223+ ScalarType::Half);
203224 return std::make_tuple (paths, logProbs);
204225}
205226
206- TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
207- m.impl (" forced_align" , &compute);
227+ void boxed_forced_align_cpu (
228+ StableIValue* stack,
229+ uint64_t num_args,
230+ uint64_t num_outputs) {
231+ STD_TORCH_CHECK (num_args == 5 , " num_args must be 5" );
232+ STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
233+ std::tuple<Tensor, Tensor> res = compute (
234+ /* logProbs*/ torch::stable::detail::to<Tensor>(stack[0 ]),
235+ /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
236+ /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
237+ /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
238+ /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])));
239+ stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
240+ stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
241+ }
242+
243+ STABLE_TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
244+ m.impl (" forced_align" , &boxed_forced_align_cpu);
208245}
209246
210247} // namespace cpu
0 commit comments