Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipeline] Fix str device issue #24396

Merged
merged 10 commits into from
Jun 26, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Addresses: #24140 (comment)

Currently passing device="cuda" is not supported when creating a pipeline.
This is because torch.cuda.set_device(self.device) expects the device to have an explicit index. The fix is to create an indexed device when initializing a pipeline with a str device

Handy reproducible snippet:

from transformers import pipeline

# this works
pipe = pipeline("text-generation", device=0)
pipe("Hello")

# this works
pipe = pipeline("text-generation", device="cuda:0")
pipe("Hello")

# this fails
pipe = pipeline("text-generation", device="cuda")
pipe("Hello")

cc @amyeroberts @Narsil

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 21, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Just some more comments on making it a bit more robust. They're suggestions, so up to you if you want to add.

@require_torch_gpu
def test_pipeline_cuda(self):

pipe = pipeline("text-generation", device="cuda")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add an equivalent test here for "cuda:0" to make sure things still work even if the logic changes upstream?

@@ -793,6 +793,8 @@ def __init__(
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't comment on the line below, but we could make this if/elif/else check a bit safer by doing

elif isinstance(device, int):
    self.device = device
else:
    raise ValueError(f"Device type not supported. Got {device}")

Copy link
Contributor

@Narsil Narsil Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html

set_device seems strongly discouraged, so I'm unsure about current_device() usage.

torch.device("cuda")

Works though, what's the issue ?

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

Also

python -c 'from transformers import pipeline; pipe = pipeline(model="gpt2", device="cuda")'

Works on main.. So I'm not sure what's the issue

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 21, 2023

@Narsil what you shared works on main but it should throw an error if you try to run an example with it (I attached a reproducible snippet above)

Alternatively, this fails on main and this PR fixes it

python -c 'from transformers import pipeline; pipe = pipeline(model="gpt2", device="cuda"); pipe("hello")'

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

Can we remove the set_device instead then ? Seems better:

diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 510c07cf5..b5975d081 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -901,10 +901,8 @@ class Pipeline(_ScikitCompat):
             with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
                 yield
         else:
-            if self.device.type == "cuda":
-                torch.cuda.set_device(self.device)
-
-            yield
+            with torch.cuda.device(self.device):
+                yield

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

The initial thing fails indeed, and seems to be linked to the fact that there are multiple set_device happening causing issues.

By removing it the issue is indeed removed (but the test you added in the test suite isn't failing on main, and since this is what supposed to catch the regression, this is what I tried :) )

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 21, 2023

I am happy to revert some of the changes I proposed and add yours, it looks much better. However I have few questions
1- is it ok to call that context manager if self.device is CPU? I think we need a check on top of that to make sure we're not on CPU (similarly as what we had before)

import torch
device = torch.device("cpu")

with torch.cuda.device(device):
    print(torch.randn(1))

Throws:

    raise ValueError('Expected a cuda device, but got: {}'.format(device))
ValueError: Expected a cuda device, but got: cpu

EDIT: just with torch.device(self.device) seems to work

2- I am not sure but I think the with device context manager is only available since PT2.0 no?

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

2- I am not sure but I think the with device context manager is only available since PT2.0 no?

I don't know, all those are very good questions for which I don't have the answer to. I just know that now set_device is strongly discouraged so it's probably the source of our issues.

@younesbelkada
Copy link
Contributor Author

Thanks !
I can confirm the context manager doesn't work for PT==1.9 which is should be supported by us:

Traceback (most recent call last):
  File "scratch.py", line 203, in <module>
    with torch.device(device):
AttributeError: __enter__

Therefore I just added some changes to ensure backward compatibility with older PT versions. WDYT?

@@ -793,11 +793,16 @@ def __init__(
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
if device == "cuda":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe?

Suggested change
if device == "cuda":
if device == "cuda" and not hasattr(torch.device, "__enter__"):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking, if this condition is true, does the line below run OK?

self.device = torch.device(device)

Copy link
Contributor Author

@younesbelkada younesbelkada Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so, there shouldn't be an issue , i.e. torch.device(f"cuda:{i}") should work as long as i<n_gpus

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the updates look good 👍

torch.cuda.set_device(self.device)

yield
if hasattr(torch.device, "__enter__"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we typically check for compatibility with flags like is_torch_greater_or_equal_than_2_0 in pytorch_utils.. It's a bit cleaner than checking for a private attribute and is clearer for the reader what's being checked

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, will install different PT version and try to trck down from when the support has been added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I can confirm it has been introduced on PT>=2.0.0

@@ -793,11 +793,16 @@ def __init__(
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
if device == "cuda":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking, if this condition is true, does the line below run OK?

self.device = torch.device(device)

@younesbelkada
Copy link
Contributor Author

Hi @Narsil
Let me know if the changes look all good to you, happy to address any additional comments you have

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

May I attempt a different thing ?

I think the fix is correct, but I'm wondering if simply relying on torch.cuda.device context manager couldn't help remove the need for the compat layer.

@younesbelkada
Copy link
Contributor Author

Sure yes!

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

Cannot push

diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 626d33a3d..ee117e62a 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -50,7 +50,6 @@ if is_torch_available():
     from torch.utils.data import DataLoader, Dataset

     from ..models.auto.modeling_auto import AutoModel
-    from ..pytorch_utils import is_torch_greater_or_equal_than_2_0

     # Re-export for backward compatibility
     from .pt_utils import KeyDataset
@@ -794,16 +793,11 @@ class Pipeline(_ScikitCompat):
             if isinstance(device, torch.device):
                 self.device = device
             elif isinstance(device, str):
-                if device == "cuda" and not is_torch_greater_or_equal_than_2_0:
-                    # for backward compatiblity if using `set_device` and `cuda`
-                    device = f"cuda:{torch.cuda.current_device()}"
                 self.device = torch.device(device)
             elif device < 0:
                 self.device = torch.device("cpu")
-            elif isinstance(device, int):
-                self.device = torch.device(f"cuda:{device}")
             else:
-                raise ValueError(f"Device type not supported. Got {device}")
+                self.device = torch.device(f"cuda:{device}")
         else:
             self.device = device if device is not None else -1
         self.torch_dtype = torch_dtype
@@ -908,13 +902,10 @@ class Pipeline(_ScikitCompat):
             with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
                 yield
         else:
-            if is_torch_greater_or_equal_than_2_0:
-                with torch.device(self.device):
+            if self.device.type == "cuda":
+                with torch.cuda.device(self.device):
                     yield
-            # for backward compatibility
             else:
-                if self.device.type == "cuda":
-                    torch.cuda.set_device(self.device)
                 yield

@Narsil
Copy link
Contributor

Narsil commented Jun 21, 2023

torch.cuda.device is defined for torch==1.9 so it should work.

And torch.device("cpu") ... well it's the default there's no need to context manage it.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 26, 2023

Hi @Narsil
I am not sure if with torch.cuda.device(self.device): is supported for torch<2.0

https://pytorch.org/tutorials/recipes/recipes/changing_default_device.html

Maybe we should merge this PR for now to unblock also @thomasw21 & @NouamaneTazi . what do you think?

@thomasw21
Copy link
Contributor

I don't think we're blocked by this.

And torch.device("cpu") ... well it's the default there's no need to context manage it.

Not sure of the context of this sentence, but we're overriding the default to cuda, so having a context manager to switch back to cpu makes sense to me.

@Narsil
Copy link
Contributor

Narsil commented Jun 26, 2023

https://pytorch.org/docs/1.9.0/generated/torch.cuda.device.html?highlight=torch%20cuda%20device#torch.cuda.device

It is supported from 1.9.0+, at least in the docs.

@younesbelkada
Copy link
Contributor Author

Great ! agreed with those changes

@younesbelkada younesbelkada requested a review from sgugger June 26, 2023 10:27
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@younesbelkada younesbelkada merged commit 914289a into huggingface:main Jun 26, 2023
@younesbelkada younesbelkada deleted the add-str-support branch June 26, 2023 11:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants