From e8de12b805c0e7ec61b86dc6a54ae206bf186dec Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 19 Apr 2023 14:52:04 +0200 Subject: [PATCH 1/8] x --- components/segmenter/src/complex.rs | 366 ++++++++++++---------------- 1 file changed, 162 insertions(+), 204 deletions(-) diff --git a/components/segmenter/src/complex.rs b/components/segmenter/src/complex.rs index d17694c52f5..9a48c2bed9d 100644 --- a/components/segmenter/src/complex.rs +++ b/components/segmenter/src/complex.rs @@ -2,6 +2,8 @@ // called LICENSE at the top level of the ICU4X source tree // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). +#[cfg(feature = "lstm")] +use crate::lstm::LstmSegmenter; use crate::dictionary::DictionarySegmenter; use crate::language::*; use crate::provider::*; @@ -9,45 +11,58 @@ use alloc::vec::Vec; use icu_locid::{locale, Locale}; use icu_provider::prelude::*; +#[cfg(not(feature = "lstm"))] +type DictOrLstm = Result, core::convert::Infallible>; +#[cfg(not(feature = "lstm"))] +type DictOrLstmBorrowed<'a> = + Result<&'a DataPayload, &'a core::convert::Infallible>; + +#[cfg(feature = "lstm")] +type DictOrLstm = + Result, DataPayload>; +#[cfg(feature = "lstm")] +type DictOrLstmBorrowed<'a> = + Result<&'a DataPayload, &'a DataPayload>; + #[derive(Debug)] pub(crate) struct ComplexPayloads { grapheme: DataPayload, - burmese_lstm: Option>, - khmer_lstm: Option>, - lao_lstm: Option>, - thai_lstm: Option>, - burmese_dict: Option>, - khmer_dict: Option>, - lao_dict: Option>, - thai_dict: Option>, - cj_dict: Option>, + my: Option, + km: Option, + lo: Option, + th: Option, + ja: Option>, } impl ComplexPayloads { - fn select_lstm(&self, language: Language) -> Option<&DataPayload> { - match language { - Language::Burmese => self.burmese_lstm.as_ref(), - Language::Khmer => self.khmer_lstm.as_ref(), - Language::Lao => self.lao_lstm.as_ref(), - Language::Thai => self.thai_lstm.as_ref(), - Language::ChineseOrJapanese | Language::Unknown => None, - } - } - - fn select_dict( - &self, - language: Language, - ) -> Option<&DataPayload> { + fn select(&self, language: Language) -> Option { + const ERR: DataError = DataError::custom("No segmentation model for language"); match language { - Language::Burmese => self.burmese_dict.as_ref(), - Language::Khmer => self.khmer_dict.as_ref(), - Language::Lao => self.lao_dict.as_ref(), - Language::Thai => self.thai_dict.as_ref(), - Language::ChineseOrJapanese => self.cj_dict.as_ref(), + Language::Burmese => self.my.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("my"); + None + }), + Language::Khmer => self.km.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("km"); + None + }), + Language::Lao => self.lo.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("lo"); + None + }), + Language::Thai => self.th.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("th"); + None + }), + Language::ChineseOrJapanese => self.ja.as_ref().map(Ok).or_else(|| { + ERR.with_display_context("ja"); + None + }), Language::Unknown => None, } } + #[cfg(feature = "lstm")] pub(crate) fn try_new_lstm(provider: &D) -> Result where D: DataProvider @@ -56,19 +71,19 @@ impl ComplexPayloads { { Ok(Self { grapheme: provider.load(Default::default())?.take_payload()?, - burmese_lstm: try_load::(provider, locale!("my"))? - .map(DataPayload::cast), - khmer_lstm: try_load::(provider, locale!("km"))? - .map(DataPayload::cast), - lao_lstm: try_load::(provider, locale!("lo"))? - .map(DataPayload::cast), - thai_lstm: try_load::(provider, locale!("th"))? - .map(DataPayload::cast), - burmese_dict: None, - khmer_dict: None, - lao_dict: None, - thai_dict: None, - cj_dict: None, + my: try_load::(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Err), + km: try_load::(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Err), + lo: try_load::(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Err), + th: try_load::(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Err), + ja: None, }) } @@ -81,31 +96,19 @@ impl ComplexPayloads { { Ok(Self { grapheme: provider.load(Default::default())?.take_payload()?, - burmese_lstm: None, - khmer_lstm: None, - lao_lstm: None, - thai_lstm: None, - burmese_dict: try_load::( - provider, - locale!("my"), - )? - .map(DataPayload::cast), - khmer_dict: try_load::( - provider, - locale!("km"), - )? - .map(DataPayload::cast), - lao_dict: try_load::( - provider, - locale!("lo"), - )? - .map(DataPayload::cast), - thai_dict: try_load::( - provider, - locale!("th"), - )? - .map(DataPayload::cast), - cj_dict: try_load::(provider, locale!("ja"))? + my: try_load::(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Ok), + km: try_load::(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Ok), + lo: try_load::(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Ok), + th: try_load::(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Ok), + ja: try_load::(provider, locale!("ja"))? .map(DataPayload::cast), }) } @@ -120,19 +123,19 @@ impl ComplexPayloads { { Ok(Self { grapheme: provider.load(Default::default())?.take_payload()?, - burmese_lstm: try_load::(provider, locale!("my"))? - .map(DataPayload::cast), - khmer_lstm: try_load::(provider, locale!("km"))? - .map(DataPayload::cast), - lao_lstm: try_load::(provider, locale!("lo"))? - .map(DataPayload::cast), - thai_lstm: try_load::(provider, locale!("th"))? - .map(DataPayload::cast), - burmese_dict: None, - khmer_dict: None, - lao_dict: None, - thai_dict: None, - cj_dict: try_load::(provider, locale!("ja"))? + my: try_load::(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Err), + km: try_load::(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Err), + lo: try_load::(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Err), + th: try_load::(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Err), + ja: try_load::(provider, locale!("ja"))? .map(DataPayload::cast), }) } @@ -145,31 +148,19 @@ impl ComplexPayloads { { Ok(Self { grapheme: provider.load(Default::default())?.take_payload()?, - burmese_lstm: None, - khmer_lstm: None, - lao_lstm: None, - thai_lstm: None, - burmese_dict: try_load::( - provider, - locale!("my"), - )? - .map(DataPayload::cast), - khmer_dict: try_load::( - provider, - locale!("km"), - )? - .map(DataPayload::cast), - lao_dict: try_load::( - provider, - locale!("lo"), - )? - .map(DataPayload::cast), - thai_dict: try_load::( - provider, - locale!("th"), - )? - .map(DataPayload::cast), - cj_dict: None, + my: try_load::(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Ok), + km: try_load::(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Ok), + lo: try_load::(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Ok), + th: try_load::(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Ok), + ja: None, }) } } @@ -196,88 +187,68 @@ fn try_load + ?Sized>( } /// Return UTF-16 segment offset array using dictionary or lstm segmenter. -#[allow(unused_variables)] pub(crate) fn complex_language_segment_utf16( payloads: &ComplexPayloads, input: &[u16], ) -> Vec { - let mut result: Vec = Vec::new(); - let lang_iter = LanguageIteratorUtf16::new(input); + let mut result = Vec::new(); let mut offset = 0; - for (str_per_lang, lang) in lang_iter { - if lang == Language::Unknown { - offset += str_per_lang.len(); - result.push(offset); - } else if let Some(lstm) = payloads.select_lstm(lang) { + for (slice, lang) in LanguageIteratorUtf16::new(input) { + match payloads.select(lang) { + Some(Ok(dict)) => { + result.extend( + DictionarySegmenter::new(dict, &payloads.grapheme) + .segment_utf16(slice) + .map(|n| offset + n), + ); + } #[cfg(feature = "lstm")] - { - let segmenter = crate::lstm::LstmSegmenter::new(lstm, &payloads.grapheme); - let breaks = segmenter.segment_utf16(str_per_lang); - result.extend(breaks.map(|n| offset + n)); - offset += str_per_lang.len(); + Some(Err(lstm)) => { + result.extend( + LstmSegmenter::new(lstm, &payloads.grapheme) + .segment_utf16(slice) + .map(|n| offset + n), + ); + } + #[cfg(not(feature = "lstm"))] + Some(Err(_infallible)) => {} // should be refutable + None => { + result.push(offset + slice.len()); } - } else if let Some(dict) = payloads.select_dict(lang) { - let segmenter = DictionarySegmenter::new(dict, &payloads.grapheme); - let breaks = segmenter.segment_utf16(str_per_lang); - result.extend(breaks.map(|n| offset + n)); - offset += str_per_lang.len(); - } else { - // Create error for logging - DataError::custom("No segmentation model for language").with_display_context( - match lang { - Language::Thai => "th", - Language::Lao => "lo", - Language::Burmese => "my", - Language::Khmer => "km", - Language::ChineseOrJapanese => "ja", - Language::Unknown => unreachable!(), - }, - ); - offset += str_per_lang.len(); - result.push(offset); } + offset += slice.len(); } result } /// Return UTF-8 segment offset array using dictionary or lstm segmenter. -#[allow(unused_variables)] pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &str) -> Vec { - let mut result: Vec = Vec::new(); - let lang_iter = LanguageIterator::new(input); + let mut result = Vec::new(); let mut offset = 0; - for (str_per_lang, lang) in lang_iter { - if lang == Language::Unknown { - offset += str_per_lang.len(); - result.push(offset); - } else if let Some(lstm) = payloads.select_lstm(lang) { + for (slice, lang) in LanguageIterator::new(input) { + match payloads.select(lang) { + Some(Ok(dict)) => { + result.extend( + DictionarySegmenter::new(dict, &payloads.grapheme) + .segment_str(slice) + .map(|n| offset + n), + ); + } #[cfg(feature = "lstm")] - { - let segmenter = crate::lstm::LstmSegmenter::new(lstm, &payloads.grapheme); - let breaks = segmenter.segment_str(str_per_lang); - result.extend(breaks.map(|n| offset + n)); - offset += str_per_lang.len(); + Some(Err(lstm)) => { + result.extend( + LstmSegmenter::new(lstm, &payloads.grapheme) + .segment_str(slice) + .map(|n| offset + n), + ); + } + #[cfg(not(feature = "lstm"))] + Some(Err(_infallible)) => {} // should be refutable + None => { + result.push(offset + slice.len()); } - } else if let Some(dict) = payloads.select_dict(lang) { - let segmenter = DictionarySegmenter::new(dict, &payloads.grapheme); - let breaks = segmenter.segment_str(str_per_lang); - result.extend(breaks.map(|n| offset + n)); - offset += str_per_lang.len(); - } else { - // Create error for logging - DataError::custom("No segmentation model for language").with_display_context( - match lang { - Language::Thai => "th", - Language::Lao => "lo", - Language::Burmese => "my", - Language::Khmer => "km", - Language::ChineseOrJapanese => "ja", - Language::Unknown => unreachable!(), - }, - ); - offset += str_per_lang.len(); - result.push(offset); } + offset += slice.len(); } result } @@ -286,46 +257,33 @@ pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &s #[cfg(feature = "serde")] mod tests { use super::*; - use icu_locid::locale; #[test] fn thai_word_break() { const TEST_STR: &str = "ภาษาไทยภาษาไทย"; - let grapheme = try_load::( - &icu_testdata::buffer().as_deserializing(), - Locale::UND, - ) - .unwrap() - .unwrap(); - let dict = try_load::( - &icu_testdata::buffer().as_deserializing(), - locale!("th"), - ) - .unwrap() - .unwrap(); - let lstm = try_load::( - &icu_testdata::buffer().as_deserializing(), - locale!("th"), - ) - .unwrap() - .unwrap(); - let payloads = ComplexPayloads { - grapheme, - burmese_lstm: None, - khmer_lstm: None, - lao_lstm: None, - thai_lstm: Some(lstm.cast()), - burmese_dict: None, - khmer_dict: None, - lao_dict: None, - thai_dict: Some(dict.cast()), - cj_dict: None, - }; - let breaks = complex_language_segment_str(&payloads, TEST_STR); - assert_eq!(breaks, [12, 21, 33, 42], "Thai test by UTF-8"); - let utf16: Vec = TEST_STR.encode_utf16().collect(); - let breaks = complex_language_segment_utf16(&payloads, &utf16); - assert_eq!(breaks, [4, 7, 11, 14], "Thai test by UTF-16"); + + let lstm = + ComplexPayloads::try_new_lstm(&icu_testdata::buffer().as_deserializing()).unwrap(); + let dict = + ComplexPayloads::try_new_dict(&icu_testdata::buffer().as_deserializing()).unwrap(); + + assert_eq!( + complex_language_segment_str(&lstm, TEST_STR), + [12, 21, 33, 42] + ); + assert_eq!( + complex_language_segment_utf16(&lstm, &utf16), + [4, 7, 11, 14] + ); + + assert_eq!( + complex_language_segment_str(&dict, TEST_STR), + [12, 21, 33, 42] + ); + assert_eq!( + complex_language_segment_utf16(&dict, &utf16), + [4, 7, 11, 14] + ); } } From dd061efcfe6fcaa7f0475c904625331b14073a46 Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 19 Apr 2023 14:51:31 +0200 Subject: [PATCH 2/8] vis --- .../segmenter/src/{ => complex}/dictionary.rs | 7 ++--- .../segmenter/src/{ => complex}/language.rs | 0 .../lstm/matrix.rs} | 29 +++++++++++++++++-- .../src/{lstm.rs => complex/lstm/mod.rs} | 22 +++++++------- .../src/{complex.rs => complex/mod.rs} | 12 +++++--- components/segmenter/src/lib.rs | 7 ----- components/segmenter/src/provider/lstm.rs | 12 ++------ 7 files changed, 50 insertions(+), 39 deletions(-) rename components/segmenter/src/{ => complex}/dictionary.rs (98%) rename components/segmenter/src/{ => complex}/language.rs (100%) rename components/segmenter/src/{math_helper.rs => complex/lstm/matrix.rs} (95%) rename components/segmenter/src/{lstm.rs => complex/lstm/mod.rs} (96%) rename components/segmenter/src/{complex.rs => complex/mod.rs} (99%) diff --git a/components/segmenter/src/dictionary.rs b/components/segmenter/src/complex/dictionary.rs similarity index 98% rename from components/segmenter/src/dictionary.rs rename to components/segmenter/src/complex/dictionary.rs index ac0ccb19f43..32e1e776cfa 100644 --- a/components/segmenter/src/dictionary.rs +++ b/components/segmenter/src/complex/dictionary.rs @@ -186,11 +186,8 @@ impl<'l> DictionarySegmenter<'l> { #[cfg(test)] #[cfg(feature = "serde")] mod tests { - use crate::{ - dictionary::DictionarySegmenter, provider::DictionaryForWordOnlyAutoV1Marker, - LineSegmenter, WordSegmenter, - }; - use icu_provider::prelude::*; + use super::*; + use crate::{provider::DictionaryForWordOnlyAutoV1Marker, LineSegmenter, WordSegmenter}; use icu_provider_adapters::fork::ForkByKeyProvider; use icu_provider_fs::FsDataProvider; use std::path::PathBuf; diff --git a/components/segmenter/src/language.rs b/components/segmenter/src/complex/language.rs similarity index 100% rename from components/segmenter/src/language.rs rename to components/segmenter/src/complex/language.rs diff --git a/components/segmenter/src/math_helper.rs b/components/segmenter/src/complex/lstm/matrix.rs similarity index 95% rename from components/segmenter/src/math_helper.rs rename to components/segmenter/src/complex/lstm/matrix.rs index 545dca8d14c..093c056e8fd 100644 --- a/components/segmenter/src/math_helper.rs +++ b/components/segmenter/src/complex/lstm/matrix.rs @@ -397,11 +397,34 @@ pub struct MatrixZero<'a, const D: usize> { dims: [usize; D], } -impl<'a, const D: usize> MatrixZero<'a, D> { - pub fn from_parts_unchecked(data: &'a ZeroSlice, dims: [usize; D]) -> Self { - Self { data, dims } +impl<'a> From<&'a crate::provider::LstmMatrix1<'a>> for MatrixZero<'a, 1> { + fn from(other: &'a crate::provider::LstmMatrix1<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a> From<&'a crate::provider::LstmMatrix2<'a>> for MatrixZero<'a, 2> { + fn from(other: &'a crate::provider::LstmMatrix2<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } } +} + +impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> { + fn from(other: &'a crate::provider::LstmMatrix3<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} +impl<'a, const D: usize> MatrixZero<'a, D> { #[allow(clippy::wrong_self_convention)] // same convention as slice::to_vec pub fn to_owned(&self) -> MatrixOwned { MatrixOwned { diff --git a/components/segmenter/src/lstm.rs b/components/segmenter/src/complex/lstm/mod.rs similarity index 96% rename from components/segmenter/src/lstm.rs rename to components/segmenter/src/complex/lstm/mod.rs index 98e49a2dd27..37128160194 100644 --- a/components/segmenter/src/lstm.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -3,7 +3,6 @@ // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). use crate::grapheme::GraphemeClusterSegmenter; -use crate::math_helper::{MatrixBorrowedMut, MatrixOwned, MatrixZero}; use crate::provider::*; use alloc::boxed::Box; use alloc::string::String; @@ -12,6 +11,9 @@ use core::char::{decode_utf16, REPLACEMENT_CHARACTER}; use icu_provider::DataPayload; use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr}; +mod matrix; +use matrix::*; + // A word break iterator using LSTM model. Input string have to be same language. pub struct LstmSegmenterIterator<'s> { @@ -79,15 +81,15 @@ impl<'l> LstmSegmenter<'l> { let LstmDataV1::Float32(lstm) = lstm.get(); Self { dic: lstm.dic.as_borrowed(), - embedding: lstm.embedding.as_matrix_zero(), - fw_w: lstm.fw_w.as_matrix_zero(), - fw_u: lstm.fw_u.as_matrix_zero(), - fw_b: lstm.fw_b.as_matrix_zero(), - bw_w: lstm.bw_w.as_matrix_zero(), - bw_u: lstm.bw_u.as_matrix_zero(), - bw_b: lstm.bw_b.as_matrix_zero(), - time_w: lstm.time_w.as_matrix_zero(), - time_b: lstm.time_b.as_matrix_zero(), + embedding: MatrixZero::from(&lstm.embedding), + fw_w: MatrixZero::from(&lstm.fw_w), + fw_u: MatrixZero::from(&lstm.fw_u), + fw_b: MatrixZero::from(&lstm.fw_b), + bw_w: MatrixZero::from(&lstm.bw_w), + bw_u: MatrixZero::from(&lstm.bw_u), + bw_b: MatrixZero::from(&lstm.bw_b), + time_w: MatrixZero::from(&lstm.time_w), + time_b: MatrixZero::from(&lstm.time_b), grapheme: (lstm.model == ModelType::GraphemeClusters).then(|| grapheme.get()), } } diff --git a/components/segmenter/src/complex.rs b/components/segmenter/src/complex/mod.rs similarity index 99% rename from components/segmenter/src/complex.rs rename to components/segmenter/src/complex/mod.rs index 9a48c2bed9d..b1a730e6beb 100644 --- a/components/segmenter/src/complex.rs +++ b/components/segmenter/src/complex/mod.rs @@ -2,15 +2,19 @@ // called LICENSE at the top level of the ICU4X source tree // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). -#[cfg(feature = "lstm")] -use crate::lstm::LstmSegmenter; -use crate::dictionary::DictionarySegmenter; -use crate::language::*; use crate::provider::*; use alloc::vec::Vec; use icu_locid::{locale, Locale}; use icu_provider::prelude::*; +mod dictionary; +use dictionary::*; +mod language; +use language::*; +#[cfg(feature = "lstm")] +mod lstm; +use lstm::*; + #[cfg(not(feature = "lstm"))] type DictOrLstm = Result, core::convert::Infallible>; #[cfg(not(feature = "lstm"))] diff --git a/components/segmenter/src/lib.rs b/components/segmenter/src/lib.rs index c4ca8a2e7b2..b809ed72dae 100644 --- a/components/segmenter/src/lib.rs +++ b/components/segmenter/src/lib.rs @@ -126,11 +126,9 @@ extern crate alloc; mod complex; -mod dictionary; mod error; mod indices; mod iterator_helpers; -mod language; mod rule_segmenter; mod grapheme; @@ -144,11 +142,6 @@ pub mod provider; #[doc(hidden)] pub mod symbols; -#[cfg(feature = "lstm")] -mod lstm; -#[cfg(feature = "lstm")] -mod math_helper; - // Main Segmenter and BreakIterator public types pub use crate::grapheme::GraphemeClusterBreakIterator; pub use crate::grapheme::GraphemeClusterSegmenter; diff --git a/components/segmenter/src/provider/lstm.rs b/components/segmenter/src/provider/lstm.rs index 61fc0d3e11f..6a85680e4ca 100644 --- a/components/segmenter/src/provider/lstm.rs +++ b/components/segmenter/src/provider/lstm.rs @@ -26,9 +26,9 @@ macro_rules! lstm_matrix { pub struct $name<'data> { // Invariant: dims.product() == data.len() #[allow(missing_docs)] - dims: [u16; $generic], + pub(crate) dims: [u16; $generic], #[allow(missing_docs)] - data: ZeroVec<'data, f32>, + pub(crate) data: ZeroVec<'data, f32>, } impl<'data> $name<'data> { @@ -52,14 +52,6 @@ macro_rules! lstm_matrix { ) -> Self { Self { dims, data } } - - #[cfg(feature = "lstm")] - pub(crate) fn as_matrix_zero(&self) -> crate::math_helper::MatrixZero<$generic> { - crate::math_helper::MatrixZero::from_parts_unchecked( - &self.data, - self.dims.map(|x| x as usize), - ) - } } #[cfg(feature = "serde")] From bb1067bd6dcda20a0bc72aaf899df2a2f16a7e7e Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 19 Apr 2023 15:16:19 +0200 Subject: [PATCH 3/8] pubs --- .../segmenter/src/complex/dictionary.rs | 28 ++++---- components/segmenter/src/complex/language.rs | 10 +-- .../segmenter/src/complex/lstm/matrix.rs | 72 +++++++++---------- components/segmenter/src/complex/lstm/mod.rs | 30 ++++---- 4 files changed, 70 insertions(+), 70 deletions(-) diff --git a/components/segmenter/src/complex/dictionary.rs b/components/segmenter/src/complex/dictionary.rs index 32e1e776cfa..c3e5fe3dce1 100644 --- a/components/segmenter/src/complex/dictionary.rs +++ b/components/segmenter/src/complex/dictionary.rs @@ -10,7 +10,7 @@ use icu_collections::char16trie::{Char16Trie, TrieResult}; use icu_provider::prelude::*; /// A trait for dictionary based iterator -pub trait DictionaryType<'l, 's> { +trait DictionaryType<'l, 's> { /// The iterator over characters. type IterAttr: Iterator + Clone; @@ -21,7 +21,7 @@ pub trait DictionaryType<'l, 's> { fn char_len(c: Self::CharType) -> usize; } -pub struct DictionaryBreakIterator< +struct DictionaryBreakIterator< 'l, 's, Y: DictionaryType<'l, 's> + ?Sized, @@ -137,13 +137,13 @@ impl<'l, 's> DictionaryType<'l, 's> for char { } } -pub(crate) struct DictionarySegmenter<'l> { +pub(super) struct DictionarySegmenter<'l> { dict: &'l UCharDictionaryBreakDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>, } impl<'l> DictionarySegmenter<'l> { - pub fn new( + pub(super) fn new( dict: &'l DataPayload, grapheme: &'l DataPayload, ) -> Self { @@ -155,12 +155,12 @@ impl<'l> DictionarySegmenter<'l> { } /// Create a dictionary based break iterator for an `str` (a UTF-8 string). - pub fn segment_str<'s>( - &'s self, - input: &'s str, - ) -> DictionaryBreakIterator<'l, 's, char, GraphemeClusterBreakIteratorUtf8> { + pub(super) fn segment_str( + &'l self, + input: &'l str, + ) -> impl Iterator + 'l { let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme); - DictionaryBreakIterator { + DictionaryBreakIterator:: { trie: Char16Trie::new(self.dict.trie_data.clone()), iter: input.char_indices(), len: input.len(), @@ -169,12 +169,12 @@ impl<'l> DictionarySegmenter<'l> { } /// Create a dictionary based break iterator for a UTF-16 string. - pub fn segment_utf16<'s>( - &'s self, - input: &'s [u16], - ) -> DictionaryBreakIterator<'l, 's, u32, GraphemeClusterBreakIteratorUtf16> { + pub(super) fn segment_utf16( + &'l self, + input: &'l [u16], + ) -> impl Iterator +'l { let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme); - DictionaryBreakIterator { + DictionaryBreakIterator:: { trie: Char16Trie::new(self.dict.trie_data.clone()), iter: Utf16Indices::new(input), len: input.len(), diff --git a/components/segmenter/src/complex/language.rs b/components/segmenter/src/complex/language.rs index 801ed918e6b..327eea5e20b 100644 --- a/components/segmenter/src/complex/language.rs +++ b/components/segmenter/src/complex/language.rs @@ -3,7 +3,7 @@ // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). #[derive(PartialEq, Debug, Copy, Clone)] -pub enum Language { +pub(super) enum Language { Burmese, ChineseOrJapanese, Khmer, @@ -43,12 +43,12 @@ fn get_language(codepoint: u32) -> Language { /// This struct is an iterator that returns the string per language from the /// given string. -pub struct LanguageIterator<'s> { +pub(super) struct LanguageIterator<'s> { rest: &'s str, } impl<'s> LanguageIterator<'s> { - pub fn new(input: &'s str) -> Self { + pub(super) fn new(input: &'s str) -> Self { Self { rest: input } } } @@ -70,12 +70,12 @@ impl<'s> Iterator for LanguageIterator<'s> { } } -pub struct LanguageIteratorUtf16<'s> { +pub(super) struct LanguageIteratorUtf16<'s> { rest: &'s [u16], } impl<'s> LanguageIteratorUtf16<'s> { - pub fn new(input: &'s [u16]) -> Self { + pub(super) fn new(input: &'s [u16]) -> Self { Self { rest: input } } } diff --git a/components/segmenter/src/complex/lstm/matrix.rs b/components/segmenter/src/complex/lstm/matrix.rs index 093c056e8fd..8f97a59a7c7 100644 --- a/components/segmenter/src/complex/lstm/matrix.rs +++ b/components/segmenter/src/complex/lstm/matrix.rs @@ -14,13 +14,13 @@ use num_traits::Float; /// `tanh` computes the tanh function for a scalar value. #[inline] -pub fn tanh(x: f32) -> f32 { +fn tanh(x: f32) -> f32 { x.tanh() } /// `sigmoid` computes the sigmoid function for a scalar value. #[inline] -pub fn sigmoid(x: f32) -> f32 { +fn sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp()) } @@ -30,20 +30,20 @@ pub fn sigmoid(x: f32) -> f32 { /// submatrices. For example, indexing into a matrix of size 5x4x3 returns a /// matrix of size 4x3. For more information, see [`MatrixOwned::submatrix`]. #[derive(Debug, Clone)] -pub struct MatrixOwned { +pub(super) struct MatrixOwned { data: Vec, dims: [usize; D], } impl MatrixOwned { - pub fn as_borrowed(&self) -> MatrixBorrowed { + pub(super) fn as_borrowed(&self) -> MatrixBorrowed { MatrixBorrowed { data: &self.data, dims: self.dims, } } - pub fn new_zero(dims: [usize; D]) -> Self { + pub(super) fn new_zero(dims: [usize; D]) -> Self { let total_len = dims.iter().product::(); MatrixOwned { data: vec![0.0; total_len], @@ -58,7 +58,7 @@ impl MatrixOwned { /// /// The type parameter `M` should be `D - 1`. #[inline] - pub fn submatrix(&self, index: usize) -> Option> { + pub(super) fn submatrix(&self, index: usize) -> Option> { // This assertion is based on const generics; it should always succeed and be elided. assert_eq!(M, D - 1); let (range, dims) = self.as_borrowed().submatrix_range(index); @@ -66,7 +66,7 @@ impl MatrixOwned { Some(MatrixBorrowed { data, dims }) } - pub fn as_mut(&mut self) -> MatrixBorrowedMut { + pub(super) fn as_mut(&mut self) -> MatrixBorrowedMut { MatrixBorrowedMut { data: &mut self.data, dims: self.dims, @@ -75,7 +75,7 @@ impl MatrixOwned { /// A mutable version of [`Self::submatrix`]. #[inline] - pub fn submatrix_mut(&mut self, index: usize) -> Option> { + pub(super) fn submatrix_mut(&mut self, index: usize) -> Option> { // This assertion is based on const generics; it should always succeed and be elided. assert_eq!(M, D - 1); let (range, dims) = self.as_borrowed().submatrix_range(index); @@ -86,26 +86,26 @@ impl MatrixOwned { /// A `D`-dimensional, borrowed matrix. #[derive(Debug, Clone, Copy)] -pub struct MatrixBorrowed<'a, const D: usize> { +pub(super) struct MatrixBorrowed<'a, const D: usize> { data: &'a [f32], dims: [usize; D], } impl<'a, const D: usize> MatrixBorrowed<'a, D> { #[cfg(debug_assertions)] - pub fn debug_assert_dims(&self, dims: [usize; D]) { + pub(super) fn debug_assert_dims(&self, dims: [usize; D]) { debug_assert_eq!(dims, self.dims); let expected_len = dims.iter().product::(); debug_assert_eq!(expected_len, self.data.len()); } - pub fn as_slice(&self) -> &'a [f32] { + pub(super) fn as_slice(&self) -> &'a [f32] { self.data } /// See [`MatrixOwned::submatrix`]. #[inline] - pub fn submatrix(&self, index: usize) -> Option> { + pub(super) fn submatrix(&self, index: usize) -> Option> { // This assertion is based on const generics; it should always succeed and be elided. assert_eq!(M, D - 1); let (range, dims) = self.submatrix_range(index); @@ -129,21 +129,21 @@ macro_rules! impl_basic_dim { ($t1:path, $t2:path, $t3:path) => { impl<'a> $t1 { #[allow(dead_code)] - pub fn dim(&self) -> usize { + pub(super) fn dim(&self) -> usize { let [dim] = self.dims; dim } } impl<'a> $t2 { #[allow(dead_code)] - pub fn dim(&self) -> (usize, usize) { + pub(super) fn dim(&self) -> (usize, usize) { let [d0, d1] = self.dims; (d0, d1) } } impl<'a> $t3 { #[allow(dead_code)] - pub fn dim(&self) -> (usize, usize, usize) { + pub(super) fn dim(&self) -> (usize, usize, usize) { let [d0, d1, d2] = self.dims; (d0, d1, d2) } @@ -165,24 +165,24 @@ impl_basic_dim!( impl_basic_dim!(MatrixZero<'a, 1>, MatrixZero<'a, 2>, MatrixZero<'a, 3>); /// A `D`-dimensional, mutably borrowed matrix. -pub struct MatrixBorrowedMut<'a, const D: usize> { - pub(crate) data: &'a mut [f32], - pub(crate) dims: [usize; D], +pub(super) struct MatrixBorrowedMut<'a, const D: usize> { + pub(super) data: &'a mut [f32], + pub(super) dims: [usize; D], } impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { - pub fn as_borrowed(&self) -> MatrixBorrowed { + pub(super) fn as_borrowed(&self) -> MatrixBorrowed { MatrixBorrowed { data: self.data, dims: self.dims, } } - pub fn as_mut_slice(&mut self) -> &mut [f32] { + pub(super) fn as_mut_slice(&mut self) -> &mut [f32] { self.data } - pub fn copy_submatrix(&mut self, from: usize, to: usize) { + pub(super) fn copy_submatrix(&mut self, from: usize, to: usize) { let (range_from, _) = self.as_borrowed().submatrix_range::(from); let (range_to, _) = self.as_borrowed().submatrix_range::(to); if let (Some(_), Some(_)) = ( @@ -195,7 +195,7 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { } #[must_use] - pub fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> { + pub(super) fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> { debug_assert_eq!(self.dims, other.dims); // TODO: Vectorize? for i in 0..self.data.len() { @@ -205,26 +205,26 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { } /// Mutates this matrix by applying a softmax transformation. - pub fn softmax_transform(&mut self) { + pub(super) fn softmax_transform(&mut self) { let sm = self.data.iter().map(|v| v.exp()).sum::(); self.data.iter_mut().for_each(|v| { *v = v.exp() / sm; }); } - pub fn sigmoid_transform(&mut self) { + pub(super) fn sigmoid_transform(&mut self) { for x in &mut self.data.iter_mut() { *x = sigmoid(*x); } } - pub fn tanh_transform(&mut self) { + pub(super) fn tanh_transform(&mut self) { for x in &mut self.data.iter_mut() { *x = tanh(*x); } } - pub fn convolve( + pub(super) fn convolve( &mut self, i: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>, @@ -247,7 +247,7 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { } } - pub fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) { + pub(super) fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) { let o = o.as_slice(); let c = c.as_slice(); let len = self.data.len(); @@ -267,7 +267,7 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { impl<'a> MatrixBorrowed<'a, 1> { #[allow(dead_code)] // could be useful - pub fn dot_1d(&self, other: MatrixZero<1>) -> f32 { + pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 { debug_assert_eq!(self.dims, other.dims); unrolled_dot_1(self.data, other.data) } @@ -278,7 +278,7 @@ impl<'a> MatrixBorrowedMut<'a, 1> { /// /// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N; /// this is the opposite of standard practice. - pub fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) { + pub(super) fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) { let m = a.dim(); let n = self.as_borrowed().dim(); debug_assert_eq!( @@ -312,7 +312,7 @@ impl<'a> MatrixBorrowedMut<'a, 2> { /// Calculate the dot product of a and b, adding the result to self. /// /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_. - pub fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) { + pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) { let m = a.dim(); let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1; debug_assert_eq!( @@ -352,7 +352,7 @@ impl<'a> MatrixBorrowedMut<'a, 2> { /// Calculate the dot product of a and b, adding the result to self. /// /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_. - pub fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) { + pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) { let m = a.dim(); let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1; debug_assert_eq!( @@ -392,7 +392,7 @@ impl<'a> MatrixBorrowedMut<'a, 2> { /// A `D`-dimensional matrix borrowed from a [`ZeroSlice`]. #[derive(Debug, Clone, Copy)] -pub struct MatrixZero<'a, const D: usize> { +pub(super) struct MatrixZero<'a, const D: usize> { data: &'a ZeroSlice, dims: [usize; D], } @@ -426,19 +426,19 @@ impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> { impl<'a, const D: usize> MatrixZero<'a, D> { #[allow(clippy::wrong_self_convention)] // same convention as slice::to_vec - pub fn to_owned(&self) -> MatrixOwned { + pub(super) fn to_owned(&self) -> MatrixOwned { MatrixOwned { data: self.data.iter().collect(), dims: self.dims, } } - pub fn as_slice(&self) -> &ZeroSlice { + pub(super) fn as_slice(&self) -> &ZeroSlice { self.data } #[cfg(debug_assertions)] - pub fn debug_assert_dims(&self, dims: [usize; D]) { + pub(super) fn debug_assert_dims(&self, dims: [usize; D]) { debug_assert_eq!(dims, self.dims); let expected_len = dims.iter().product::(); debug_assert_eq!(expected_len, self.data.len()); @@ -446,7 +446,7 @@ impl<'a, const D: usize> MatrixZero<'a, D> { /// See [`MatrixOwned::submatrix`]. #[inline] - pub fn submatrix(&self, index: usize) -> Option> { + pub(super) fn submatrix(&self, index: usize) -> Option> { // This assertion is based on const generics; it should always succeed and be elided. assert_eq!(M, D - 1); let (range, dims) = self.submatrix_range(index); diff --git a/components/segmenter/src/complex/lstm/mod.rs b/components/segmenter/src/complex/lstm/mod.rs index 37128160194..6cb249e5457 100644 --- a/components/segmenter/src/complex/lstm/mod.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -16,7 +16,7 @@ use matrix::*; // A word break iterator using LSTM model. Input string have to be same language. -pub struct LstmSegmenterIterator<'s> { +struct LstmSegmenterIterator<'s> { input: &'s str, bies_str: Box<[Bies]>, pos: usize, @@ -39,7 +39,7 @@ impl Iterator for LstmSegmenterIterator<'_> { } } -pub struct LstmSegmenterIteratorUtf16 { +struct LstmSegmenterIteratorUtf16 { bies_str: Box<[Bies]>, pos: usize, } @@ -58,7 +58,7 @@ impl Iterator for LstmSegmenterIteratorUtf16 { } } -pub(crate) struct LstmSegmenter<'l> { +pub(super) struct LstmSegmenter<'l> { dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>, embedding: MatrixZero<'l, 2>, fw_w: MatrixZero<'l, 3>, @@ -74,7 +74,7 @@ pub(crate) struct LstmSegmenter<'l> { impl<'l> LstmSegmenter<'l> { /// Returns `Err` if grapheme data is required but not present - pub fn new( + pub(super) fn new( lstm: &'l DataPayload, grapheme: &'l DataPayload, ) -> Self { @@ -95,7 +95,7 @@ impl<'l> LstmSegmenter<'l> { } /// Create an LSTM based break iterator for an `str` (a UTF-8 string). - pub fn segment_str<'s>(&self, input: &'s str) -> LstmSegmenterIterator<'s> { + pub(super) fn segment_str<'s>(&self, input: &'s str) -> impl Iterator + 's { let lstm_output = self.produce_bies(input); LstmSegmenterIterator { input, @@ -106,7 +106,7 @@ impl<'l> LstmSegmenter<'l> { } /// Create an LSTM based break iterator for a UTF-16 string. - pub fn segment_utf16(&self, input: &[u16]) -> LstmSegmenterIteratorUtf16 { + pub(super) fn segment_utf16(&self, input: &[u16]) -> impl Iterator { let input: String = decode_utf16(input.iter().copied()) .map(|r| r.unwrap_or(REPLACEMENT_CHARACTER)) .collect(); @@ -267,7 +267,7 @@ impl<'l> LstmSegmenter<'l> { // TODO(#421): Use common BIES normalizer code #[derive(Debug, PartialEq, Copy, Clone)] -pub enum Bies { +enum Bies { B, I, E, @@ -320,21 +320,21 @@ mod tests { /// Each test case has two attributs: `unseg` which denots the unsegmented line, and `true_bies` which indicates the Bies /// sequence representing the true segmentation. #[derive(PartialEq, Debug, Deserialize)] - pub struct TestCase { - pub unseg: String, - pub expected_bies: String, - pub true_bies: String, + struct TestCase { + unseg: String, + expected_bies: String, + true_bies: String, } /// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text. #[derive(PartialEq, Debug, Deserialize)] - pub struct TestTextData { - pub testcases: Vec, + struct TestTextData { + testcases: Vec, } #[derive(Debug)] - pub struct TestText { - pub data: TestTextData, + struct TestText { + data: TestTextData, } fn load_test_text(filename: &str) -> TestTextData { From 3705319951366631cc15308415f75d5c9917c959 Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 19 Apr 2023 15:22:31 +0200 Subject: [PATCH 4/8] fix --- components/segmenter/src/complex/dictionary.rs | 10 ++-------- components/segmenter/src/complex/lstm/matrix.rs | 5 ++++- components/segmenter/src/complex/mod.rs | 1 + 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/components/segmenter/src/complex/dictionary.rs b/components/segmenter/src/complex/dictionary.rs index c3e5fe3dce1..76f3e025c4a 100644 --- a/components/segmenter/src/complex/dictionary.rs +++ b/components/segmenter/src/complex/dictionary.rs @@ -155,10 +155,7 @@ impl<'l> DictionarySegmenter<'l> { } /// Create a dictionary based break iterator for an `str` (a UTF-8 string). - pub(super) fn segment_str( - &'l self, - input: &'l str, - ) -> impl Iterator + 'l { + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator + 'l { let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme); DictionaryBreakIterator:: { trie: Char16Trie::new(self.dict.trie_data.clone()), @@ -169,10 +166,7 @@ impl<'l> DictionarySegmenter<'l> { } /// Create a dictionary based break iterator for a UTF-16 string. - pub(super) fn segment_utf16( - &'l self, - input: &'l [u16], - ) -> impl Iterator +'l { + pub(super) fn segment_utf16(&'l self, input: &'l [u16]) -> impl Iterator + 'l { let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme); DictionaryBreakIterator:: { trie: Char16Trie::new(self.dict.trie_data.clone()), diff --git a/components/segmenter/src/complex/lstm/matrix.rs b/components/segmenter/src/complex/lstm/matrix.rs index 8f97a59a7c7..faeeece1394 100644 --- a/components/segmenter/src/complex/lstm/matrix.rs +++ b/components/segmenter/src/complex/lstm/matrix.rs @@ -75,7 +75,10 @@ impl MatrixOwned { /// A mutable version of [`Self::submatrix`]. #[inline] - pub(super) fn submatrix_mut(&mut self, index: usize) -> Option> { + pub(super) fn submatrix_mut( + &mut self, + index: usize, + ) -> Option> { // This assertion is based on const generics; it should always succeed and be elided. assert_eq!(M, D - 1); let (range, dims) = self.as_borrowed().submatrix_range(index); diff --git a/components/segmenter/src/complex/mod.rs b/components/segmenter/src/complex/mod.rs index b1a730e6beb..774204b7edb 100644 --- a/components/segmenter/src/complex/mod.rs +++ b/components/segmenter/src/complex/mod.rs @@ -13,6 +13,7 @@ mod language; use language::*; #[cfg(feature = "lstm")] mod lstm; +#[cfg(feature = "lstm")] use lstm::*; #[cfg(not(feature = "lstm"))] From d620b8a0e139c4e7ccb9aaa76311217029f5c7dc Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 19 Apr 2023 15:20:09 +0200 Subject: [PATCH 5/8] segiter --- .../segmenter/src/complex/lstm/matrix.rs | 1 + components/segmenter/src/complex/lstm/mod.rs | 375 +++++++++--------- 2 files changed, 197 insertions(+), 179 deletions(-) diff --git a/components/segmenter/src/complex/lstm/matrix.rs b/components/segmenter/src/complex/lstm/matrix.rs index faeeece1394..b680278a377 100644 --- a/components/segmenter/src/complex/lstm/matrix.rs +++ b/components/segmenter/src/complex/lstm/matrix.rs @@ -207,6 +207,7 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { Some(()) } + #[allow(dead_code)] // maybe needed for more complicated bies calculations /// Mutates this matrix by applying a softmax transformation. pub(super) fn softmax_transform(&mut self) { let sm = self.data.iter().map(|v| v.exp()).sum::(); diff --git a/components/segmenter/src/complex/lstm/mod.rs b/components/segmenter/src/complex/lstm/mod.rs index 6cb249e5457..458e6e21dbf 100644 --- a/components/segmenter/src/complex/lstm/mod.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -4,8 +4,6 @@ use crate::grapheme::GraphemeClusterSegmenter; use crate::provider::*; -use alloc::boxed::Box; -use alloc::string::String; use alloc::vec::Vec; use core::char::{decode_utf16, REPLACEMENT_CHARACTER}; use icu_provider::DataPayload; @@ -18,9 +16,8 @@ use matrix::*; struct LstmSegmenterIterator<'s> { input: &'s str, - bies_str: Box<[Bies]>, - pos: usize, pos_utf8: usize, + bies: BiesIterator<'s>, } impl Iterator for LstmSegmenterIterator<'_> { @@ -29,29 +26,27 @@ impl Iterator for LstmSegmenterIterator<'_> { fn next(&mut self) -> Option { #[allow(clippy::indexing_slicing)] // pos_utf8 in range loop { - let bies = *self.bies_str.get(self.pos)?; + let is_e = self.bies.next()?; self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8(); - self.pos += 1; - if bies == Bies::E || self.pos == self.bies_str.len() { + if is_e || self.bies.len() == 0 { return Some(self.pos_utf8); } } } } -struct LstmSegmenterIteratorUtf16 { - bies_str: Box<[Bies]>, +struct LstmSegmenterIteratorUtf16<'s> { + bies: BiesIterator<'s>, pos: usize, } -impl Iterator for LstmSegmenterIteratorUtf16 { +impl Iterator for LstmSegmenterIteratorUtf16<'_> { type Item = usize; fn next(&mut self) -> Option { loop { - let bies = *self.bies_str.get(self.pos)?; self.pos += 1; - if bies == Bies::E || self.pos == self.bies_str.len() { + if self.bies.next()? || self.bies.len() == 0 { return Some(self.pos); } } @@ -67,7 +62,8 @@ pub(super) struct LstmSegmenter<'l> { bw_w: MatrixZero<'l, 3>, bw_u: MatrixZero<'l, 3>, bw_b: MatrixZero<'l, 2>, - time_w: MatrixZero<'l, 3>, + timew_fw: MatrixZero<'l, 2>, + timew_bw: MatrixZero<'l, 2>, time_b: MatrixZero<'l, 1>, grapheme: Option<&'l RuleBreakDataV1<'l>>, } @@ -79,6 +75,11 @@ impl<'l> LstmSegmenter<'l> { grapheme: &'l DataPayload, ) -> Self { let LstmDataV1::Float32(lstm) = lstm.get(); + let time_w = MatrixZero::from(&lstm.time_w); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_fw = time_w.submatrix(0).unwrap(); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_bw = time_w.submatrix(1).unwrap(); Self { dic: lstm.dic.as_borrowed(), embedding: MatrixZero::from(&lstm.embedding), @@ -88,42 +89,21 @@ impl<'l> LstmSegmenter<'l> { bw_w: MatrixZero::from(&lstm.bw_w), bw_u: MatrixZero::from(&lstm.bw_u), bw_b: MatrixZero::from(&lstm.bw_b), - time_w: MatrixZero::from(&lstm.time_w), + timew_fw, + timew_bw, time_b: MatrixZero::from(&lstm.time_b), grapheme: (lstm.model == ModelType::GraphemeClusters).then(|| grapheme.get()), } } /// Create an LSTM based break iterator for an `str` (a UTF-8 string). - pub(super) fn segment_str<'s>(&self, input: &'s str) -> impl Iterator + 's { - let lstm_output = self.produce_bies(input); - LstmSegmenterIterator { - input, - bies_str: lstm_output, - pos: 0, - pos_utf8: 0, - } + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator + 'l { + self.segment_str_p(input) } - /// Create an LSTM based break iterator for a UTF-16 string. - pub(super) fn segment_utf16(&self, input: &[u16]) -> impl Iterator { - let input: String = decode_utf16(input.iter().copied()) - .map(|r| r.unwrap_or(REPLACEMENT_CHARACTER)) - .collect(); - let lstm_output = self.produce_bies(&input); - LstmSegmenterIteratorUtf16 { - bies_str: lstm_output, - pos: 0, - } - } - - /// `produce_bies` is a function that gets a "clean" unsegmented string as its input and returns a BIES (B: Beginning, I: Inside, E: End, - /// S: Single) sequence for grapheme clusters. The boundaries of words can be found easily using this BIES sequence. - fn produce_bies(&self, input: &str) -> Box<[Bies]> { - // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later - // in the embedding layer of the model. - // Already checked that the name of the model is either "codepoints" or "graphclsut" - let input_seq: Vec = if let Some(grapheme) = self.grapheme { + // For unit testing as we cannot inspect the opaque type's bies + fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> { + let input_seq = if let Some(grapheme) = self.grapheme { GraphemeClusterSegmenter::new_and_segment_str(input, grapheme) .collect::>() .windows(2) @@ -133,8 +113,14 @@ impl<'l> LstmSegmenter<'l> { } else { unreachable!() }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + self.dic - .get_copied(UnvalidatedStr::from_str(input.get(range).unwrap_or(input))) + .get_copied(UnvalidatedStr::from_str(grapheme_cluster)) .unwrap_or_else(|| self.dic.len() as u16) }) .collect() @@ -148,162 +134,192 @@ impl<'l> LstmSegmenter<'l> { }) .collect() }; + LstmSegmenterIterator { + input, + pos_utf8: 0, + bies: BiesIterator::new(self, input_seq), + } + } - /// `compute_hc1` implemens the evaluation of one LSTM layer. - fn compute_hc<'a>( - x_t: MatrixZero<'a, 1>, - mut h_tm1: MatrixBorrowedMut<'a, 1>, - mut c_tm1: MatrixBorrowedMut<'a, 1>, - w: MatrixZero<'a, 3>, - u: MatrixZero<'a, 3>, - b: MatrixZero<'a, 2>, - ) { - #[cfg(debug_assertions)] - { - let hunits = h_tm1.dim(); - let embedd_dim = x_t.dim(); - c_tm1.as_borrowed().debug_assert_dims([hunits]); - w.debug_assert_dims([4, hunits, embedd_dim]); - u.debug_assert_dims([4, hunits, hunits]); - b.debug_assert_dims([4, hunits]); - } + /// Create an LSTM based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator + 'l { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme) + .collect::>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; - let mut s_t = b.to_owned(); - - s_t.as_mut().add_dot_3d_2(x_t, w); - s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); - - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); - - #[allow(clippy::unwrap_used)] // first dimension is 4 - c_tm1.convolve( - s_t.as_borrowed().submatrix(0).unwrap(), - s_t.as_borrowed().submatrix(2).unwrap(), - s_t.as_borrowed().submatrix(1).unwrap(), - ); + // The maximum UTF-8 size of a grapheme cluster seems to be 41 bytes + let mut i = 0; + let mut buf = [0; 41]; - #[allow(clippy::unwrap_used)] // first dimension is 4 - h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); + decode_utf16(grapheme_cluster.iter().copied()).for_each(|c| { + debug_assert!(i < 37); + i += c + .unwrap_or(REPLACEMENT_CHARACTER) + .encode_utf8(&mut buf[i..]) + .len() + }); + + self.dic + .get_copied(UnvalidatedStr::from_bytes(&buf[..i])) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + decode_utf16(input.iter().copied()) + .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER)) + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIteratorUtf16 { + bies: BiesIterator::new(self, input_seq), + pos: 0, } + } +} - let hunits = self.fw_u.dim().1; +struct BiesIterator<'l> { + segmenter: &'l LstmSegmenter<'l>, + input_seq: core::iter::Enumerate>, + h_bw: MatrixOwned<2>, + curr_fw: MatrixOwned<1>, + c_fw: MatrixOwned<1>, +} - // Forward LSTM - let mut c_fw = MatrixOwned::<1>::new_zero([hunits]); - let mut all_h_fw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); - for (i, &g_id) in input_seq.iter().enumerate() { - #[allow(clippy::unwrap_used)] - // embedding has shape (dict.len() + 1, hunit), g_id is at most dict.len() - let x_t = self.embedding.submatrix::<1>(g_id as usize).unwrap(); - if i > 0 { - all_h_fw.as_mut().copy_submatrix::<1>(i - 1, i); - } - #[allow(clippy::unwrap_used)] - compute_hc( - x_t, - all_h_fw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) - c_fw.as_mut(), - self.fw_w, - self.fw_u, - self.fw_b, - ); - } +impl<'l> BiesIterator<'l> { + // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later + // in the embedding layer of the model. + fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec) -> Self { + let hunits = segmenter.fw_u.dim().1; // Backward LSTM let mut c_bw = MatrixOwned::<1>::new_zero([hunits]); - let mut all_h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); + let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); for (i, &g_id) in input_seq.iter().enumerate().rev() { - #[allow(clippy::unwrap_used)] - // embedding has shape (dict.len() + 1, hunit), g_id is at most dict.len() - let x_t = self.embedding.submatrix::<1>(g_id as usize).unwrap(); if i + 1 < input_seq.len() { - all_h_bw.as_mut().copy_submatrix::<1>(i + 1, i); + h_bw.as_mut().copy_submatrix::<1>(i + 1, i); } #[allow(clippy::unwrap_used)] compute_hc( - x_t, - all_h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) + segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) c_bw.as_mut(), - self.bw_w, - self.bw_u, - self.bw_b, + segmenter.bw_w, + segmenter.bw_u, + segmenter.bw_b, ); } - #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) - let timew_fw = self.time_w.submatrix(0).unwrap(); - #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) - let timew_bw = self.time_w.submatrix(1).unwrap(); - - // Combining forward and backward LSTMs using the dense time-distributed layer - (0..input_seq.len()) - .map(|i| { - #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) - let curr_fw = all_h_fw.submatrix::<1>(i).unwrap(); - #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) - let curr_bw = all_h_bw.submatrix::<1>(i).unwrap(); - let mut weights = [0.0; 4]; - let mut curr_est = MatrixBorrowedMut { - data: &mut weights, - dims: [4], - }; - curr_est.add_dot_2d(curr_fw, timew_fw); - curr_est.add_dot_2d(curr_bw, timew_bw); - #[allow(clippy::unwrap_used)] // both shape (4) - curr_est.add(self.time_b).unwrap(); - curr_est.softmax_transform(); - Bies::from_probabilities(weights) - }) - .collect() + Self { + input_seq: input_seq.into_iter().enumerate(), + h_bw, + c_fw: MatrixOwned::<1>::new_zero([hunits]), + curr_fw: MatrixOwned::<1>::new_zero([hunits]), + segmenter, + } } } -// TODO(#421): Use common BIES normalizer code -#[derive(Debug, PartialEq, Copy, Clone)] -enum Bies { - B, - I, - E, - S, +impl ExactSizeIterator for BiesIterator<'_> { + fn len(&self) -> usize { + self.input_seq.len() + } } -impl Bies { - /// Returns the value the largest probability - fn from_probabilities(arr: [f32; 4]) -> Bies { - let [b, i, e, s] = arr; - let mut result = Bies::B; - let mut max = b; - if i > max { - result = Bies::I; - max = i; - } - if e > max { - result = Bies::E; - max = e; - } - if s > max { - result = Bies::S; - // max = s; - } - result +impl Iterator for BiesIterator<'_> { + type Item = bool; + + fn next(&mut self) -> Option { + let (i, g_id) = self.input_seq.next()?; + + #[allow(clippy::unwrap_used)] + compute_hc( + self.segmenter + .embedding + .submatrix::<1>(g_id as usize) + .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + self.curr_fw.as_mut(), + self.c_fw.as_mut(), + self.segmenter.fw_w, + self.segmenter.fw_u, + self.segmenter.fw_b, + ); + + #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) + let curr_bw = self.h_bw.submatrix::<1>(i).unwrap(); + let mut weights = [0.0; 4]; + let mut curr_est = MatrixBorrowedMut { + data: &mut weights, + dims: [4], + }; + curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw); + curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw); + #[allow(clippy::unwrap_used)] // both shape (4) + curr_est.add(self.segmenter.time_b).unwrap(); + // For correct BIES weight calculation we'd now have to apply softmax, however + // we're only doing a naive argmax, so a monotonic function doesn't make a difference. + + Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3]) } +} - #[cfg(test)] - fn as_char(&self) -> char { - match self { - Bies::B => 'b', - Bies::I => 'i', - Bies::E => 'e', - Bies::S => 's', - } +/// `compute_hc1` implemens the evaluation of one LSTM layer. +fn compute_hc<'a>( + x_t: MatrixZero<'a, 1>, + mut h_tm1: MatrixBorrowedMut<'a, 1>, + mut c_tm1: MatrixBorrowedMut<'a, 1>, + w: MatrixZero<'a, 3>, + u: MatrixZero<'a, 3>, + b: MatrixZero<'a, 2>, +) { + #[cfg(debug_assertions)] + { + let hunits = h_tm1.dim(); + let embedd_dim = x_t.dim(); + c_tm1.as_borrowed().debug_assert_dims([hunits]); + w.debug_assert_dims([4, hunits, embedd_dim]); + u.debug_assert_dims([4, hunits, hunits]); + b.debug_assert_dims([4, hunits]); } + + let mut s_t = b.to_owned(); + + s_t.as_mut().add_dot_3d_2(x_t, w); + s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + c_tm1.convolve( + s_t.as_borrowed().submatrix(0).unwrap(), + s_t.as_borrowed().submatrix(2).unwrap(), + s_t.as_borrowed().submatrix(1).unwrap(), + ); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); } #[cfg(test)] @@ -379,17 +395,18 @@ mod tests { }; // Testing - for test_case in test_text.data.testcases { - let lstm_output = lstm.produce_bies(&test_case.unseg); + for test_case in &test_text.data.testcases { + let lstm_output = lstm + .segment_str_p(&test_case.unseg) + .bies + .map(|is_e| if is_e { 'e' } else { '?' }) + .collect::(); println!("Test case : {}", test_case.unseg); println!("Expected bies : {}", test_case.expected_bies); - println!("Estimated bies : {lstm_output:?}"); + println!("Estimated bies : {lstm_output}"); println!("True bies : {}", test_case.true_bies); println!("****************************************************"); - assert_eq!( - test_case.expected_bies, - lstm_output.iter().map(Bies::as_char).collect::() - ); + assert_eq!(test_case.expected_bies.replace(['b','i','s'], "?"), lstm_output); } } } From a7c5cb3006881241929f294d77fb3743576ed2d1 Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Thu, 20 Apr 2023 17:25:21 +0200 Subject: [PATCH 6/8] clippy --- components/segmenter/src/complex/lstm/mod.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/components/segmenter/src/complex/lstm/mod.rs b/components/segmenter/src/complex/lstm/mod.rs index 458e6e21dbf..fdce6cea235 100644 --- a/components/segmenter/src/complex/lstm/mod.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -163,6 +163,8 @@ impl<'l> LstmSegmenter<'l> { let mut i = 0; let mut buf = [0; 41]; + #[allow(clippy::unwrap_used)] + // debug_asserting whether my assumption is correct decode_utf16(grapheme_cluster.iter().copied()).for_each(|c| { debug_assert!(i < 37); i += c @@ -171,6 +173,8 @@ impl<'l> LstmSegmenter<'l> { .len() }); + #[allow(clippy::unwrap_used)] + // debug_asserting whether my assumption is correct self.dic .get_copied(UnvalidatedStr::from_bytes(&buf[..i])) .unwrap_or_else(|| self.dic.len() as u16) @@ -406,7 +410,10 @@ mod tests { println!("Estimated bies : {lstm_output}"); println!("True bies : {}", test_case.true_bies); println!("****************************************************"); - assert_eq!(test_case.expected_bies.replace(['b','i','s'], "?"), lstm_output); + assert_eq!( + test_case.expected_bies.replace(['b', 'i', 's'], "?"), + lstm_output + ); } } } From 0bee697fbae06edf672012713430ec12f5e83f1c Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 3 May 2023 20:10:03 +0200 Subject: [PATCH 7/8] cmp iters --- components/segmenter/src/complex/lstm/mod.rs | 29 ++++++++------------ utils/zerovec/src/map/borrowed.rs | 6 ++++ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/components/segmenter/src/complex/lstm/mod.rs b/components/segmenter/src/complex/lstm/mod.rs index 2ff44e7e15f..c9f72cbd740 100644 --- a/components/segmenter/src/complex/lstm/mod.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -155,24 +155,19 @@ impl<'l> LstmSegmenter<'l> { return self.dic.len() as u16; }; - // The maximum UTF-8 size of a grapheme cluster seems to be 41 bytes - let mut i = 0; - let mut buf = [0; 41]; - - #[allow(clippy::unwrap_used)] - // debug_asserting whether my assumption is correct - decode_utf16(grapheme_cluster.iter().copied()).for_each(|c| { - debug_assert!(i < 37); - i += c - .unwrap_or(REPLACEMENT_CHARACTER) - .encode_utf8(&mut buf[i..]) - .len() - }); - - #[allow(clippy::unwrap_used)] - // debug_asserting whether my assumption is correct self.dic - .get_copied(UnvalidatedStr::from_bytes(&buf[..i])) + .get_copied_by(|key| { + key.as_bytes().iter().copied().cmp( + decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| { + let mut buf = [0; 4]; + let len = c + .unwrap_or(REPLACEMENT_CHARACTER) + .encode_utf8(&mut buf) + .len(); + buf.into_iter().take(len) + }), + ) + }) .unwrap_or_else(|| self.dic.len() as u16) }) .collect() diff --git a/utils/zerovec/src/map/borrowed.rs b/utils/zerovec/src/map/borrowed.rs index bc93ee49795..b6307990d1f 100644 --- a/utils/zerovec/src/map/borrowed.rs +++ b/utils/zerovec/src/map/borrowed.rs @@ -252,6 +252,12 @@ where self.values.get(index) } + /// For cases when `V` is fixed-size, obtain a direct copy of `V` instead of `V::ULE` + pub fn get_copied_by(&self, predicate: impl FnMut(&K) -> Ordering) -> Option { + let index = self.keys.zvl_binary_search_by(predicate).ok()?; + self.values.get(index) + } + /// Similar to [`Self::iter()`] except it returns a direct copy of the values instead of references /// to `V::ULE`, in cases when `V` is fixed-size pub fn iter_copied_values<'b>( From b8b1387a79e398577bfd217925b207b536d79f4f Mon Sep 17 00:00:00 2001 From: Robert Bastian Date: Wed, 3 May 2023 20:38:11 +0200 Subject: [PATCH 8/8] test --- provider/datagen/src/transform/segmenter/lstm.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/provider/datagen/src/transform/segmenter/lstm.rs b/provider/datagen/src/transform/segmenter/lstm.rs index d6f6de7cf44..78fd54f0835 100644 --- a/provider/datagen/src/transform/segmenter/lstm.rs +++ b/provider/datagen/src/transform/segmenter/lstm.rs @@ -235,7 +235,6 @@ mod tests { #[test] fn thai_word_break_with_grapheme_model() { - const TEST_STR: &str = "ภาษาไทยภาษาไทย"; let provider = crate::DatagenProvider::for_test(); let raw_data = provider .source @@ -249,12 +248,16 @@ mod tests { ), provider, ); + let segmenter = LineSegmenter::try_new_lstm_with_any_provider(&provider).unwrap(); + + const TEST_STR: &str = "ภาษาไทยภาษาไทย"; + let utf16: Vec = TEST_STR.encode_utf16().collect(); + let breaks: Vec = segmenter.segment_str(TEST_STR).collect(); - assert_eq!( - breaks, - [0, 6, 12, 21, 27, 33, TEST_STR.len()], - "Thai test with grapheme model" - ); + assert_eq!(breaks, [0, 6, 12, 21, 27, 33, TEST_STR.len()],); + + let breaks: Vec = segmenter.segment_utf16(&utf16).collect(); + assert_eq!(breaks, [0, 2, 4, 7, 9, 11, utf16.len()],); } }