Skip to content

Commit

Permalink
BatchFeature.to() supports non-tensor keys (huggingface#33918)
Browse files Browse the repository at this point in the history
* Fix issue in oneformer preprocessing

* [run slow] oneformer

* [run_slow] oneformer

* Make the same fixes in DQA and object detection pipelines

* Fix BatchFeature.to() instead

* Revert pipeline-specific changes

* Add the same check in Pixtral's methods

* Add the same check in BatchEncoding

* make sure torch is imported
  • Loading branch information
Rocketknight1 authored Oct 8, 2024
1 parent 3b44d2f commit fb360a6
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ def to(self, *args, **kwargs) -> "BatchFeature":
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if torch.is_floating_point(v):
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/pixtral/image_processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def to(self, *args, **kwargs) -> "BatchMixFeature":
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif torch.is_floating_point(v):
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def to(self, *args, **kwargs) -> "BatchMixFeature":
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif torch.is_floating_point(v):
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,12 +809,13 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
[`BatchEncoding`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch

# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None}
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
Expand Down

0 comments on commit fb360a6

Please sign in to comment.