Skip to content

Commit

Permalink
Fix previous wrong conflict merge
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Aug 13, 2023
1 parent 78fcc09 commit 7a2223f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions hezar/preprocessors/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Tokenizer(Preprocessor):
"mask_token",
"additional_special_tokens",
]
uncastable_keys = ["word_ids", "tokens", "offset_mapping"]

def __init__(self, config: TokenizerConfig, **kwargs):
super().__init__(config, **kwargs)
Expand Down Expand Up @@ -117,7 +118,6 @@ def pad_encoded_batch(
Returns:
"""

if isinstance(inputs, (list, tuple)) and isinstance(inputs[0], Mapping):
inputs = {key: [example[key] for example in inputs] for key in inputs[0].keys()}

Expand Down Expand Up @@ -155,6 +155,7 @@ def pad_encoded_batch(
else:
inputs_length = max_length

skip_keys += self.uncastable_keys # avoid possible errors
inputs = convert_batch_dict_dtype(inputs, dtype="list", skip_keys=skip_keys)

skip_keys = skip_keys or []
Expand All @@ -170,7 +171,7 @@ def pad_encoded_batch(
padded_batch.append(padded_x)
inputs[key] = padded_batch

inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors)
inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors, skip_keys=skip_keys)

return inputs

Expand Down Expand Up @@ -282,7 +283,7 @@ def __call__(
for key, value in sanitized_outputs.items()
}

outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors)
outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys)
if device and return_tensors == "pt":
outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()}

Expand Down

0 comments on commit 7a2223f

Please sign in to comment.