diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index 3723fe20..6d5db78d 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -60,6 +60,7 @@ class Tokenizer(Preprocessor): "mask_token", "additional_special_tokens", ] + uncastable_keys = ["word_ids", "tokens", "offset_mapping"] def __init__(self, config: TokenizerConfig, **kwargs): super().__init__(config, **kwargs) @@ -117,7 +118,6 @@ def pad_encoded_batch( Returns: """ - if isinstance(inputs, (list, tuple)) and isinstance(inputs[0], Mapping): inputs = {key: [example[key] for example in inputs] for key in inputs[0].keys()} @@ -155,6 +155,7 @@ def pad_encoded_batch( else: inputs_length = max_length + skip_keys += self.uncastable_keys # avoid possible errors inputs = convert_batch_dict_dtype(inputs, dtype="list", skip_keys=skip_keys) skip_keys = skip_keys or [] @@ -170,7 +171,7 @@ def pad_encoded_batch( padded_batch.append(padded_x) inputs[key] = padded_batch - inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors) + inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors, skip_keys=skip_keys) return inputs @@ -282,7 +283,7 @@ def __call__( for key, value in sanitized_outputs.items() } - outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors) + outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys) if device and return_tensors == "pt": outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()}