diff --git a/components/segmenter/src/complex/lstm/matrix.rs b/components/segmenter/src/complex/lstm/matrix.rs index 3e777d6e5a5..6b8314f73f6 100644 --- a/components/segmenter/src/complex/lstm/matrix.rs +++ b/components/segmenter/src/complex/lstm/matrix.rs @@ -211,6 +211,7 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { Some(()) } + #[allow(dead_code)] // maybe needed for more complicated bies calculations /// Mutates this matrix by applying a softmax transformation. pub(super) fn softmax_transform(&mut self) { for v in self.data.iter_mut() { diff --git a/components/segmenter/src/complex/lstm/mod.rs b/components/segmenter/src/complex/lstm/mod.rs index c68458db431..c9f72cbd740 100644 --- a/components/segmenter/src/complex/lstm/mod.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -4,8 +4,6 @@ use crate::grapheme::GraphemeClusterSegmenter; use crate::provider::*; -use alloc::boxed::Box; -use alloc::string::String; use alloc::vec::Vec; use core::char::{decode_utf16, REPLACEMENT_CHARACTER}; use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr}; @@ -17,9 +15,8 @@ use matrix::*; struct LstmSegmenterIterator<'s> { input: &'s str, - bies_str: Box<[Bies]>, - pos: usize, pos_utf8: usize, + bies: BiesIterator<'s>, } impl Iterator for LstmSegmenterIterator<'_> { @@ -28,29 +25,27 @@ impl Iterator for LstmSegmenterIterator<'_> { fn next(&mut self) -> Option { #[allow(clippy::indexing_slicing)] // pos_utf8 in range loop { - let bies = *self.bies_str.get(self.pos)?; + let is_e = self.bies.next()?; self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8(); - self.pos += 1; - if bies == Bies::E || self.pos == self.bies_str.len() { + if is_e || self.bies.len() == 0 { return Some(self.pos_utf8); } } } } -struct LstmSegmenterIteratorUtf16 { - bies_str: Box<[Bies]>, +struct LstmSegmenterIteratorUtf16<'s> { + bies: BiesIterator<'s>, pos: usize, } -impl Iterator for LstmSegmenterIteratorUtf16 { +impl Iterator for LstmSegmenterIteratorUtf16<'_> { type Item = usize; fn next(&mut self) -> Option { loop { - let bies = *self.bies_str.get(self.pos)?; self.pos += 1; - if bies == Bies::E || self.pos == self.bies_str.len() { + if self.bies.next()? || self.bies.len() == 0 { return Some(self.pos); } } @@ -66,7 +61,8 @@ pub(super) struct LstmSegmenter<'l> { bw_w: MatrixZero<'l, 3>, bw_u: MatrixZero<'l, 3>, bw_b: MatrixZero<'l, 2>, - time_w: MatrixZero<'l, 3>, + timew_fw: MatrixZero<'l, 2>, + timew_bw: MatrixZero<'l, 2>, time_b: MatrixZero<'l, 1>, grapheme: Option<&'l RuleBreakDataV1<'l>>, } @@ -75,6 +71,11 @@ impl<'l> LstmSegmenter<'l> { /// Returns `Err` if grapheme data is required but not present pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self { let LstmDataV1::Float32(lstm) = lstm; + let time_w = MatrixZero::from(&lstm.time_w); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_fw = time_w.submatrix(0).unwrap(); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_bw = time_w.submatrix(1).unwrap(); Self { dic: lstm.dic.as_borrowed(), embedding: MatrixZero::from(&lstm.embedding), @@ -84,42 +85,21 @@ impl<'l> LstmSegmenter<'l> { bw_w: MatrixZero::from(&lstm.bw_w), bw_u: MatrixZero::from(&lstm.bw_u), bw_b: MatrixZero::from(&lstm.bw_b), - time_w: MatrixZero::from(&lstm.time_w), + timew_fw, + timew_bw, time_b: MatrixZero::from(&lstm.time_b), grapheme: (lstm.model == ModelType::GraphemeClusters).then(|| grapheme), } } /// Create an LSTM based break iterator for an `str` (a UTF-8 string). - pub(super) fn segment_str<'s>(&self, input: &'s str) -> impl Iterator + 's { - let lstm_output = self.produce_bies(input); - LstmSegmenterIterator { - input, - bies_str: lstm_output, - pos: 0, - pos_utf8: 0, - } + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator + 'l { + self.segment_str_p(input) } - /// Create an LSTM based break iterator for a UTF-16 string. - pub(super) fn segment_utf16(&self, input: &[u16]) -> impl Iterator { - let input: String = decode_utf16(input.iter().copied()) - .map(|r| r.unwrap_or(REPLACEMENT_CHARACTER)) - .collect(); - let lstm_output = self.produce_bies(&input); - LstmSegmenterIteratorUtf16 { - bies_str: lstm_output, - pos: 0, - } - } - - /// `produce_bies` is a function that gets a "clean" unsegmented string as its input and returns a BIES (B: Beginning, I: Inside, E: End, - /// S: Single) sequence for grapheme clusters. The boundaries of words can be found easily using this BIES sequence. - fn produce_bies(&self, input: &str) -> Box<[Bies]> { - // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later - // in the embedding layer of the model. - // Already checked that the name of the model is either "codepoints" or "graphclsut" - let input_seq: Vec = if let Some(grapheme) = self.grapheme { + // For unit testing as we cannot inspect the opaque type's bies + fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> { + let input_seq = if let Some(grapheme) = self.grapheme { GraphemeClusterSegmenter::new_and_segment_str(input, grapheme) .collect::>() .windows(2) @@ -129,8 +109,14 @@ impl<'l> LstmSegmenter<'l> { } else { unreachable!() }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + self.dic - .get_copied(UnvalidatedStr::from_str(input.get(range).unwrap_or(input))) + .get_copied(UnvalidatedStr::from_str(grapheme_cluster)) .unwrap_or_else(|| self.dic.len() as u16) }) .collect() @@ -144,162 +130,191 @@ impl<'l> LstmSegmenter<'l> { }) .collect() }; + LstmSegmenterIterator { + input, + pos_utf8: 0, + bies: BiesIterator::new(self, input_seq), + } + } - /// `compute_hc1` implemens the evaluation of one LSTM layer. - fn compute_hc<'a>( - x_t: MatrixZero<'a, 1>, - mut h_tm1: MatrixBorrowedMut<'a, 1>, - mut c_tm1: MatrixBorrowedMut<'a, 1>, - w: MatrixZero<'a, 3>, - u: MatrixZero<'a, 3>, - b: MatrixZero<'a, 2>, - ) { - #[cfg(debug_assertions)] - { - let hunits = h_tm1.dim(); - let embedd_dim = x_t.dim(); - c_tm1.as_borrowed().debug_assert_dims([hunits]); - w.debug_assert_dims([4, hunits, embedd_dim]); - u.debug_assert_dims([4, hunits, hunits]); - b.debug_assert_dims([4, hunits]); - } - - let mut s_t = b.to_owned(); - - s_t.as_mut().add_dot_3d_2(x_t, w); - s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); - - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); - - #[allow(clippy::unwrap_used)] // first dimension is 4 - c_tm1.convolve( - s_t.as_borrowed().submatrix(0).unwrap(), - s_t.as_borrowed().submatrix(2).unwrap(), - s_t.as_borrowed().submatrix(1).unwrap(), - ); + /// Create an LSTM based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator + 'l { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme) + .collect::>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; - #[allow(clippy::unwrap_used)] // first dimension is 4 - h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); + self.dic + .get_copied_by(|key| { + key.as_bytes().iter().copied().cmp( + decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| { + let mut buf = [0; 4]; + let len = c + .unwrap_or(REPLACEMENT_CHARACTER) + .encode_utf8(&mut buf) + .len(); + buf.into_iter().take(len) + }), + ) + }) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + decode_utf16(input.iter().copied()) + .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER)) + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIteratorUtf16 { + bies: BiesIterator::new(self, input_seq), + pos: 0, } + } +} - let hunits = self.fw_u.dim().1; +struct BiesIterator<'l> { + segmenter: &'l LstmSegmenter<'l>, + input_seq: core::iter::Enumerate>, + h_bw: MatrixOwned<2>, + curr_fw: MatrixOwned<1>, + c_fw: MatrixOwned<1>, +} - // Forward LSTM - let mut c_fw = MatrixOwned::<1>::new_zero([hunits]); - let mut all_h_fw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); - for (i, &g_id) in input_seq.iter().enumerate() { - #[allow(clippy::unwrap_used)] - // embedding has shape (dict.len() + 1, hunit), g_id is at most dict.len() - let x_t = self.embedding.submatrix::<1>(g_id as usize).unwrap(); - if i > 0 { - all_h_fw.as_mut().copy_submatrix::<1>(i - 1, i); - } - #[allow(clippy::unwrap_used)] - compute_hc( - x_t, - all_h_fw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) - c_fw.as_mut(), - self.fw_w, - self.fw_u, - self.fw_b, - ); - } +impl<'l> BiesIterator<'l> { + // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later + // in the embedding layer of the model. + fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec) -> Self { + let hunits = segmenter.fw_u.dim().1; // Backward LSTM let mut c_bw = MatrixOwned::<1>::new_zero([hunits]); - let mut all_h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); + let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); for (i, &g_id) in input_seq.iter().enumerate().rev() { - #[allow(clippy::unwrap_used)] - // embedding has shape (dict.len() + 1, hunit), g_id is at most dict.len() - let x_t = self.embedding.submatrix::<1>(g_id as usize).unwrap(); if i + 1 < input_seq.len() { - all_h_bw.as_mut().copy_submatrix::<1>(i + 1, i); + h_bw.as_mut().copy_submatrix::<1>(i + 1, i); } #[allow(clippy::unwrap_used)] compute_hc( - x_t, - all_h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) + segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) c_bw.as_mut(), - self.bw_w, - self.bw_u, - self.bw_b, + segmenter.bw_w, + segmenter.bw_u, + segmenter.bw_b, ); } - #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) - let timew_fw = self.time_w.submatrix(0).unwrap(); - #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) - let timew_bw = self.time_w.submatrix(1).unwrap(); - - // Combining forward and backward LSTMs using the dense time-distributed layer - (0..input_seq.len()) - .map(|i| { - #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) - let curr_fw = all_h_fw.submatrix::<1>(i).unwrap(); - #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) - let curr_bw = all_h_bw.submatrix::<1>(i).unwrap(); - let mut weights = [0.0; 4]; - let mut curr_est = MatrixBorrowedMut { - data: &mut weights, - dims: [4], - }; - curr_est.add_dot_2d(curr_fw, timew_fw); - curr_est.add_dot_2d(curr_bw, timew_bw); - #[allow(clippy::unwrap_used)] // both shape (4) - curr_est.add(self.time_b).unwrap(); - curr_est.softmax_transform(); - Bies::from_probabilities(weights) - }) - .collect() + Self { + input_seq: input_seq.into_iter().enumerate(), + h_bw, + c_fw: MatrixOwned::<1>::new_zero([hunits]), + curr_fw: MatrixOwned::<1>::new_zero([hunits]), + segmenter, + } } } -// TODO(#421): Use common BIES normalizer code -#[derive(Debug, PartialEq, Copy, Clone)] -enum Bies { - B, - I, - E, - S, +impl ExactSizeIterator for BiesIterator<'_> { + fn len(&self) -> usize { + self.input_seq.len() + } } -impl Bies { - /// Returns the value the largest probability - fn from_probabilities(arr: [f32; 4]) -> Bies { - let [b, i, e, s] = arr; - let mut result = Bies::B; - let mut max = b; - if i > max { - result = Bies::I; - max = i; - } - if e > max { - result = Bies::E; - max = e; - } - if s > max { - result = Bies::S; - // max = s; - } - result +impl Iterator for BiesIterator<'_> { + type Item = bool; + + fn next(&mut self) -> Option { + let (i, g_id) = self.input_seq.next()?; + + #[allow(clippy::unwrap_used)] + compute_hc( + self.segmenter + .embedding + .submatrix::<1>(g_id as usize) + .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + self.curr_fw.as_mut(), + self.c_fw.as_mut(), + self.segmenter.fw_w, + self.segmenter.fw_u, + self.segmenter.fw_b, + ); + + #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) + let curr_bw = self.h_bw.submatrix::<1>(i).unwrap(); + let mut weights = [0.0; 4]; + let mut curr_est = MatrixBorrowedMut { + data: &mut weights, + dims: [4], + }; + curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw); + curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw); + #[allow(clippy::unwrap_used)] // both shape (4) + curr_est.add(self.segmenter.time_b).unwrap(); + // For correct BIES weight calculation we'd now have to apply softmax, however + // we're only doing a naive argmax, so a monotonic function doesn't make a difference. + + Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3]) } +} - #[cfg(test)] - fn as_char(&self) -> char { - match self { - Bies::B => 'b', - Bies::I => 'i', - Bies::E => 'e', - Bies::S => 's', - } +/// `compute_hc1` implemens the evaluation of one LSTM layer. +fn compute_hc<'a>( + x_t: MatrixZero<'a, 1>, + mut h_tm1: MatrixBorrowedMut<'a, 1>, + mut c_tm1: MatrixBorrowedMut<'a, 1>, + w: MatrixZero<'a, 3>, + u: MatrixZero<'a, 3>, + b: MatrixZero<'a, 2>, +) { + #[cfg(debug_assertions)] + { + let hunits = h_tm1.dim(); + let embedd_dim = x_t.dim(); + c_tm1.as_borrowed().debug_assert_dims([hunits]); + w.debug_assert_dims([4, hunits, embedd_dim]); + u.debug_assert_dims([4, hunits, hunits]); + b.debug_assert_dims([4, hunits]); } + + let mut s_t = b.to_owned(); + + s_t.as_mut().add_dot_3d_2(x_t, w); + s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + c_tm1.convolve( + s_t.as_borrowed().submatrix(0).unwrap(), + s_t.as_borrowed().submatrix(2).unwrap(), + s_t.as_borrowed().submatrix(1).unwrap(), + ); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); } #[cfg(test)] @@ -371,16 +386,20 @@ mod tests { }; // Testing - for test_case in test_text.data.testcases { - let lstm_output = lstm.produce_bies(&test_case.unseg); + for test_case in &test_text.data.testcases { + let lstm_output = lstm + .segment_str_p(&test_case.unseg) + .bies + .map(|is_e| if is_e { 'e' } else { '?' }) + .collect::(); println!("Test case : {}", test_case.unseg); println!("Expected bies : {}", test_case.expected_bies); - println!("Estimated bies : {lstm_output:?}"); + println!("Estimated bies : {lstm_output}"); println!("True bies : {}", test_case.true_bies); println!("****************************************************"); assert_eq!( - test_case.expected_bies, - lstm_output.iter().map(Bies::as_char).collect::() + test_case.expected_bies.replace(['b', 'i', 's'], "?"), + lstm_output ); } } diff --git a/provider/datagen/src/transform/segmenter/lstm.rs b/provider/datagen/src/transform/segmenter/lstm.rs index d6f6de7cf44..78fd54f0835 100644 --- a/provider/datagen/src/transform/segmenter/lstm.rs +++ b/provider/datagen/src/transform/segmenter/lstm.rs @@ -235,7 +235,6 @@ mod tests { #[test] fn thai_word_break_with_grapheme_model() { - const TEST_STR: &str = "ภาษาไทยภาษาไทย"; let provider = crate::DatagenProvider::for_test(); let raw_data = provider .source @@ -249,12 +248,16 @@ mod tests { ), provider, ); + let segmenter = LineSegmenter::try_new_lstm_with_any_provider(&provider).unwrap(); + + const TEST_STR: &str = "ภาษาไทยภาษาไทย"; + let utf16: Vec = TEST_STR.encode_utf16().collect(); + let breaks: Vec = segmenter.segment_str(TEST_STR).collect(); - assert_eq!( - breaks, - [0, 6, 12, 21, 27, 33, TEST_STR.len()], - "Thai test with grapheme model" - ); + assert_eq!(breaks, [0, 6, 12, 21, 27, 33, TEST_STR.len()],); + + let breaks: Vec = segmenter.segment_utf16(&utf16).collect(); + assert_eq!(breaks, [0, 2, 4, 7, 9, 11, utf16.len()],); } } diff --git a/utils/zerovec/src/map/borrowed.rs b/utils/zerovec/src/map/borrowed.rs index bc93ee49795..b6307990d1f 100644 --- a/utils/zerovec/src/map/borrowed.rs +++ b/utils/zerovec/src/map/borrowed.rs @@ -252,6 +252,12 @@ where self.values.get(index) } + /// For cases when `V` is fixed-size, obtain a direct copy of `V` instead of `V::ULE` + pub fn get_copied_by(&self, predicate: impl FnMut(&K) -> Ordering) -> Option { + let index = self.keys.zvl_binary_search_by(predicate).ok()?; + self.values.get(index) + } + /// Similar to [`Self::iter()`] except it returns a direct copy of the values instead of references /// to `V::ULE`, in cases when `V` is fixed-size pub fn iter_copied_values<'b>(