Skip to content

Commit

Permalink
Quickfix for byte-producing LM compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jan 22, 2024
1 parent 09c8481 commit b24357a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def filter_rank_tokens(
) -> tuple[list[tuple[int, float, str]], float]:
indices = list(range(0, len(scores)))
token_score_tuples = sorted(zip(indices, scores, tokens), key=lambda x: abs(x[1]), reverse=True)
threshold = None
if std_threshold:
threshold = tensor(scores).mean() + std_threshold * tensor(scores).std()
token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) > threshold]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def format_context_comment(
)
for idx, score, tok in context_ranked_tokens:
context_tokens[idx] = f"[bold green]{tok}({score:.3f})[/bold green]"
cci_threshold_comment = f"(CCI > {threshold:.3f})"
cci_threshold_comment = f"(CCI > {threshold:.3f})" if threshold is not None else ""
return f"\n[bold]{context_type} context {cci_threshold_comment}:[/bold]\t{''.join(context_tokens)}"

out_string = ""
output_current_tokens = get_filtered_tokens(
output.output_current, model, args.special_tokens_to_keep, replace_special_characters=True, is_target=True
)
cti_theshold_comment = f"(CTI > {cti_threshold:.3f})"
cti_theshold_comment = f"(CTI > {cti_threshold:.3f})" if cti_threshold is not None else ""
for example_idx, cci_out in enumerate(output.cci_scores, start=1):
curr_output_tokens = output_current_tokens.copy()
cti_idx = cci_out.cti_idx
Expand Down
26 changes: 16 additions & 10 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def __init__(
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
if self.model.config.pad_token_id is not None:
self.pad_token = self.tokenizer.convert_ids_to_tokens(self.model.config.pad_token_id)
self.pad_token = self._convert_ids_to_tokens(self.model.config.pad_token_id, skip_special_tokens=False)
self.tokenizer.pad_token = self.pad_token
self.bos_token_id = getattr(self.model.config, "decoder_start_token_id", None)
if self.bos_token_id is None:
self.bos_token_id = self.model.config.bos_token_id
self.bos_token = self.tokenizer.convert_ids_to_tokens(self.bos_token_id)
self.bos_token = self._convert_ids_to_tokens(self.bos_token_id, skip_special_tokens=False)
self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
if self.eos_token_id is None:
self.eos_token_id = self.tokenizer.pad_token_id
Expand Down Expand Up @@ -277,7 +277,7 @@ def encode(
baseline_ids = torch.cat((bos_ids, baseline_ids), dim=1)
return BatchEncoding(
input_ids=batch["input_ids"],
input_tokens=[self.tokenizer.convert_ids_to_tokens(x) for x in batch["input_ids"]],
input_tokens=[self._convert_ids_to_tokens(x, skip_special_tokens=False) for x in batch["input_ids"]],
attention_mask=batch["attention_mask"],
baseline_ids=baseline_ids,
)
Expand All @@ -304,14 +304,20 @@ def embed_ids(self, ids: IdsTensor, as_targets: bool = False) -> EmbeddingsTenso
embeddings = self.get_embedding_layer()(ids)
return embeddings * self.embed_scale

def _convert_ids_to_tokens(self, ids: IdsTensor, skip_special_tokens: bool = True) -> OneOrMoreTokenSequences:
tokens = self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
if isinstance(tokens, bytes) and not isinstance(tokens, str):
return tokens.decode("utf-8")
elif isinstance(tokens, list):
return [t.decode("utf-8") if isinstance(t, bytes) else t for t in tokens]
return tokens

def convert_ids_to_tokens(
self, ids: IdsTensor, skip_special_tokens: Optional[bool] = True
) -> OneOrMoreTokenSequences:
if ids.ndim < 2:
return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
return [
self.tokenizer.convert_ids_to_tokens(id_slice, skip_special_tokens=skip_special_tokens) for id_slice in ids
]
return self._convert_ids_to_tokens(ids, skip_special_tokens)
return [self._convert_ids_to_tokens(id_slice, skip_special_tokens) for id_slice in ids]

def convert_tokens_to_ids(self, tokens: TextInput) -> OneOrMoreIdSequences:
if isinstance(tokens[0], str):
Expand All @@ -326,7 +332,7 @@ def convert_tokens_to_string(
) -> TextInput:
if isinstance(tokens, list) and len(tokens) == 0:
return ""
elif isinstance(tokens[0], str):
elif isinstance(tokens[0], (bytes, str)):
tmp_decode_state = self.tokenizer._decode_use_source_tokenizer
self.tokenizer._decode_use_source_tokenizer = not as_targets
out_strings = self.tokenizer.convert_tokens_to_string(
Expand All @@ -348,7 +354,7 @@ def convert_string_to_tokens(
text_target=text if as_targets else None,
add_special_tokens=not skip_special_tokens,
)["input_ids"]
return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens)
return self._convert_ids_to_tokens(ids, skip_special_tokens)
return [self.convert_string_to_tokens(t, skip_special_tokens, as_targets) for t in text]

def clean_tokens(
Expand All @@ -372,7 +378,7 @@ def clean_tokens(
"""
if isinstance(tokens, list) and len(tokens) == 0:
return []
elif isinstance(tokens[0], str):
elif isinstance(tokens[0], (bytes, str)):
clean_tokens = []
for tok in tokens:
clean_tok = self.convert_tokens_to_string(
Expand Down

0 comments on commit b24357a

Please sign in to comment.