-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) #239
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
// This check is new and seems buggy (possibly with PyO3 interaction) | ||
#![allow(clippy::borrow_deref_ref)] | ||
|
||
use std::collections::HashSet; | ||
use std::collections::{BTreeMap, BTreeSet, HashSet}; | ||
use std::iter::successors; | ||
use std::num::NonZeroU64; | ||
use std::thread; | ||
|
||
|
@@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap; | |
|
||
type Rank = u32; | ||
|
||
const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500; | ||
|
||
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { | ||
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT { | ||
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quadtratic algo is usually faster for very small words - which is always the case for natural language, but e.g. DNA sequences or a DOS attack can be avoided by switching to the linearithmic algo |
||
} else { | ||
_byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy | ||
} | ||
} | ||
|
||
fn _byte_pair_merge_small(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { | ||
// This is a vector of (start, rank). | ||
// The rank is of the pair starting at position start. | ||
let mut parts = Vec::with_capacity(piece.len() + 1); | ||
|
@@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, | |
parts | ||
} | ||
|
||
fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { | ||
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. grouped by rank, the values ordered by index (basically a LinkedHashSet inside) |
||
let mut index_rank = vec![Rank::MAX; piece.len() + 1]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mutations seemed easier this way, compared to creating a struct with index/rank/prev/next - especially in Rust |
||
let mut index_prev = vec![usize::MAX; piece.len() + 1]; | ||
let mut index_next = vec![usize::MAX; piece.len() + 1]; | ||
|
||
let get_rank = |start_idx: usize, end_idx: usize| -> Rank { | ||
*piece.get(start_idx..end_idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when |
||
.and_then(|p| ranks.get(p)) | ||
.unwrap_or(&Rank::MAX) | ||
}; | ||
|
||
let mut prev_node = None; | ||
for i in 0..=piece.len() { | ||
let rank = get_rank(i, i + 2); | ||
index_rank[i] = rank; | ||
if let Some(prev) = prev_node { | ||
index_prev[i] = prev; | ||
index_next[prev] = i; | ||
} | ||
prev_node = Some(i); | ||
|
||
rank_indexes.entry(rank).or_default().insert(i); | ||
} | ||
|
||
while rank_indexes.len() > 1 { | ||
let mut skip_next = false; | ||
if let Some((_, nodes)) = rank_indexes.pop_first() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. next min is popped off in logarithmic time instead of linearly |
||
for &min_node in &nodes { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicates are processed in bulk (since the next min is strictly greater than equal), no need to remove them one-by-one |
||
if skip_next { | ||
skip_next = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when merging neighboring elements with the same ranks |
||
continue; | ||
} | ||
|
||
let min_rank = index_rank[min_node]; | ||
|
||
let prev_node = index_prev[min_node]; | ||
let next_node = index_next[min_node]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getting next and previous requires lookups now |
||
let next_next_node = index_next[next_node]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @l0rinc I think these lines would panic (out-of-range) if your min_node is close to an end, how do you know you have 3 nodes to the right of it? Your last node's next will be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind, I'm misreading this. The only one that can be "none" is the next_next_next_node; but then it's usize::MAX and the |
||
let next_next_next_node = index_next[next_next_node]; | ||
Comment on lines
+123
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
if prev_node != usize::MAX { | ||
let new_rank = get_rank(prev_node, next_next_node); | ||
if index_rank[prev_node] != new_rank { | ||
rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node); | ||
index_rank[prev_node] = new_rank; | ||
rank_indexes.entry(new_rank).or_default().insert(prev_node); | ||
} | ||
} | ||
|
||
let new_rank = get_rank(min_node, next_next_next_node); | ||
index_rank[min_node] = new_rank; | ||
rank_indexes.entry(new_rank).or_default().insert(min_node); | ||
|
||
index_next[min_node] = next_next_node; | ||
index_prev[next_next_node] = min_node; | ||
|
||
let next_node_rank = index_rank[next_node]; | ||
if next_node_rank == min_rank { | ||
skip_next = true; | ||
} else if next_node_rank != Rank::MAX { | ||
rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node); | ||
} | ||
} | ||
} | ||
} | ||
|
||
successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iterate until there's a valid rank |
||
.map(|n| (n, Rank::MAX)) | ||
.collect() | ||
} | ||
|
||
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> { | ||
assert!(piece.len() > 1); | ||
_byte_pair_merge(&ranks, &piece) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,13 +61,16 @@ def test_simple_regex(): | |
def test_basic_encode(): | ||
enc = tiktoken.get_encoding("r50k_base") | ||
assert enc.encode("hello world") == [31373, 995] | ||
assert enc.encode("a" * 1000) == [24794] * 250 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to cover the big encoder as well |
||
|
||
enc = tiktoken.get_encoding("p50k_base") | ||
assert enc.encode("hello world") == [31373, 995] | ||
assert enc.encode("a" * 1000) == [24794] * 250 | ||
|
||
enc = tiktoken.get_encoding("cl100k_base") | ||
assert enc.encode("hello world") == [15339, 1917] | ||
assert enc.encode(" \x850") == [220, 126, 227, 15] | ||
assert enc.encode("a" * 1000) == [70540] * 125 | ||
|
||
|
||
def test_encode_empty(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the Java version this was controlled by an environmental variable, which enabled us to run all tests against both implementations - should I do it here as well?