Skip to content

Commit

Permalink
Populate torch_dtype from model to pipeline (#28940)
Browse files Browse the repository at this point in the history
* Populate torch_dtype from model to pipeline

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* use property

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* lint

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Remove default handling

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 authored Mar 25, 2024
1 parent afe73ae commit 8e9a220
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def __init__(
raise ValueError(f"{device} unrecognized or not available.")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype

self.binary_output = binary_output

# We shouldn't call `model.to()` for models loaded with accelerate
Expand Down Expand Up @@ -964,6 +964,13 @@ def predict(self, X):
"""
return self(X)

@property
def torch_dtype(self) -> Optional["torch.dtype"]:
"""
Torch dtype of the model (if it's Pytorch model), `None` otherwise.
"""
return getattr(self.model, "dtype", None)

@contextmanager
def device_placement(self):
"""
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,29 @@ def test_unbatch_attentions_hidden_states(self):
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
self.assertEqual(len(outputs), 20)

@require_torch
def test_torch_dtype_property(self):
import torch

model_id = "hf-internal-testing/tiny-random-distilbert"

# If dtype is specified in the pipeline constructor, the property should return that type
pipe = pipeline(model=model_id, torch_dtype=torch.float16)
self.assertEqual(pipe.torch_dtype, torch.float16)

# If the underlying model changes dtype, the property should return the new type
pipe.model.to(torch.bfloat16)
self.assertEqual(pipe.torch_dtype, torch.bfloat16)

# If dtype is NOT specified in the pipeline constructor, the property should just return
# the dtype of the underlying model (default)
pipe = pipeline(model=model_id)
self.assertEqual(pipe.torch_dtype, torch.float32)

# If underlying model doesn't have dtype property, simply return None
pipe.model = None
self.assertIsNone(pipe.torch_dtype)


@is_pipeline_test
class PipelineScikitCompatTest(unittest.TestCase):
Expand Down

0 comments on commit 8e9a220

Please sign in to comment.