Skip to content

Commit

Permalink
remove squence loop
Browse files Browse the repository at this point in the history
  • Loading branch information
wcshds committed Dec 8, 2023
1 parent 94d52d2 commit 7121852
Showing 1 changed file with 124 additions and 89 deletions.
213 changes: 124 additions & 89 deletions burn-core/src/nn/loss/ctc.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#![allow(clippy::single_range_in_vec_init)]
use core::marker::PhantomData;

use burn_tensor::{backend::Backend, ElementConversion, Int, Tensor};
use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Tensor};

use super::Reduction;

const NEG_INF: f32 = -10000.0;
const NEG_INF: f32 = -1e5;

/// The Connectionist Temporal Classification loss.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -78,106 +78,94 @@ impl<B: Backend> CTCLoss<B> {
let target_with_blank_length = 2 * max_target_length + 1;

let targets_pad = Self::pad_target(
targets.clone(),
targets,
target_lengths.clone(),
max_target_length,
self.blank,
&device,
);
let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32);
println!("{}", targets_intersperse.clone());
let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes);

let mut log_alphas = Tensor::<B, 3>::empty_device(
let log_alphas = Tensor::<B, 3>::empty_device(
[batch_size, seq_length, target_with_blank_length],
&device,
);
// initialize value at t0
log_alphas = log_alphas.slice_assign(
let log_alphas = log_alphas.slice_assign(
[0..batch_size, 0..1, 0..target_with_blank_length],
Tensor::<B, 3>::full_device(
[batch_size, 1, target_with_blank_length],
NEG_INF,
&device,
),
);
log_alphas = log_alphas.slice_assign(
let log_alphas = log_alphas.slice_assign(
[0..batch_size, 0..1, 0..1],
log_probs
.clone()
.slice([0..batch_size, 0..1, self.blank..(self.blank + 1)]),
);
let target_primes = Self::get_target_primes(targets_pad.clone(), 1, self.blank);
log_alphas = log_alphas.slice_assign(
let target_primes: Tensor<B, 3, Int> = targets_pad
.slice([0..batch_size, 0..1])
.reshape([batch_size, 1, 1]);
let mut log_alphas = log_alphas.slice_assign(
[0..batch_size, 0..1, 1..2],
log_probs
.clone()
.slice([0..batch_size, 0..1, 0..num_classes])
.gather(2, target_primes.reshape([batch_size, 1, 1])),
.gather(2, target_primes),
);
let log_probs_available = targets_one_hot.matmul(log_probs.swap_dims(1, 2));
let mut neg_log_likelihood = Tensor::<B, 1>::zeros_device([batch_size], &device);

for s in 0..target_with_blank_length {
let current_target_primes = Self::get_target_primes(targets_pad.clone(), s, self.blank);

for t in 1..seq_length {
// \alpha_{t-1}(s)
let la1 = log_alphas
// s != s-2
let mask_la3 = targets_intersperse
.clone()
.slice([0..batch_size, 0..(target_with_blank_length - 2)])
.equal(targets_intersperse.slice([0..batch_size, 2..target_with_blank_length]))
.bool_not()
.float();
let mask_la3 = pad(mask_la3, [(0, 0), (2, 0)], 0.0).unsqueeze_dim(1);

for t in 1..seq_length {
// \alpha_{t-1}(s)
let la1 =
log_alphas
.clone()
.slice([0..batch_size, (t - 1)..t, s..(s + 1)])
.reshape([batch_size]);

// for the logsumexp calculation
let mut lamax = la1.clone();

// \alpha_{t-1}(s-1)
let mut la2 = Tensor::<B, 1>::full_device([batch_size], NEG_INF, &device);
if s > 0 {
la2 = log_alphas
.clone()
.slice([0..batch_size, (t - 1)..t, (s - 1)..s])
.reshape([batch_size]);

lamax = lamax
.clone()
.mask_where(la2.clone().greater(lamax.clone()), la2.clone());
}

// \alpha_{t-1}(s-2)
let mut la3 = Tensor::<B, 1>::full_device([batch_size], NEG_INF, &device);
if s > 1 {
la3 = la3.mask_where(
Self::get_target_primes(targets_pad.clone(), s - 2, self.blank)
.equal(current_target_primes.clone())
.bool_not(),
log_alphas
.clone()
.slice([0..batch_size, (t - 1)..t, (s - 2)..(s - 1)])
.reshape([batch_size]),
);

lamax = lamax
.slice([0..batch_size, (t - 1)..t, 0..target_with_blank_length]);
// \alpha_{t-1}(s-1)
let la2 = la1
.clone()
.slice([0..batch_size, 0..1, 0..(target_with_blank_length - 1)])
.clamp_min(NEG_INF);
let la2 = pad(la2, [(0, 0), (0, 0), (1, 0)], NEG_INF);
// \alpha_{t-1}(s-2)
let la3 = la1
.clone()
.slice([0..batch_size, 0..1, 0..(target_with_blank_length - 2)])
.clamp_min(NEG_INF);
let la3 = pad(la3, [(0, 0), (0, 0), (2, 0)], NEG_INF);
// for the logsumexp calculation
let lamax: Tensor<B, 3> =
Tensor::stack::<4>([la1.clone(), la2.clone(), la3.clone()].to_vec(), 3)
.max_dim(3)
.squeeze(3);

log_alphas = log_alphas.slice_assign(
[0..batch_size, t..(t + 1), 0..target_with_blank_length],
((la1 - lamax.clone()).exp()
+ (la2 - lamax.clone()).exp()
+ (la3 - lamax.clone()).exp().mul(mask_la3.clone()))
.log()
.clamp_min(NEG_INF)
+ lamax
+ log_probs_available
.clone()
.mask_where(la3.clone().greater(lamax.clone()), la3.clone());
}

lamax = lamax
.clone()
.mask_fill(lamax.clone().lower_equal_elem(NEG_INF), 0.0);

log_alphas = log_alphas.slice_assign(
[0..batch_size, t..(t + 1), s..(s + 1)],
(((la1.clone() - lamax.clone()).exp()
+ (la2.clone() - lamax.clone()).exp()
+ (la3.clone() - lamax.clone()).exp())
.log()
.clamp_min(NEG_INF)
+ lamax.clone()
+ log_probs
.clone()
.slice([0..batch_size, t..(t + 1), 0..num_classes])
.gather(2, current_target_primes.clone().reshape([batch_size, 1, 1]))
.reshape([batch_size]))
.reshape([batch_size, 1, 1]),
);
}
.slice([0..batch_size, 0..target_with_blank_length, t..(t + 1)])
.swap_dims(1, 2),
);
}

let l1 = log_alphas
Expand All @@ -194,7 +182,7 @@ impl<B: Backend> CTCLoss<B> {
.clone()
.gather(
1,
(input_lengths.clone() - 1)
(input_lengths - 1)
.reshape([batch_size, 1, 1])
.repeat(2, target_with_blank_length),
)
Expand All @@ -219,23 +207,6 @@ impl<B: Backend> CTCLoss<B> {
}
}

fn get_target_primes(
targets_pad: Tensor<B, 2, Int>,
idx: usize,
blank: usize,
) -> Tensor<B, 1, Int> {
let device = targets_pad.device();
let [batch_size, _] = targets_pad.dims();

if idx % 2 == 0 {
Tensor::<B, 1, Int>::full_device([batch_size], blank as i32, &device)
} else {
targets_pad
.slice([0..batch_size, (idx / 2)..(idx / 2 + 1)])
.squeeze(1)
}
}

fn pad_target(
targets: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
Expand Down Expand Up @@ -313,6 +284,62 @@ impl<B: Backend> CTCLoss<B> {
}
}

fn pad<const D: usize, K, E, B>(
tensor: Tensor<B, D, K>,
pad_width: [(usize, usize); D],
fill_value: E,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
E: ElementConversion,
{
let device = tensor.device();
let origin_shape = tensor.dims();

let mut pad_shape = [0; D];
let mut assign_range = Vec::with_capacity(D);
for (idx, (&origin_len, (left_pad, right_pad))) in
origin_shape.iter().zip(pad_width).enumerate()
{
pad_shape[idx] = origin_len + left_pad + right_pad;
assign_range.push(left_pad..(left_pad + origin_len));
}

let padded = Tensor::<B, D, K>::full_device(pad_shape, fill_value, &device);

padded.slice_assign::<D>(assign_range.try_into().unwrap(), tensor)
}

fn intersperse<B, K, E>(tensor: Tensor<B, 2, K>, value: E) -> Tensor<B, 2, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
E: ElementConversion + Clone,
{
let device = tensor.device();
let mut shape = tensor.dims();
let constants: Tensor<B, 2, K> = Tensor::full_device(shape, value.clone(), &device);
shape[1] = shape[1] * 2;
let stack = Tensor::stack::<3>([tensor, constants].to_vec(), 2).reshape(shape);
pad(stack, [(0, 0), (1, 0)], value)
}

fn one_hot<B: Backend>(tensor: Tensor<B, 2, Int>, num_classes: usize) -> Tensor<B, 3> {
let device = tensor.device();
let shape = tensor.dims();

let labels: Tensor<B, 3, Int> = tensor.unsqueeze_dim(2).repeat(2, num_classes);
let indices = Tensor::<B, 1, Int>::arange_device(0..num_classes, &device)
.reshape([1, 1, num_classes])
.repeat(1, shape[1])
.repeat(0, shape[0]);

labels.equal(indices).float()
}

#[cfg(test)]
mod test {
use burn_tensor::Data;
Expand All @@ -321,6 +348,14 @@ mod test {

use super::*;

#[test]
fn test_intersperse() {
let tensor = Tensor::<TestBackend, 1, Int>::arange(1..25).reshape([4, 6]);
let tensor = intersperse(tensor, 0);

println!("{}", tensor);
}

#[test]
fn test_ctc_loss() {
let input = Tensor::<TestBackend, 3>::from_data([[
Expand Down

0 comments on commit 7121852

Please sign in to comment.