Skip to content

Commit

Permalink
Store tokens in u32 instead of usize
Browse files Browse the repository at this point in the history
based on upstream commit openai@c2960c1

cc openai#251
  • Loading branch information
tmm1 committed Oct 17, 2024
1 parent 241dee1 commit f40333c
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 67 deletions.
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn main() {
fn generate(
name: &str,
file: &mut File,
mergeable_ranks: &HashMap<Vec<u8>, usize>,
mergeable_ranks: &HashMap<Vec<u8>, Rank>,
) {
writeln!(
file,
Expand Down
72 changes: 37 additions & 35 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
use std::num::NonZeroU64;
use std::thread;

use fancy_regex::Regex;
use rustc_hash::FxHashMap as HashMap;
use rustc_hash::FxHashSet as HashSet;
use std::sync::Arc;

pub type Rank = u32;

fn _byte_pair_merge<T>(
piece: &[u8],
ranks: &HashMap<Vec<u8>, usize>,
ranks: &HashMap<Vec<u8>, Rank>,
f: impl Fn(std::ops::Range<usize>) -> T,
) -> Vec<T> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();

let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| {
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
Expand All @@ -33,8 +36,8 @@ fn _byte_pair_merge<T>(
for i in 0..parts.len() - 2 {
match get_rank(&parts, i, 0) {
Some(rank) => {
// usize::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != usize::MAX);
// Rank::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != Rank::MAX);
parts[i].1 = rank;
}
None => {
Expand All @@ -57,26 +60,26 @@ fn _byte_pair_merge<T>(
break;
}

// usize::MAX is a sentinel rank value allowing us to
// Rank::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (usize, usize) = (usize::MAX, 0);
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}

if min_rank.0 != usize::MAX {
if min_rank.0 != Rank::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX);
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX);
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
}

parts.remove(i + 1);
Expand All @@ -91,14 +94,14 @@ fn _byte_pair_merge<T>(
out
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
Expand Down Expand Up @@ -146,7 +149,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.

use std::num::NonZeroU64;
pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
Expand All @@ -166,10 +168,10 @@ const MAX_NUM_THREADS: usize = 8;

#[derive(Debug)]
pub struct CoreBPE {
pub encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, &'static [u8]>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
pub encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
decoder: HashMap<Rank, &'static [u8]>,
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
regex_tls: Arc<[Regex]>,
special_regex_tls: Arc<[Regex]>,
sorted_token_bytes: Vec<&'static [u8]>,
Expand All @@ -187,7 +189,7 @@ impl CoreBPE {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}

fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for token in tokens {
let token_bytes = self
Expand All @@ -200,7 +202,7 @@ impl CoreBPE {
ret
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_tl_regex();
Expand All @@ -216,7 +218,7 @@ impl CoreBPE {
ret
}

fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();
let mut ret = vec![];
Expand Down Expand Up @@ -274,9 +276,9 @@ impl CoreBPE {

fn _increase_last_piece_token_len(
&self,
tokens: Vec<usize>,
tokens: Vec<Rank>,
mut last_piece_token_len: usize,
) -> (Vec<usize>, usize) {
) -> (Vec<Rank>, usize) {
// Unfortunately, the locations where our regex splits can be unstable.
// For the purposes of determining unstable tokens, unstable regex splitting
// is only a problem if a split that was present disappears, since this can
Expand Down Expand Up @@ -315,7 +317,7 @@ impl CoreBPE {
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
if last_piece_token_len == 0 {
// If last_piece_token_len is zero, the last token was a special token and we have
Expand Down Expand Up @@ -430,8 +432,8 @@ impl CoreBPE {

impl CoreBPE {
pub fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> Result<Self, fancy_regex::Error> {
let regex = Regex::new(pattern)?;
Expand All @@ -445,7 +447,7 @@ impl CoreBPE {
};

// Use unsafe to extend the lifetime of references to the encoder's keys
let decoder: HashMap<usize, &'static [u8]> = encoder
let decoder: HashMap<Rank, &'static [u8]> = encoder
.iter()
.map(|(k, v)| {
let bytes: &[u8] = k.as_slice();
Expand All @@ -459,7 +461,7 @@ impl CoreBPE {
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
);

let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
Expand Down Expand Up @@ -497,15 +499,15 @@ impl CoreBPE {
// Encoding
// ====================

pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
self._encode_ordinary_native(text)
}

pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<usize> {
pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<Rank> {
self._encode_native(text, &allowed_special).0
}

pub fn _encode_bytes(&self, bytes: &[u8]) -> Vec<usize> {
pub fn _encode_bytes(&self, bytes: &[u8]) -> Vec<Rank> {
match std::str::from_utf8(bytes) {
Ok(text) => self._encode_ordinary_native(text),
Err(e) => {
Expand Down Expand Up @@ -534,11 +536,11 @@ impl CoreBPE {
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
self._encode_unstable_native(text, &allowed_special)
}

pub fn encode_single_token(&self, piece: &[u8]) -> Result<usize, Vec<u8>> {
pub fn encode_single_token(&self, piece: &[u8]) -> Result<Rank, Vec<u8>> {
if let Some(token) = self.encoder.get(piece).copied() {
return Ok(token);
}
Expand All @@ -550,7 +552,7 @@ impl CoreBPE {
Err(piece.to_owned())
}

pub fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
pub fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
if let Some(token) = self.encoder.get(piece) {
return vec![*token];
}
Expand All @@ -561,11 +563,11 @@ impl CoreBPE {
// Decoding
// ====================

pub fn decode_bytes(&self, tokens: &[usize]) -> Vec<u8> {
pub fn decode_bytes(&self, tokens: &[Rank]) -> Vec<u8> {
self._decode_native(&tokens)
}

pub fn decode_single_token_bytes(&self, token: usize) -> Result<Vec<u8>, usize> {
pub fn decode_single_token_bytes(&self, token: Rank) -> Result<Vec<u8>, Rank> {
if let Some(bytes) = self.decoder.get(&token) {
return Ok(bytes.to_vec());
}
Expand Down
Loading

0 comments on commit f40333c

Please sign in to comment.