Skip to content

Commit

Permalink
Remove default handling
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed Mar 11, 2024
1 parent c921678 commit e9ae6c9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 31 deletions.
21 changes: 5 additions & 16 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,6 @@ def __init__(
else:
self.device = device if device is not None else -1

self._initial_torch_dtype = torch_dtype

self.binary_output = binary_output

# We shouldn't call `model.to()` for models loaded with accelerate
Expand Down Expand Up @@ -957,20 +955,11 @@ def predict(self, X):
return self(X)

@property
def torch_dtype(self):
if hasattr(self.model, "dtype"):
# NB: We extract dtype from the underlying model, but it is possible that the model has dtype
# but the pipeline subclass doesn't support it. In such case we should not return anything,
# but it is not straightforward to detect it in a generic way. Therefore, we assume that the
# pipeline support torch_dtype if (1) the extracted dtype is not default one (float32), or
# (2) the torch_dtype argument was set by the user when creating the pipeline.
if self._initial_torch_dtype is not None or self.model.dtype not in (
torch.float32,
"float32",
"torch.float32",
):
return self.model.dtype
return self._initial_torch_dtype
def torch_dtype(self) -> Optional["torch.dtype"]:
"""
Torch dtype of the model (if it's Pytorch model), `None` otherwise.
"""
return getattr(self.model, "dtype", None)

@contextmanager
def device_placement(self):
Expand Down
21 changes: 6 additions & 15 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,23 +213,14 @@ def test_torch_dtype_property(self):
pipe.model.to(torch.bfloat16)
self.assertEqual(pipe.torch_dtype, torch.bfloat16)

# Even if the model dtype is the default one, we can safely assume the pipeline supports torch_dtype
# as it is constructed with torch_dtype specified
pipe.model.to(torch.float32)
self.assertEqual(pipe.torch_dtype, torch.float32)

# If dtype is NOT specified in the pipeline constructor, the property should NOT return type
# as we don't know if the pipeline supports torch_dtype
# If dtype is NOT specified in the pipeline constructor, the property should just return
# the dtype of the underlying model (default)
pipe = pipeline(model=model_id)
self.assertEqual(pipe.torch_dtype, None)

# If the model changes to non default dtype, we assume the pipeline supports torch_dtype
pipe.model.to(torch.float16)
self.assertEqual(pipe.torch_dtype, torch.float16)
self.assertEqual(pipe.torch_dtype, torch.float32)

# If the model dtype is the default, we conservatively assume the pipeline doesn't support torch_dtype
pipe.model.to(torch.float32)
self.assertEqual(pipe.torch_dtype, None)
# If underlying model doesn't have dtype property, simply return None
pipe.model = None
self.assertIsNone(pipe.torch_dtype)


@is_pipeline_test
Expand Down

0 comments on commit e9ae6c9

Please sign in to comment.