Skip to content

Commit 5e2e77f

Browse files
cyyevergante
andauthored
Improve torch_dtype checks (#40808)
* Improve torch_dtype checks Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * Apply suggestions from code review --------- Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
1 parent c81f426 commit 5e2e77f

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/transformers/commands/chat.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,14 @@ class ChatArguments:
289289
def __post_init__(self):
290290
"""Only used for BC `torch_dtype` argument."""
291291
# In this case only the BC torch_dtype was given
292-
if self.torch_dtype is not None and self.dtype == "auto":
293-
self.dtype = self.torch_dtype
292+
if self.torch_dtype is not None:
293+
if self.dtype is None:
294+
self.dtype = self.torch_dtype
295+
elif self.torch_dtype != self.dtype:
296+
raise ValueError(
297+
f"`torch_dtype` {self.torch_dtype} and `dtype` {self.dtype} have different values. `torch_dtype` is deprecated and "
298+
"will be removed in 4.59.0, please set `dtype` instead."
299+
)
294300

295301

296302
def chat_command_factory(args: Namespace):

src/transformers/commands/serving.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,14 @@ class ServeArguments:
457457
def __post_init__(self):
458458
"""Only used for BC `torch_dtype` argument."""
459459
# In this case only the BC torch_dtype was given
460-
if self.torch_dtype is not None and self.dtype == "auto":
461-
self.dtype = self.torch_dtype
460+
if self.torch_dtype is not None:
461+
if self.dtype is None:
462+
self.dtype = self.torch_dtype
463+
elif self.torch_dtype != self.dtype:
464+
raise ValueError(
465+
f"`torch_dtype` {self.torch_dtype} and `dtype` {self.dtype} have different values. `torch_dtype` is deprecated and "
466+
"will be removed in 4.59.0, please set `dtype` instead."
467+
)
462468

463469

464470
class ServeCommand(BaseTransformersCLICommand):

src/transformers/pipelines/keypoint_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __call__(
147147
def preprocess(self, images, timeout=None):
148148
images = [load_image(image, timeout=timeout) for image in images]
149149
model_inputs = self.image_processor(images=images, return_tensors=self.framework)
150-
model_inputs = model_inputs.to(self.torch_dtype)
150+
model_inputs = model_inputs.to(self.dtype)
151151
target_sizes = [image.size for image in images]
152152
preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes}
153153
return preprocess_outputs

0 commit comments

Comments
 (0)