Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unigram bytefallback #1217

Merged
merged 52 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
044fb41
current updates will go red
ArthurZucker Apr 12, 2023
f26b0b7
cargo fmt
ArthurZucker Apr 14, 2023
f8c6c47
npm install
ArthurZucker Apr 14, 2023
ac7529a
Merge branch 'main' of https://github.com/huggingface/tokenizers into…
ArthurZucker May 30, 2023
ce61a40
refactor train for unigram to allow bytefallbakc (breaking)
ArthurZucker Jun 6, 2023
b327540
fmt
ArthurZucker Jun 6, 2023
5e13667
nits
ArthurZucker Jun 6, 2023
e9e42e8
Merge branch 'main' of https://github.com/huggingface/tokenizers into…
ArthurZucker Jun 20, 2023
92b6490
update
ArthurZucker Jun 20, 2023
0fef053
add a proper test
ArthurZucker Jun 20, 2023
dfd36ff
fix encode optimised fallback + add trainer arg
ArthurZucker Jun 21, 2023
c72eac1
fixes
ArthurZucker Jun 21, 2023
3323956
fixes
ArthurZucker Jun 21, 2023
03b5be5
fix tests
ArthurZucker Jun 21, 2023
c97fd26
add test
ArthurZucker Jun 21, 2023
4375f07
fmt
ArthurZucker Jun 21, 2023
cc2f12d
fix rust test
ArthurZucker Jun 21, 2023
00e3a3d
update python bindings
ArthurZucker Jun 21, 2023
8834796
update
ArthurZucker Jun 21, 2023
005698a
pub is okay and needed
ArthurZucker Jun 21, 2023
474d31e
more fix
ArthurZucker Jun 21, 2023
c2881b2
cleanup
ArthurZucker Jun 21, 2023
ad6f524
remove useles id
ArthurZucker Jun 21, 2023
29601fb
MissingUnkId error
ArthurZucker Jun 21, 2023
739790d
nits
ArthurZucker Jun 21, 2023
fa5b6a6
fix offset
ArthurZucker Jun 21, 2023
0684008
add a test in python
ArthurZucker Jun 21, 2023
d5f37bd
update src bindings
ArthurZucker Jun 21, 2023
263376f
remove bytefallback from trainer
ArthurZucker Jun 21, 2023
60c2dc0
styling
ArthurZucker Jun 21, 2023
e01fbae
update pckg
ArthurZucker Jun 21, 2023
112d423
lint
ArthurZucker Jun 21, 2023
9320717
fmt
ArthurZucker Jun 21, 2023
6476410
stup with dev
ArthurZucker Jun 21, 2023
dc7cced
update code based on review
ArthurZucker Jun 22, 2023
41f2a7e
remove unused function
ArthurZucker Jun 22, 2023
1cd282b
udpate python test to compare ids
ArthurZucker Jun 22, 2023
949bcb4
fix option bool issues
ArthurZucker Jun 22, 2023
03aacd8
final fix
ArthurZucker Jun 22, 2023
e44b03a
clippy
ArthurZucker Jun 22, 2023
ea02db0
fix npm isntall
ArthurZucker Jun 22, 2023
5ad1d63
update
ArthurZucker Jun 22, 2023
8bf84cf
update test
ArthurZucker Jun 22, 2023
6c3ea53
more in depth testing
ArthurZucker Jun 22, 2023
a04408c
Lint
ArthurZucker Jun 22, 2023
7fc68a3
last attempt to fix node
ArthurZucker Jun 22, 2023
e1b7a33
update node bindings
ArthurZucker Jun 22, 2023
16f9619
fmt
ArthurZucker Jun 22, 2023
003d284
Update tokenizers/src/models/unigram/model.rs
ArthurZucker Jun 23, 2023
2451f8c
update based on review
ArthurZucker Jun 26, 2023
4b90431
simpler test
ArthurZucker Jun 26, 2023
58912da
lint
ArthurZucker Jun 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bindings/node/lib/bindings/models.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ export interface UnigramOptions {
* @default undefined
*/
unkId?: number;
/**
* Whether or not bytefallback support should be enabled.
* @default false
*/
byte_fallback?: boolean;
}

