diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index c64ee7b749e876..4d85a783c8c638 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -861,7 +861,7 @@ def __init__( raise ValueError(f"{device} unrecognized or not available.") else: self.device = device if device is not None else -1 - self.torch_dtype = torch_dtype + self.binary_output = binary_output # We shouldn't call `model.to()` for models loaded with accelerate @@ -964,6 +964,13 @@ def predict(self, X): """ return self(X) + @property + def torch_dtype(self) -> Optional["torch.dtype"]: + """ + Torch dtype of the model (if it's Pytorch model), `None` otherwise. + """ + return getattr(self.model, "dtype", None) + @contextmanager def device_placement(self): """ diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 5e3e15f39c10ea..13b97aff3216b5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -199,6 +199,29 @@ def test_unbatch_attentions_hidden_states(self): outputs = text_classifier(["This is great !"] * 20, batch_size=32) self.assertEqual(len(outputs), 20) + @require_torch + def test_torch_dtype_property(self): + import torch + + model_id = "hf-internal-testing/tiny-random-distilbert" + + # If dtype is specified in the pipeline constructor, the property should return that type + pipe = pipeline(model=model_id, torch_dtype=torch.float16) + self.assertEqual(pipe.torch_dtype, torch.float16) + + # If the underlying model changes dtype, the property should return the new type + pipe.model.to(torch.bfloat16) + self.assertEqual(pipe.torch_dtype, torch.bfloat16) + + # If dtype is NOT specified in the pipeline constructor, the property should just return + # the dtype of the underlying model (default) + pipe = pipeline(model=model_id) + self.assertEqual(pipe.torch_dtype, torch.float32) + + # If underlying model doesn't have dtype property, simply return None + pipe.model = None + self.assertIsNone(pipe.torch_dtype) + @is_pipeline_test class PipelineScikitCompatTest(unittest.TestCase):