diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs index 2b843273..b9abd393 100644 --- a/examples/stable-diffusion/main.rs +++ b/examples/stable-diffusion/main.rs @@ -39,7 +39,7 @@ /// // TODO: fix tensor_tools so that it works properly there. // TODO: Split this file, probably in a way similar to huggingface/diffusers. -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::io::BufRead; use tch::{kind, nn, nn::Module, Device, Kind, Tensor}; @@ -319,13 +319,15 @@ const BYTES_TO_UNICODE: [(u8, char); 256] = [ const PAT: &str = r"<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"; +type Str = Box; + // This is mostly a Rust rewrite of the original Python CLIP code. // https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py struct Tokenizer { re: regex::Regex, - encoder: HashMap, - decoder: HashMap, - bpe_ranks: HashMap<(String, String), usize>, + encoder: HashMap, + decoder: HashMap, + bpe_ranks: HashMap<(Str, Str), usize>, start_of_text_token: usize, end_of_text_token: usize, } @@ -335,31 +337,31 @@ impl Tokenizer { let bpe_file = std::fs::File::open(bpe_path)?; let bpe_lines: Result, _> = std::io::BufReader::new(bpe_file).lines().collect(); let bpe_lines = bpe_lines?; - let bpe_lines: Result, _> = bpe_lines[1..49152 - 256 - 2 + 1] + let bpe_lines: Result, _> = bpe_lines[1..49152 - 256 - 2 + 1] .iter() .map(|line| { let vs: Vec<_> = line.split_whitespace().collect(); if vs.len() != 2 { anyhow::bail!("expected two items got {} '{}'", vs.len(), line) } - Ok((vs[0].to_string(), vs[1].to_string())) + Ok((Box::from(vs[0]), Box::from(vs[1]))) }) .collect(); let bpe_lines = bpe_lines?; - let mut vocab: Vec = Vec::new(); + let mut vocab: Vec = Vec::new(); for (_index, elem) in BYTES_TO_UNICODE { - vocab.push(elem.into()) + vocab.push(String::from(elem).into()) } for (_index, elem) in BYTES_TO_UNICODE { - vocab.push(format!("{}", elem)); + vocab.push(format!("{}", elem).into()); } for elem in bpe_lines.iter() { - vocab.push(format!("{}{}", elem.0, elem.1)) + vocab.push(format!("{}{}", elem.0, elem.1).into()) } let start_of_text_token = vocab.len(); - vocab.push("<|startoftext|>".to_string()); + vocab.push("<|startoftext|>".into()); let end_of_text_token = vocab.len(); - vocab.push("<|endoftext|>".to_string()); + vocab.push("<|endoftext|>".into()); let encoder: HashMap<_, _> = vocab.into_iter().enumerate().map(|(i, v)| (v, i)).collect(); let decoder: HashMap<_, _> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); let bpe_ranks: HashMap<_, _> = @@ -370,28 +372,22 @@ impl Tokenizer { Ok(tokenizer) } - fn get_pairs(word: &[String]) -> HashSet<(String, String)> { - let mut pairs = HashSet::new(); - for (i, v) in word.iter().enumerate() { - if i > 0 { - pairs.insert((word[i - 1].clone(), v.clone())); - } - } - pairs - } - fn bpe(&self, token: &str) -> Vec { - let mut word: Vec = token.chars().map(|x| x.to_string()).collect(); + let mut word: Vec = token.chars().map(|c| c.to_string().into()).collect(); if word.is_empty() { return Vec::new(); } - let last_index = word.len() - 1; - word[last_index] = format!("{}", word[last_index]); + let mut last_word = word.pop().unwrap().into_string(); + last_word.push_str(""); + word.push(last_word.into()); + while word.len() > 1 { let mut current_min = None; - let pairs = Self::get_pairs(&word); - for p in pairs.iter() { - match self.bpe_ranks.get(p) { + for w in word.windows(2).skip(1) { + let (a, b) = (&w[0], &w[1]); + let p = (a, b); + //this clone is a little sad, but Borrow is too inflexible currently + match self.bpe_ranks.get(&(a.clone(), b.clone())) { None => {} Some(v) => { let should_replace = match current_min { @@ -404,16 +400,17 @@ impl Tokenizer { } } } - let (first, second) = match current_min { - None => break, - Some((_v, (first, second))) => (first, second), + let (first, second) = if let Some((_index, (first, second))) = current_min { + (first, second) + } else { + break; }; - let mut new_word = vec![]; + let mut new_word: Vec = vec![]; let mut index = 0; while index < word.len() { let w = &word[index]; - if index + 1 < word.len() && w == first && &word[index + 1] == second { - new_word.push(format!("{}{}", first, second)); + if index + 1 < word.len() && w == first && word[index + 1] == *second { + new_word.push(format!("{}{}", first, second).into()); index += 2 } else { new_word.push(w.clone()); @@ -447,9 +444,9 @@ impl Tokenizer { Ok(bpe_tokens) } - fn decode(&self, tokens: &[usize]) -> String { - let s: String = tokens.iter().map(|token| self.decoder[token].as_str()).collect(); - s.replace("", " ") + fn decode(&self, tokens: &[usize]) -> Str { + let s: String = tokens.iter().map(|token| &*self.decoder[token]).collect(); + s.replace("", " ").into() } }