Skip to content

Commit ee1a135

Browse files
authored
[STABLE ABI] Port forced_align (#4079)
1 parent 1be5c8f commit ee1a135

File tree

7 files changed

+809
-154
lines changed

7 files changed

+809
-154
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 99 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
1-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
1+
#include <libtorchaudio/stable/ops.h>
2+
#include <libtorchaudio/utils.h>
23
#include <torch/csrc/stable/library.h>
3-
#include <torch/csrc/stable/ops.h>
44
#include <torch/csrc/stable/tensor.h>
5-
#include <torch/script.h>
6-
#include <torch/torch.h>
7-
8-
using namespace std;
5+
#include <torch/headeronly/core/Dispatch_v2.h>
6+
#include <torch/headeronly/core/ScalarType.h>
97

108
namespace torchaudio {
119
namespace alignment {
1210
namespace cpu {
11+
12+
using torch::headeronly::ScalarType;
13+
using torch::stable::Tensor;
14+
1315
// Inspired from
1416
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
15-
template <typename scalar_t, at::ScalarType target_scalar_type>
17+
template <typename scalar_t, ScalarType target_scalar_type>
1618
void forced_align_impl(
17-
const torch::Tensor& logProbs,
18-
const torch::Tensor& targets,
19+
const Tensor& logProbs,
20+
const Tensor& targets,
1921
const int64_t blank,
20-
torch::Tensor& paths) {
22+
Tensor& paths) {
2123
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
2224
using target_t = typename std::
23-
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
25+
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
2426
const auto batchIndex =
2527
0; // TODO: support batch version and use the real batch index
2628
const auto T = logProbs.size(1);
@@ -36,17 +38,16 @@ void forced_align_impl(
3638
for (int i = 0; i < T * S; i++) {
3739
backPtr_a[i] = -1;
3840
}
39-
40-
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
41-
auto targets_a = targets.accessor<target_t, 2>();
42-
auto paths_a = paths.accessor<target_t, 2>();
41+
auto logProbs_a = torchaudio::stable::accessor<scalar_t, 3>(logProbs);
42+
auto targets_a = torchaudio::stable::accessor<target_t, 2>(targets);
43+
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
4344
auto R = 0;
4445
for (auto i = 1; i < L; i++) {
4546
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
4647
++R;
4748
}
4849
}
49-
TORCH_CHECK(
50+
STD_TORCH_CHECK(
5051
T >= L + R,
5152
"targets length is too long for CTC. Found log_probs length: ",
5253
T,
@@ -138,73 +139,109 @@ void forced_align_impl(
138139
delete[] backPtr_a;
139140
}
140141

141-
std::tuple<torch::Tensor, torch::Tensor> compute(
142-
const torch::Tensor& logProbs,
143-
const torch::Tensor& targets,
144-
const torch::Tensor& inputLengths,
145-
const torch::Tensor& targetLengths,
142+
template <typename scalar_t>
143+
const auto forced_align_long_impl =
144+
forced_align_impl<scalar_t, ScalarType::Long>;
145+
146+
template <typename scalar_t>
147+
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
148+
149+
std::tuple<Tensor, Tensor> compute(
150+
const Tensor& logProbs,
151+
const Tensor& targets,
152+
const Tensor& inputLengths,
153+
const Tensor& targetLengths,
146154
const int64_t blank) {
147-
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
148-
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
149-
TORCH_CHECK(
150-
logProbs.device() == targets.device(),
151-
"log_probs and targets need to be on the same device");
152-
TORCH_CHECK(
153-
logProbs.dtype() == torch::kFloat64 ||
154-
logProbs.dtype() == torch::kFloat32 ||
155-
logProbs.dtype() == torch::kFloat16,
155+
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
156+
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
157+
STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor");
158+
STD_TORCH_CHECK(
159+
targetLengths.is_cpu(), "target_lengths must be a CPU tensor");
160+
STD_TORCH_CHECK(
161+
logProbs.scalar_type() == ScalarType::Double ||
162+
logProbs.scalar_type() == ScalarType::Float ||
163+
logProbs.scalar_type() == ScalarType::Half,
156164
"log_probs must be float64, float32 or float16 (half) type");
157-
TORCH_CHECK(
158-
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
165+
STD_TORCH_CHECK(
166+
targets.scalar_type() == ScalarType::Int ||
167+
targets.scalar_type() == ScalarType::Long,
159168
"targets must be int32 or int64 type");
160-
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
161-
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
162-
TORCH_CHECK(
169+
STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
170+
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
171+
STD_TORCH_CHECK(
163172
logProbs.dim() == 3,
164173
"log_probs must be 3-D (batch_size, input length, num classes)");
165-
TORCH_CHECK(
174+
STD_TORCH_CHECK(
166175
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
167-
TORCH_CHECK(
176+
STD_TORCH_CHECK(
168177
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
169-
TORCH_CHECK(
178+
STD_TORCH_CHECK(
170179
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
171-
TORCH_CHECK(
180+
STD_TORCH_CHECK(
172181
logProbs.size(0) == 1,
173182
"The batch dimension for log_probs must be 1 at the current version.")
174-
TORCH_CHECK(
183+
STD_TORCH_CHECK(
175184
targets.size(0) == 1,
176185
"The batch dimension for targets must be 1 at the current version.")
177-
TORCH_CHECK(
186+
STD_TORCH_CHECK(
178187
blank >= 0 && blank < logProbs.size(-1),
179188
"blank must be within [0, num classes)");
180-
181-
TORCH_CHECK(
182-
logProbs.size(1) == at::max(inputLengths).item().toInt(),
183-
"input length mismatch");
184-
TORCH_CHECK(
185-
targets.size(1) == at::max(targetLengths).item().toInt(),
186-
"target length mismatch");
187-
189+
THO_DISPATCH_V2(
190+
inputLengths.scalar_type(),
191+
"forced_align_impl",
192+
AT_WRAP([&] {
193+
STD_TORCH_CHECK(
194+
logProbs.size(1) == torchaudio::util::max<scalar_t>(inputLengths),
195+
"input length mismatch");
196+
}),
197+
ScalarType::Int,
198+
ScalarType::Long);
199+
THO_DISPATCH_V2(
200+
targetLengths.scalar_type(),
201+
"forced_align_impl",
202+
AT_WRAP([&] {
203+
STD_TORCH_CHECK(
204+
targets.size(1) == torchaudio::util::max<scalar_t>(targetLengths),
205+
"target length mismatch");
206+
}),
207+
ScalarType::Int,
208+
ScalarType::Long);
188209
const auto B = logProbs.size(0);
189210
const auto T = logProbs.size(1);
190-
auto paths = torch::zeros(
191-
{B, T},
192-
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
193-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
194-
logProbs.scalar_type(), "forced_align_impl", [&] {
195-
if (targets.scalar_type() == torch::kInt64) {
196-
forced_align_impl<scalar_t, torch::kInt64>(
197-
logProbs, targets, blank, paths);
211+
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
212+
THO_DISPATCH_V2(
213+
logProbs.scalar_type(),
214+
"forced_align_impl",
215+
AT_WRAP([&] {
216+
if (targets.scalar_type() == ScalarType::Long) {
217+
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
198218
} else {
199-
forced_align_impl<scalar_t, torch::kInt32>(
200-
logProbs, targets, blank, paths);
219+
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
201220
}
202-
});
221+
}),
222+
AT_EXPAND(AT_FLOATING_TYPES),
223+
ScalarType::Half);
203224
return std::make_tuple(paths, logProbs);
204225
}
205226

206-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
207-
m.impl("forced_align", &compute);
227+
void boxed_forced_align_cpu(
228+
StableIValue* stack,
229+
uint64_t num_args,
230+
uint64_t num_outputs) {
231+
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
232+
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
233+
std::tuple<Tensor, Tensor> res = compute(
234+
/*logProbs*/ torch::stable::detail::to<Tensor>(stack[0]),
235+
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
236+
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
237+
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
238+
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])));
239+
stack[0] = torch::stable::detail::from(std::get<0>(res));
240+
stack[1] = torch::stable::detail::from(std::get<1>(res));
241+
}
242+
243+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
244+
m.impl("forced_align", &boxed_forced_align_cpu);
208245
}
209246

210247
} // namespace cpu

0 commit comments

Comments
 (0)