-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pipeline
] Fix str device issue
#24396
[pipeline
] Fix str device issue
#24396
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Just some more comments on making it a bit more robust. They're suggestions, so up to you if you want to add.
@require_torch_gpu | ||
def test_pipeline_cuda(self): | ||
|
||
pipe = pipeline("text-generation", device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add an equivalent test here for "cuda:0" to make sure things still work even if the logic changes upstream?
src/transformers/pipelines/base.py
Outdated
@@ -793,6 +793,8 @@ def __init__( | |||
if isinstance(device, torch.device): | |||
self.device = device | |||
elif isinstance(device, str): | |||
if device == "cuda": | |||
device = f"cuda:{torch.cuda.current_device()}" | |||
self.device = torch.device(device) | |||
elif device < 0: | |||
self.device = torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't comment on the line below, but we could make this if/elif/else check a bit safer by doing
elif isinstance(device, int):
self.device = device
else:
raise ValueError(f"Device type not supported. Got {device}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html
set_device
seems strongly discouraged, so I'm unsure about current_device()
usage.
torch.device("cuda")
Works though, what's the issue ?
Also
Works on |
@Narsil what you shared works on main but it should throw an error if you try to run an example with it (I attached a reproducible snippet above) Alternatively, this fails on main and this PR fixes it python -c 'from transformers import pipeline; pipe = pipeline(model="gpt2", device="cuda"); pipe("hello")' |
Can we remove the diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 510c07cf5..b5975d081 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -901,10 +901,8 @@ class Pipeline(_ScikitCompat):
with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
yield
else:
- if self.device.type == "cuda":
- torch.cuda.set_device(self.device)
-
- yield
+ with torch.cuda.device(self.device):
+ yield |
The initial thing fails indeed, and seems to be linked to the fact that there are multiple By removing it the issue is indeed removed (but the test you added in the test suite isn't failing on main, and since this is what supposed to catch the regression, this is what I tried :) ) |
I am happy to revert some of the changes I proposed and add yours, it looks much better. However I have few questions import torch
device = torch.device("cpu")
with torch.cuda.device(device):
print(torch.randn(1)) Throws: raise ValueError('Expected a cuda device, but got: {}'.format(device))
ValueError: Expected a cuda device, but got: cpu EDIT: just 2- I am not sure but I think the |
I don't know, all those are very good questions for which I don't have the answer to. I just know that now |
Thanks ! Traceback (most recent call last):
File "scratch.py", line 203, in <module>
with torch.device(device):
AttributeError: __enter__ Therefore I just added some changes to ensure backward compatibility with older PT versions. WDYT? |
src/transformers/pipelines/base.py
Outdated
@@ -793,11 +793,16 @@ def __init__( | |||
if isinstance(device, torch.device): | |||
self.device = device | |||
elif isinstance(device, str): | |||
if device == "cuda": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe?
if device == "cuda": | |
if device == "cuda" and not hasattr(torch.device, "__enter__"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double checking, if this condition is true, does the line below run OK?
self.device = torch.device(device)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think so, there shouldn't be an issue , i.e. torch.device(f"cuda:{i}")
should work as long as i<n_gpus
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the updates look good 👍
src/transformers/pipelines/base.py
Outdated
torch.cuda.set_device(self.device) | ||
|
||
yield | ||
if hasattr(torch.device, "__enter__"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we typically check for compatibility with flags like is_torch_greater_or_equal_than_2_0
in pytorch_utils.. It's a bit cleaner than checking for a private attribute and is clearer for the reader what's being checked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, will install different PT version and try to trck down from when the support has been added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I can confirm it has been introduced on PT>=2.0.0
src/transformers/pipelines/base.py
Outdated
@@ -793,11 +793,16 @@ def __init__( | |||
if isinstance(device, torch.device): | |||
self.device = device | |||
elif isinstance(device, str): | |||
if device == "cuda": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double checking, if this condition is true, does the line below run OK?
self.device = torch.device(device)
Hi @Narsil |
May I attempt a different thing ? I think the fix is correct, but I'm wondering if simply relying on |
Sure yes! |
Cannot push diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 626d33a3d..ee117e62a 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -50,7 +50,6 @@ if is_torch_available():
from torch.utils.data import DataLoader, Dataset
from ..models.auto.modeling_auto import AutoModel
- from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
# Re-export for backward compatibility
from .pt_utils import KeyDataset
@@ -794,16 +793,11 @@ class Pipeline(_ScikitCompat):
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
- if device == "cuda" and not is_torch_greater_or_equal_than_2_0:
- # for backward compatiblity if using `set_device` and `cuda`
- device = f"cuda:{torch.cuda.current_device()}"
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
- elif isinstance(device, int):
- self.device = torch.device(f"cuda:{device}")
else:
- raise ValueError(f"Device type not supported. Got {device}")
+ self.device = torch.device(f"cuda:{device}")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
@@ -908,13 +902,10 @@ class Pipeline(_ScikitCompat):
with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
yield
else:
- if is_torch_greater_or_equal_than_2_0:
- with torch.device(self.device):
+ if self.device.type == "cuda":
+ with torch.cuda.device(self.device):
yield
- # for backward compatibility
else:
- if self.device.type == "cuda":
- torch.cuda.set_device(self.device)
yield |
And |
Hi @Narsil https://pytorch.org/tutorials/recipes/recipes/changing_default_device.html Maybe we should merge this PR for now to unblock also @thomasw21 & @NouamaneTazi . what do you think? |
I don't think we're blocked by this.
Not sure of the context of this sentence, but we're overriding the default to |
It is supported from 1.9.0+, at least in the docs. |
Great ! agreed with those changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
What does this PR do?
Addresses: #24140 (comment)
Currently passing
device="cuda"
is not supported when creating a pipeline.This is because
torch.cuda.set_device(self.device)
expects the device to have an explicit index. The fix is to create an indexed device when initializing a pipeline with a str deviceHandy reproducible snippet:
cc @amyeroberts @Narsil