From b24357ad4a7d74610b1a3cada095e333fc70e153 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Tue, 23 Jan 2024 00:13:10 +0100 Subject: [PATCH] Quickfix for byte-producing LM compatibility --- .../attribute_context_helpers.py | 1 + .../attribute_context_viz_helpers.py | 4 +-- inseq/models/huggingface_model.py | 26 ++++++++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py index 7736c9f9..ce9d5ca4 100644 --- a/inseq/commands/attribute_context/attribute_context_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_helpers.py @@ -289,6 +289,7 @@ def filter_rank_tokens( ) -> tuple[list[tuple[int, float, str]], float]: indices = list(range(0, len(scores))) token_score_tuples = sorted(zip(indices, scores, tokens), key=lambda x: abs(x[1]), reverse=True) + threshold = None if std_threshold: threshold = tensor(scores).mean() + std_threshold * tensor(scores).std() token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) > threshold] diff --git a/inseq/commands/attribute_context/attribute_context_viz_helpers.py b/inseq/commands/attribute_context/attribute_context_viz_helpers.py index f225bd25..4f754e13 100644 --- a/inseq/commands/attribute_context/attribute_context_viz_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_viz_helpers.py @@ -66,14 +66,14 @@ def format_context_comment( ) for idx, score, tok in context_ranked_tokens: context_tokens[idx] = f"[bold green]{tok}({score:.3f})[/bold green]" - cci_threshold_comment = f"(CCI > {threshold:.3f})" + cci_threshold_comment = f"(CCI > {threshold:.3f})" if threshold is not None else "" return f"\n[bold]{context_type} context {cci_threshold_comment}:[/bold]\t{''.join(context_tokens)}" out_string = "" output_current_tokens = get_filtered_tokens( output.output_current, model, args.special_tokens_to_keep, replace_special_characters=True, is_target=True ) - cti_theshold_comment = f"(CTI > {cti_threshold:.3f})" + cti_theshold_comment = f"(CTI > {cti_threshold:.3f})" if cti_threshold is not None else "" for example_idx, cci_out in enumerate(output.cci_scores, start=1): curr_output_tokens = output_current_tokens.copy() cti_idx = cci_out.cti_idx diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index 930507c9..51e2386c 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -113,12 +113,12 @@ def __init__( else: self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) if self.model.config.pad_token_id is not None: - self.pad_token = self.tokenizer.convert_ids_to_tokens(self.model.config.pad_token_id) + self.pad_token = self._convert_ids_to_tokens(self.model.config.pad_token_id, skip_special_tokens=False) self.tokenizer.pad_token = self.pad_token self.bos_token_id = getattr(self.model.config, "decoder_start_token_id", None) if self.bos_token_id is None: self.bos_token_id = self.model.config.bos_token_id - self.bos_token = self.tokenizer.convert_ids_to_tokens(self.bos_token_id) + self.bos_token = self._convert_ids_to_tokens(self.bos_token_id, skip_special_tokens=False) self.eos_token_id = getattr(self.model.config, "eos_token_id", None) if self.eos_token_id is None: self.eos_token_id = self.tokenizer.pad_token_id @@ -277,7 +277,7 @@ def encode( baseline_ids = torch.cat((bos_ids, baseline_ids), dim=1) return BatchEncoding( input_ids=batch["input_ids"], - input_tokens=[self.tokenizer.convert_ids_to_tokens(x) for x in batch["input_ids"]], + input_tokens=[self._convert_ids_to_tokens(x, skip_special_tokens=False) for x in batch["input_ids"]], attention_mask=batch["attention_mask"], baseline_ids=baseline_ids, ) @@ -304,14 +304,20 @@ def embed_ids(self, ids: IdsTensor, as_targets: bool = False) -> EmbeddingsTenso embeddings = self.get_embedding_layer()(ids) return embeddings * self.embed_scale + def _convert_ids_to_tokens(self, ids: IdsTensor, skip_special_tokens: bool = True) -> OneOrMoreTokenSequences: + tokens = self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) + if isinstance(tokens, bytes) and not isinstance(tokens, str): + return tokens.decode("utf-8") + elif isinstance(tokens, list): + return [t.decode("utf-8") if isinstance(t, bytes) else t for t in tokens] + return tokens + def convert_ids_to_tokens( self, ids: IdsTensor, skip_special_tokens: Optional[bool] = True ) -> OneOrMoreTokenSequences: if ids.ndim < 2: - return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) - return [ - self.tokenizer.convert_ids_to_tokens(id_slice, skip_special_tokens=skip_special_tokens) for id_slice in ids - ] + return self._convert_ids_to_tokens(ids, skip_special_tokens) + return [self._convert_ids_to_tokens(id_slice, skip_special_tokens) for id_slice in ids] def convert_tokens_to_ids(self, tokens: TextInput) -> OneOrMoreIdSequences: if isinstance(tokens[0], str): @@ -326,7 +332,7 @@ def convert_tokens_to_string( ) -> TextInput: if isinstance(tokens, list) and len(tokens) == 0: return "" - elif isinstance(tokens[0], str): + elif isinstance(tokens[0], (bytes, str)): tmp_decode_state = self.tokenizer._decode_use_source_tokenizer self.tokenizer._decode_use_source_tokenizer = not as_targets out_strings = self.tokenizer.convert_tokens_to_string( @@ -348,7 +354,7 @@ def convert_string_to_tokens( text_target=text if as_targets else None, add_special_tokens=not skip_special_tokens, )["input_ids"] - return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens) + return self._convert_ids_to_tokens(ids, skip_special_tokens) return [self.convert_string_to_tokens(t, skip_special_tokens, as_targets) for t in text] def clean_tokens( @@ -372,7 +378,7 @@ def clean_tokens( """ if isinstance(tokens, list) and len(tokens) == 0: return [] - elif isinstance(tokens[0], str): + elif isinstance(tokens[0], (bytes, str)): clean_tokens = [] for tok in tokens: clean_tok = self.convert_tokens_to_string(