Skip to content

Commit

Permalink
use property
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed Mar 1, 2024
1 parent 00a61c1 commit c0e9e4a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
25 changes: 16 additions & 9 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,16 +861,8 @@ def __init__(
raise ValueError(f"{device} unrecognized or not available.")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype

if not self.torch_dtype and is_torch_available():
# If pipeline dtype is not specified, populate it from the model
# NB: We should only do this when the extracted dtype is not default one (float32),
# because not all models/pipelines support torch_dtype. Here we assume that if the
# model dtype is not float32 it is set by the user with torch_dtype param, so the
# model or pipeline should support it.
if hasattr(model, "dtype") and model.dtype not in (torch.float32, "float32", "torch.float32"):
self.torch_dtype = model.dtype
self._initial_torch_dtype = torch_dtype

self.binary_output = binary_output

Expand Down Expand Up @@ -964,6 +956,21 @@ def predict(self, X):
"""
return self(X)

@property
def torch_dtype(self):
if hasattr(self.model, "dtype"):
# NB: We extract dtype from the underlying model, but it is possible that the model has dtype
# but the pipeline subclass doesn't support it. In such case we should not return anything,
# but it is not straightforward to detect it in a generic way. Therefore, we assume that the
# pipeline support torch_dtype if (1) the extracted dtype is not default one (float32), or
# (2) the torch_dtype argument was set by the user when creating the pipeline.
if (
self._initial_torch_dtype is not None
or self.model.dtype not in (torch.float32, "float32", "torch.float32")
):
return self.model.dtype
return self._initial_torch_dtype

@contextmanager
def device_placement(self):
"""
Expand Down
38 changes: 23 additions & 15 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,27 +200,35 @@ def test_unbatch_attentions_hidden_states(self):
self.assertEqual(len(outputs), 20)

@require_torch
def test_torch_dtype_set_to_pipeline(self):
def test_torch_dtype_property(self):
import torch
model_id = "hf-internal-testing/tiny-random-distilbert"

# If dtype is specified in the pipeline constructor, it should be set to the pipeline and the model config
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", torch_dtype=torch.float16)
# If dtype is specified in the pipeline constructor, the property should return that type
pipe = pipeline(model=model_id, torch_dtype=torch.float16)
self.assertEqual(pipe.torch_dtype, torch.float16)
self.assertEqual(pipe.model.config.torch_dtype, torch.float16)

# If dtype is not specified, it should be set based on the model config
model = DistilBertForSequenceClassification.from_pretrained(
"hf-internal-testing/tiny-random-distilbert", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-distilbert")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# If the underlying model changes dtype, the property should return the new type
pipe.model.to(torch.bfloat16)
self.assertEqual(pipe.torch_dtype, torch.bfloat16)

# If dtype is not specified and not available in the model config, it should be set based
# on the model's parameters dtype
model.config.torch_dtype = None
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
self.assertEqual(pipe.torch_dtype, torch.bfloat16)
# Even if the model dtype is the default one, we can safely assume the pipeline supports torch_dtype
# as it is constructed with torch_dtype specified
pipe.model.to(torch.float32)
self.assertEqual(pipe.torch_dtype, torch.float32)

# If dtype is NOT specified in the pipeline constructor, the property should NOT return type
# as we don't know if the pipeline supports torch_dtype
pipe = pipeline(model=model_id)
self.assertEqual(pipe.torch_dtype, None)

# If the model changes to non default dtype, we assume the pipeline supports torch_dtype
pipe.model.to(torch.float16)
self.assertEqual(pipe.torch_dtype, torch.float16)

# If the model dtype is the default, we conservatively assume the pipeline doesn't support torch_dtype
pipe.model.to(torch.float32)
self.assertEqual(pipe.torch_dtype, None)


@is_pipeline_test
Expand Down

0 comments on commit c0e9e4a

Please sign in to comment.