From f2a44dc5d1d77ef358820e2ccf822428efc67e30 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 12 Jul 2024 07:29:40 +0200 Subject: [PATCH] =?UTF-8?q?Revert=20"[BREAKING=20CHANGE]=20Ignore=20added?= =?UTF-8?q?=5Ftokens=20(both=20special=20and=20not)=20=E2=80=A6=20(#1569)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Revert "[BREAKING CHANGE] Ignore added_tokens (both special and not) in the decoder (#1513)" This reverts commit 25aee8b88c8de3c5a52e2f9cb6281d6df00ad516. * don't remove audit * deprecate id_to_token * use simple id to token * don't break id_to_token since we are deprecating anyways? --- tokenizers/src/tokenizer/mod.rs | 44 ++++++++++++--------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ebc68dfb1..b0836ca3c 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -847,35 +847,23 @@ where /// Decode the given ids, back to a String pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result { - let mut result = String::with_capacity(ids.len()); - let mut chunks = Vec::with_capacity(ids.len()); - for id in ids { - if let Some(added_token) = self.added_vocabulary.simple_id_to_token(*id) { - if skip_special_tokens && self.added_vocabulary.is_special_token(&added_token) { - continue; - } - let text_chunk = if let Some(decoder) = &self.decoder { - decoder.decode(chunks.clone())? - } else { - chunks.join(" ") - }; - result.push_str(&text_chunk); - if !result.is_empty() && self.decoder.is_none() { - result.push(' '); - } - result.push_str(&added_token); - chunks.clear(); - } else if let Some(token) = self.model.id_to_token(*id) { - chunks.push(token); - } - } - let text_chunk = if let Some(decoder) = &self.decoder { - decoder.decode(chunks.clone())? + let tokens = ids + .iter() + .filter_map(|id| { + self.added_vocabulary + .simple_id_to_token(*id) + .or_else(|| self.model.id_to_token(*id)) + .filter(|token| { + !skip_special_tokens || !self.added_vocabulary.is_special_token(token) + }) + }) + .collect::>(); + + if let Some(decoder) = &self.decoder { + decoder.decode(tokens) } else { - chunks.join(" ") - }; - result.push_str(&text_chunk); - Ok(result) + Ok(tokens.join(" ")) + } } }