export namespace Unigram {
Expand Down
1 change: 1 addition & 0 deletions bindings/node/lib/bindings/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ describe("Unigram", () => {
],
{
unkId: 0,
byte_fallback: false,
}
);
expect(unigram.constructor.name).toEqual("Model");
Expand Down
7 changes: 4 additions & 3 deletions bindings/node/native/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ fn bpe_init(mut cx: FunctionContext) -> JsResult<JsModel> {
/// unkToken?: string,
/// continuingSubwordPrefix?: string,
/// endOfWordSuffix?: string
/// byteFallback?: bool
/// }, callback)
fn bpe_from_file(mut cx: FunctionContext) -> JsResult<JsUndefined> {
let (options, callback) = match cx.extract_opt::<BpeOptions>(2) {
Expand Down Expand Up @@ -369,16 +370,16 @@ fn wordlevel_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
#[serde(rename_all = "camelCase")]
struct UnigramOptions {
unk_id: Option<usize>,
byte_fallback: Option<bool>,
}

/// unigram_init(vocab: [string, number][], options?: {
/// unkId?: number
/// })
fn unigram_init(mut cx: FunctionContext) -> JsResult<JsModel> {
let vocab = cx.extract::<Vec<(String, f64)>>(0)?;
let options = cx.extract_opt::<UnigramOptions>(1)?.unwrap_or_default();

let unigram = tk::models::unigram::Unigram::from(vocab, options.unk_id)
let byte_fallback = options.byte_fallback.unwrap_or(false);
let unigram = tk::models::unigram::Unigram::from(vocab, options.unk_id, byte_fallback)
.map_err(|e| Error(e.to_string()))?;

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
Expand Down
13,594 changes: 7,108 additions & 6,486 deletions bindings/node/package-lock.json

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion bindings/node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^13.13.52",
"node-pre-gyp": "^0.14.0"
"native": "^0.3.3",
"node-pre-gyp": "^0.14.0",
"package.json": "^2.0.1"
},
"devDependencies": {
"@types/jest": "^26.0.24",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def from_spm(filename: str):
vocab = [(piece.piece, piece.score) for piece in m.pieces]
unk_id = m.trainer_spec.unk_id
model_type = m.trainer_spec.model_type
byte_fallback = m.trainer_spec.byte_fallback
if model_type != 1:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
Expand All @@ -170,7 +171,7 @@ def from_spm(filename: str):
replacement = "▁"
add_prefix_space = True

tokenizer = Tokenizer(Unigram(vocab, unk_id))
tokenizer = Tokenizer(Unigram(vocab, unk_id, byte_fallback))

tokenizer.normalizer = normalizers.Sequence(
[
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/py_src/tokenizers/models/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ class Unigram(Model):
An implementation of the Unigram algorithm

Args:
vocab (:obj:`List[Tuple[str, float]]`, `optional`):
vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
A list of vocabulary items and their relative score [("am", -0.2442),...]
"""

def __init__(self, vocab):
def __init__(self, vocab, unk_id, byte_fallback):
pass
def get_trainer(self):
"""
Expand Down
26 changes: 17 additions & 9 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,24 +804,32 @@ impl PyWordLevel {
/// An implementation of the Unigram algorithm
///
/// Args:
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`):
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
/// A list of vocabulary items and their relative score [("am", -0.2442),...]
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "Unigram")]
#[pyo3(text_signature = "(self, vocab)")]
#[pyo3(text_signature = "(self, vocab, unk_id, byte_fallback)")]
pub struct PyUnigram {}

#[pymethods]
impl PyUnigram {
#[new]
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
match (vocab, unk_id) {
(Some(vocab), unk_id) => {
let model = Unigram::from(vocab, unk_id).map_err(|e| {
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
})?;
fn new(
vocab: Option<Vec<(String, f64)>>,
unk_id: Option<usize>,
byte_fallback: Option<bool>,
) -> PyResult<(Self, PyModel)> {
match (vocab, unk_id, byte_fallback) {
(Some(vocab), unk_id, byte_fallback) => {
let model =
Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while loading Unigram: {}",
e
))
})?;
Ok((PyUnigram {}, model.into()))
}
(None, None) => Ok((PyUnigram {}, Unigram::default().into())),
(None, None, _) => Ok((PyUnigram {}, Unigram::default().into())),
_ => Err(exceptions::PyValueError::new_err(
"`vocab` and `unk_id` must be both specified",
)),
Expand Down
28 changes: 27 additions & 1 deletion bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tokenizers import AddedToken, Encoding, Tokenizer
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.models import BPE, Model, WordPiece
from tokenizers.models import BPE, Model, WordPiece, Unigram
from tokenizers.normalizers import Lowercase
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import BertProcessing, RobertaProcessing
Expand Down Expand Up @@ -412,3 +412,29 @@ def test_from_pretrained_revision(self):
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]

def test_unigram_byte_fallback(self):
vocab = [
("<unk>", 0.0),
("A", -0.01),
("sen", -0.02),
("te", -0.03),
("n", -0.04),
("ce", -0.05),
("<0xF0>", -0.06),
("<0x9F>", -0.06),
("<0xA4>", -0.06),
("<0x97>", -0.06),
(" ", -0.4),
]
tokenizer = tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))

output = tokenizer.encode("A sentence 🤗")
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 0]
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "🤗"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the tokens, please, only the ids matter.


tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))

output = tokenizer.encode("A sentence 🤗")
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, tokens are unimportant, only ids are.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm keeping them to understand what's going on in terms of fallback 😉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

83 changes: 71 additions & 12 deletions tokenizers/src/models/unigram/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct Unigram {

fuse_unk: bool,
is_optimized: bool,
byte_fallback: bool,
}
impl PartialEq for Unigram {
fn eq(&self, other: &Self) -> bool {
Expand All @@ -50,6 +51,7 @@ impl Clone for Unigram {
eos_id: self.eos_id,
fuse_unk: self.fuse_unk,
is_optimized: self.is_optimized,
byte_fallback: self.byte_fallback,
}
}
}
Expand All @@ -59,6 +61,7 @@ impl std::fmt::Debug for Unigram {
fmt.debug_struct("Unigram")
.field("vocab", &self.vocab.len())
.field("unk_id", &self.unk_id)
.field("byte_fallback", &self.byte_fallback)
.finish()
}
}
Expand All @@ -78,7 +81,7 @@ pub enum UnigramError {
impl Default for Unigram {
fn default() -> Self {
let vocab = vec![("<unk>".to_string(), 0.0)];
Self::from(vocab, Some(0)).unwrap()
Self::from(vocab, Some(0), false).unwrap()
}
}

Expand All @@ -89,7 +92,11 @@ impl Unigram {
/// unk_id, is the index within the vocabulary.
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
/// Further versions might allow that part to be hidden.
pub fn from(vocab: Vec<(String, f64)>, unk_id: Option<usize>) -> Result<Self> {
pub fn from(
vocab: Vec<(String, f64)>,
unk_id: Option<usize>,
byte_fallback: bool,
) -> Result<Self> {
let n = vocab.len();
let mut token_to_ids: TokenMap = HashMap::new();
let mut builder = TrieBuilder::default();
Expand All @@ -102,7 +109,6 @@ impl Unigram {
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
}
}

let bos_id = n + 1;
let eos_id = n + 2;

Expand Down Expand Up @@ -130,6 +136,7 @@ impl Unigram {
fuse_unk,
cache: Cache::default(),
is_optimized,
byte_fallback,
})
}

Expand All @@ -143,7 +150,9 @@ impl Unigram {
pub(super) fn set_optimized(&mut self, is_optimized: bool) {
self.is_optimized = is_optimized;
}

pub fn byte_fallback(&self) -> bool {
self.byte_fallback
}
pub(super) fn len(&self) -> usize {
self.vocab.len()
}
Expand Down Expand Up @@ -205,7 +214,7 @@ impl Unigram {
/// ("abc".to_string(), 5.0),
/// ("abcd".to_string(), 10.0),
/// ];
/// let model = Unigram::from(pieces, Some(0)).unwrap();
/// let model = Unigram::from(pieces, Some(0), false).unwrap();
/// let result = model.encode("abcdacdxx").unwrap();
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
/// ```
Expand Down Expand Up @@ -407,12 +416,31 @@ impl Model for Unigram {
let mut offset = 0;
let mut tokens = Vec::with_capacity(str_tokens.len());
for string in str_tokens {
let len = string.len();
let offsets = (offset, offset + len);
let id: u32 = match self.token_to_ids.get(&string) {
Some(id) => *id,
None => self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32,
None => {
if self.byte_fallback {
let byte_tokens: Option<Vec<_>> = string
.bytes()
.map(|byte| -> Option<Token> {
let byte_string = format!("<0x{:02X}>", byte);
let id = self.token_to_ids.get(&byte_string);
id.map(|id| Token::new(*id, byte_string, (offset, offset + len)))
})
.collect();
if let Some(byte_tokens) = byte_tokens {
for token in byte_tokens {
tokens.push(token);
}
offset += len;
continue;
}
}
self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32
}
};
let len = string.len();
let offsets = (offset, offset + len);
offset += len;
tokens.push(Token::new(id, string, offsets));
}
Expand Down Expand Up @@ -452,7 +480,7 @@ mod tests {
#[test]
fn test_populate_nodes_unk() {
let pieces = vec![("<unk>".to_string(), 0.0)];
let model = Unigram::from(pieces, Some(0)).unwrap();
let model = Unigram::from(pieces, Some(0), false).unwrap();

let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
Expand All @@ -477,7 +505,7 @@ mod tests {
("ab".to_string(), 0.3),
("bc".to_string(), 0.4),
];
let model = Unigram::from(pieces, Some(0)).unwrap();
let model = Unigram::from(pieces, Some(0), false).unwrap();

let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
Expand Down Expand Up @@ -514,7 +542,7 @@ mod tests {
("abcd".to_string(), 10.0),
];

let model = Unigram::from(sentencepieces, Some(0)).unwrap();
let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
let result = model.encode("abcd").unwrap();
assert_eq!(result, vec!["abcd"]);
}
Expand All @@ -536,7 +564,7 @@ mod tests {
("qr".to_string(), -0.5),
];

let mut model = Unigram::from(sentencepieces, Some(0)).unwrap();
let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();

for is_optimized in &[true, false] {
model.set_optimized(*is_optimized);
Expand Down Expand Up @@ -573,4 +601,35 @@ mod tests {
assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
}
}

#[test]
fn test_unigram_bytefallback() {
// In [97]: processor.encode_as_pieces("⅐⅛⅑ ")
// Out[97]: ['▁', '<0xE2>', '<0x85>', '<0x90>', '⅛', '<0xE2>', '<0x85>', '<0x91>', '▁']
let sentencepieces = vec![
("<unk>".to_string(), 0.0),
("<0xC3>".to_string(), -0.01),
("<0xA9>".to_string(), -0.03),
];
let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
assert_eq!(
tokens,
[
Token {
id: 1,
value: "<0xC3>".to_string(),
offsets: (0, 2)
},
Token {
id: 2,
value: "<0xA9>".to_string(),
offsets: (0, 2)
}
]
);

let tokens = unigram.tokenize("?é").unwrap();
assert_eq!(tokens[0].id, 0);
}
}
Loading