From b2b20485ff3de031b322b0dcada0754b5d310689 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Wed, 16 Feb 2022 06:10:15 +0000 Subject: [PATCH] Abstract regex and add `fancy_regex` backend --- .gitignore | 1 + tokenizers/Cargo.toml | 10 +- tokenizers/src/normalizers/replace.rs | 6 +- tokenizers/src/pre_tokenizers/byte_level.rs | 5 +- tokenizers/src/pre_tokenizers/split.rs | 6 +- tokenizers/src/tokenizer/pattern.rs | 16 +- tokenizers/src/utils/mod.rs | 1 + tokenizers/src/utils/regex.rs | 497 ++++++++++++++++++++ 8 files changed, 523 insertions(+), 19 deletions(-) create mode 100644 tokenizers/src/utils/regex.rs diff --git a/.gitignore b/.gitignore index f965b105c..04791633e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .vim .env +.venv target .idea Cargo.lock diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 1ca9ba0b4..caaa6b0c8 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -36,7 +36,6 @@ harness = false [dependencies] lazy_static = "1.4" rand = "0.7" -onig = { version = "6.0", default-features = false } regex = "1.3" regex-syntax = "0.6" rayon = "1.3" @@ -59,11 +58,18 @@ cached-path = { version = "0.5", optional = true } aho-corasick = "0.7" paste = "1.0.6" proc_macros = { path = "./src/utils/proc_macros" } +once_cell = "1.8" +cfg-if = "1" +onig = { version = "6.0", default-features = false, optional = true } +fancy-regex = { version = "0.7", optional = true } [features] -default = ["progressbar", "http"] +default = ["progressbar", "http", "regex-onig"] progressbar = ["indicatif"] http = ["reqwest", "cached-path"] +regex-fancy = ["fancy-regex"] +regex-onig = ["onig"] +regex-all-test = ["regex-onig", "regex-fancy"] [dev-dependencies] criterion = "0.3" diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index e9efafb13..9d01a4427 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use onig::Regex; +use crate::utils::regex::Regex; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use @@ -65,8 +65,8 @@ impl Replace { pub fn new, C: Into>(pattern: I, content: C) -> Result { let pattern: ReplacePattern = pattern.into(); let regex = match &pattern { - ReplacePattern::String(s) => Regex::new(®ex::escape(s))?, - ReplacePattern::Regex(r) => Regex::new(r)?, + ReplacePattern::String(s) => Regex::new(regex::escape(s)), + ReplacePattern::Regex(r) => Regex::new(r.to_owned()), }; Ok(Self { diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 2a8a811e9..1b7994d34 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; -use onig::Regex; +use crate::utils::regex::Regex; use serde::{Deserialize, Serialize}; use crate::tokenizer::{ @@ -34,8 +34,7 @@ fn bytes_char() -> HashMap { lazy_static! { static ref RE: Regex = - Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+") - .unwrap(); + Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".to_string()); static ref BYTES_CHAR: HashMap = bytes_char(); static ref CHAR_BYTES: HashMap = bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 32ea25a8c..bc83c1e8e 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -1,4 +1,4 @@ -use onig::Regex; +use crate::utils::regex::Regex; use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::{ @@ -80,8 +80,8 @@ impl Split { ) -> Result { let pattern: SplitPattern = pattern.into(); let regex = match &pattern { - SplitPattern::String(s) => Regex::new(®ex::escape(s))?, - SplitPattern::Regex(r) => Regex::new(r)?, + SplitPattern::String(s) => Regex::new(regex::escape(s)), + SplitPattern::Regex(r) => Regex::new(r.to_owned()), }; Ok(Self { diff --git a/tokenizers/src/tokenizer/pattern.rs b/tokenizers/src/tokenizer/pattern.rs index af090df16..4a1c40c5d 100644 --- a/tokenizers/src/tokenizer/pattern.rs +++ b/tokenizers/src/tokenizer/pattern.rs @@ -59,7 +59,7 @@ impl Pattern for &Regex { } } -impl Pattern for &onig::Regex { +impl Pattern for &crate::utils::regex::Regex { fn find_matches(&self, inside: &str) -> Result> { if inside.is_empty() { return Ok(vec![((0, 0), false)]); @@ -67,12 +67,12 @@ impl Pattern for &onig::Regex { let mut prev = 0; let mut splits = Vec::with_capacity(inside.len()); - for (start, end) in self.find_iter(inside) { - if prev != start { - splits.push(((prev, start), false)); + for m in self.find_iter(inside) { + if prev != m.start() { + splits.push(((prev, m.start()), false)); } - splits.push(((start, end), true)); - prev = end; + splits.push(((m.start(), m.end()), true)); + prev = m.end(); } if prev != inside.len() { splits.push(((prev, inside.len()), false)) @@ -205,8 +205,8 @@ mod tests { } #[test] - fn onig_regex() { - let is_whitespace = onig::Regex::new(r"\s+").unwrap(); + fn abstract_regex() { + let is_whitespace = crate::utils::regex::Regex::new(r"\s+".to_string()); do_test!("a b", &is_whitespace => vec![((0, 1), false), ((1, 4), true), ((4, 5), false)]); do_test!(" a b ", &is_whitespace => vec![((0, 3), true), ((3, 4), false), ((4, 7), true), ((7, 8), false), ((8, 11), true)] diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index e80bf75ee..b14f172c1 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -6,6 +6,7 @@ pub mod padding; pub mod parallelism; pub(crate) mod progress; pub mod truncation; +pub mod regex; use serde::{Serialize, Serializer}; use std::collections::{BTreeMap, HashMap}; diff --git a/tokenizers/src/utils/regex.rs b/tokenizers/src/utils/regex.rs new file mode 100644 index 000000000..9d4ae39a5 --- /dev/null +++ b/tokenizers/src/utils/regex.rs @@ -0,0 +1,497 @@ +// All code and comments below this first line are copied from here (with some edits): https://github.com/bminixhofer/nlprule/blob/main/nlprule/src/utils/regex.rs + + + +//! An abstraction on top of Regular Expressions to add support for Serialization and +//! to modularize the Regex backend. +//! Adapts the approach from https://github.com/trishume/syntect/pull/270 with feature flags for the +//! different backends. + +use once_cell::sync::OnceCell; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::hash::{Hash, Hasher}; + +pub use regex_impl::{CaptureMatches, Captures, Match, Matches}; + +#[derive(Debug)] +pub struct Regex { + regex_str: String, + regex: OnceCell, +} + +impl Clone for Regex { + fn clone(&self) -> Self { + Regex { + regex_str: self.regex_str.clone(), + regex: OnceCell::new(), + } + } +} + +impl Serialize for Regex { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.regex_str) + } +} + +impl<'de> Deserialize<'de> for Regex { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let regex_str = String::deserialize(deserializer)?; + Ok(Regex::new(regex_str)) + } +} + +impl Hash for Regex { + fn hash(&self, state: &mut H) { + self.regex_str.hash(state); + } +} + +impl Regex { + /// Create a new regex from the pattern string. + /// + /// Note that the regex compilation happens on first use, which is why this method does not + /// return a result. + pub fn new(regex_str: String) -> Self { + Self { + regex_str, + regex: OnceCell::new(), + } + } + + /// Check whether the pattern compiles as a valid regex. + #[allow(dead_code)] // used only in compile module + pub fn try_compile(&self) -> Result<(), Box> { + regex_impl::Regex::new(&self.regex_str).map(|_| ()) + } + + fn regex(&self) -> ®ex_impl::Regex { + self.regex.get_or_init(|| { + regex_impl::Regex::new(&self.regex_str) + .unwrap_or_else(|_| panic!("regex string should be pre-tested: {}", self.regex_str)) + }) + } + + pub fn is_match(&self, text: &str) -> bool { + self.regex().is_match(text) + } + + pub fn captures_iter<'r, 't>(&'r self, text: &'t str) -> regex_impl::CaptureMatches<'r, 't> { + self.regex().captures_iter(text) + } + + pub fn find_iter<'r, 't>(&'r self, text: &'t str) -> regex_impl::Matches<'r, 't> { + self.regex().find_iter(text) + } + + #[allow(dead_code)] // used only in compile module + pub fn captures_len(&self) -> usize { + self.regex().captures_len() + } + + pub fn captures<'t>(&self, text: &'t str) -> Option> { + self.regex().captures(text) + } + + pub fn replace_all(&self, text: &str, replacement: &str) -> String { + self.regex().replace_all(text, replacement) + } +} + +#[cfg(feature = "regex-fancy")] +#[allow(dead_code)] +mod regex_impl_fancy { + pub use fancy_regex::{Captures, Match}; + use std::error::Error; + pub struct Matches<'r, 't>(fancy_regex::Matches<'r, 't>); + + impl<'r, 't> Iterator for Matches<'r, 't> { + type Item = Match<'t>; + + fn next(&mut self) -> Option { + match self.0.next() { + Some(Ok(mat)) => Some(mat), + // stop if an error is encountered + None | Some(Err(_)) => None, + } + } + } + + pub struct CaptureMatches<'r, 't>(fancy_regex::CaptureMatches<'r, 't>); + + impl<'r, 't> Iterator for CaptureMatches<'r, 't> { + type Item = Captures<'t>; + + fn next(&mut self) -> Option { + match self.0.next() { + Some(Ok(caps)) => Some(caps), + // stop if an error is encountered + None | Some(Err(_)) => None, + } + } + } + + #[derive(Debug)] + pub struct Regex { + regex: fancy_regex::Regex, + } + + impl Regex { + pub fn new(regex_str: &str) -> Result> { + Ok(Regex { + regex: fancy_regex::Regex::new(regex_str)?, + }) + } + + pub fn is_match(&self, text: &str) -> bool { + // errors are treated as non-matches + self.regex.is_match(text).unwrap_or(false) + } + + pub fn captures_iter<'r, 't>(&'r self, text: &'t str) -> CaptureMatches<'r, 't> { + CaptureMatches(self.regex.captures_iter(text)) + } + + pub fn find_iter<'r, 't>(&'r self, text: &'t str) -> Matches<'r, 't> { + Matches(self.regex.find_iter(text)) + } + + pub fn captures_len(&self) -> usize { + self.regex.captures_len() + } + + pub fn captures<'t>(&self, text: &'t str) -> Option> { + match self.regex.captures(text) { + Ok(Some(captures)) => Some(captures), + // errors treated as not matching + Ok(None) | Err(_) => None, + } + } + + pub fn replace_all<'t>(&self, text: &'t str, replacement: &str) -> String { + let mut index = 0; + let mut out: Vec = Vec::new(); + + for captures in self.captures_iter(text) { + let mat = captures.get(0).expect("0th capture group exists"); + + out.push(text[index..mat.start()].to_string()); + + let mut replacement = replacement.to_string(); + for i in 1..captures.len() { + replacement = replacement.replace( + &format!("${}", i), + captures.get(i).map_or("", |x| x.as_str()), + ); + } + + out.push(replacement); + index = mat.end(); + } + + if index != text.len() { + out.push(text[index..].to_string()); + } + + out.join("") + } + } +} + +#[cfg(feature = "regex-onig")] +#[allow(dead_code)] +mod regex_impl_onig { + use std::error::Error; + + pub struct CaptureMatches<'r, 't>(onig::FindCaptures<'r, 't>, &'t str); + + #[derive(Debug)] + pub struct Captures<'t>(onig::Captures<'t>, &'t str); + + pub struct Matches<'r, 't>(onig::FindMatches<'r, 't>, &'t str); + + #[derive(Debug)] + pub struct Match<'t> { + text: &'t str, + start: usize, + end: usize, + } + + impl<'t> Match<'t> { + pub fn start(&self) -> usize { + self.start + } + + pub fn end(&self) -> usize { + self.end + } + + pub fn as_str(&self) -> &'t str { + self.text + } + } + + impl<'t> Captures<'t> { + pub fn get(&self, index: usize) -> Option> { + let (start, end) = self.0.pos(index)?; + let text = self.0.at(index)?; + + Some(Match { text, start, end }) + } + + pub fn iter(&'t self) -> impl Iterator>> { + self.0.iter_pos().map(move |mat| { + mat.map(|(start, end)| Match { + text: &self.1[start..end], + start, + end, + }) + }) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn len(&self) -> usize { + self.0.len() + } + } + + impl<'r, 't> Iterator for CaptureMatches<'r, 't> { + type Item = Captures<'t>; + + fn next(&mut self) -> Option { + self.0.next().map(|x| Captures(x, self.1)) + } + } + + impl<'r, 't> Iterator for Matches<'r, 't> { + type Item = Match<'t>; + + fn next(&mut self) -> Option { + self.0.next().map(|(start, end)| Match { + text: &self.1[start..end], + start, + end, + }) + } + } + + #[derive(Debug)] + pub struct Regex { + regex: onig::Regex, + } + + impl Regex { + pub fn new(regex_str: &str) -> Result> { + let mut case_sensitive = true; + let regex_str = if let Some(stripped) = regex_str.strip_suffix("(?i)") { + case_sensitive = false; + stripped + } else { + regex_str + }; + + let regex = onig::Regex::with_options( + regex_str, + onig::RegexOptions::REGEX_OPTION_CAPTURE_GROUP + | if case_sensitive { + onig::RegexOptions::REGEX_OPTION_NONE + } else { + onig::RegexOptions::REGEX_OPTION_IGNORECASE + }, + onig::Syntax::default(), + )?; + + Ok(Regex { regex }) + } + + pub fn is_match(&self, text: &str) -> bool { + self.regex.is_match(text) + } + + pub fn captures_iter<'r, 't>(&'r self, text: &'t str) -> CaptureMatches<'r, 't> { + CaptureMatches(self.regex.captures_iter(text), text) + } + + pub fn find_iter<'r, 't>(&'r self, text: &'t str) -> Matches<'r, 't> { + Matches(self.regex.find_iter(text), text) + } + + pub fn captures_len(&self) -> usize { + self.regex.captures_len() + 1 + } + + pub fn captures<'t>(&self, text: &'t str) -> Option> { + self.regex.captures(text).map(|x| Captures(x, text)) + } + + pub fn replace_all(&self, text: &str, replacement: &str) -> String { + self.regex.replace_all(&text, |caps: &onig::Captures| { + let mut replacement = replacement.to_owned(); + + for i in 1..caps.len() { + replacement = replacement.replace(&format!("${}", i), caps.at(i).unwrap_or("")); + } + + replacement + }) + } + } +} + +#[cfg(feature = "regex-all-test")] +mod regex_impl_all { + //! This backend is only used for testing. It uses all other backends and assert they do the same thing. + + use super::{regex_impl_fancy as impl_fancy, regex_impl_onig as impl_onig}; + pub use impl_fancy::{CaptureMatches, Captures, Match, Matches}; + use itertools::{EitherOrBoth, Itertools}; + use std::error::Error; + + macro_rules! option_eq { + ($a:expr, $b:expr) => { + match (&$a, &$b) { + (&Some(ref lhs), &Some(ref rhs)) if lhs == rhs => true, + (&None, &None) => true, + _ => false, + } + }; + } + + impl<'t> PartialEq> for impl_onig::Match<'t> { + fn eq(&self, other: &impl_fancy::Match<'t>) -> bool { + self.start() == other.start() + && self.end() == other.end() + && self.as_str() == other.as_str() + } + } + + impl<'t> PartialEq> for impl_onig::Captures<'t> { + fn eq(&self, other: &impl_fancy::Captures<'t>) -> bool { + self.len() == other.len() + && self.iter().zip(other.iter()).all(|(a, b)| option_eq!(a, b)) + } + } + + #[derive(Debug)] + pub struct Regex { + fancy_regex: impl_fancy::Regex, + onig_regex: impl_onig::Regex, + } + + impl Regex { + pub fn new(regex_str: &str) -> Result> { + let fancy_regex = impl_fancy::Regex::new(regex_str); + let onig_regex = impl_onig::Regex::new(regex_str); + + Ok(Regex { + fancy_regex: fancy_regex?, + onig_regex: onig_regex?, + }) + } + + pub fn is_match(&self, text: &str) -> bool { + let match_fancy = self.fancy_regex.is_match(text); + + assert_eq!( + match_fancy, + self.onig_regex.is_match(text), + "{} {:?}", + text, + self.fancy_regex + ); + match_fancy + } + + pub fn captures_iter<'r, 't>(&'r self, text: &'t str) -> CaptureMatches<'r, 't> { + assert!( + self.fancy_regex + .captures_iter(text) + .zip_longest(self.onig_regex.captures_iter(text)) + .all(|elem| { + if let EitherOrBoth::Both(a, b) = elem { + b == a + } else { + false + } + }), + "{:?}", + self.fancy_regex + ); + + self.fancy_regex.captures_iter(text) + } + + pub fn find_iter<'r, 't>(&'r self, text: &'t str) -> Matches<'r, 't> { + assert!( + self.fancy_regex + .find_iter(text) + .zip_longest(self.onig_regex.find_iter(text)) + .all(|elem| { + if let EitherOrBoth::Both(a, b) = elem { + b == a + } else { + false + } + }), + "{:?}", + self.fancy_regex + ); + + self.fancy_regex.find_iter(text) + } + + pub fn captures_len(&self) -> usize { + let out = self.fancy_regex.captures_len(); + assert_eq!( + out, + self.onig_regex.captures_len(), + "{:?}", + self.fancy_regex + ); + out + } + + pub fn captures<'t>(&self, text: &'t str) -> Option> { + let out = self.fancy_regex.captures(text); + let onig_out = self.onig_regex.captures(text); + assert!( + option_eq!(onig_out, out), + "{:?}: Fancy: {:#?}, Onig: {:#?}", + self.fancy_regex, + out, + onig_out + ); + out + } + + pub fn replace_all(&self, text: &str, replacement: &str) -> String { + let out = self.fancy_regex.replace_all(text, replacement); + assert_eq!( + out, + self.onig_regex.replace_all(text, replacement), + "{:?}", + self.fancy_regex + ); + out + } + } +} + +cfg_if::cfg_if! { + if #[cfg(feature = "regex-all-test")] { + use regex_impl_all as regex_impl; + } else if #[cfg(feature = "regex-onig")] { + use regex_impl_onig as regex_impl; + } else { + use regex_impl_fancy as regex_impl; + } +}