diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index dfa75768d8f811..80a48198f3fe6d 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -422,6 +422,7 @@ def pipeline( revision: Optional[str] = None, use_fast: bool = True, use_auth_token: Optional[Union[str, bool]] = None, + device: Optional[Union[int, str, "torch.device"]] = None, device_map=None, torch_dtype=None, trust_remote_code: Optional[bool] = None, @@ -508,6 +509,9 @@ def pipeline( use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + device (`int` or `str` or `torch.device`): + Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this + pipeline will be allocated. device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*): Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set `device_map="auto"` to compute the most optimized `device_map` automatically. [More @@ -802,4 +806,7 @@ def pipeline( if feature_extractor is not None: kwargs["feature_extractor"] = feature_extractor + if device is not None: + kwargs["device"] = device + return pipeline_class(model=model, framework=framework, task=task, **kwargs) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 6e2c28e5ddf84d..a0ce06ec5e33f1 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -704,7 +704,7 @@ def predict(self, X): Reference to the object in charge of parsing supplied pipeline parameters. device (`int`, *optional*, defaults to -1): Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on - the associated CUDA device id. You can pass native `torch.device` too. + the associated CUDA device id. You can pass native `torch.device` or a `str` too. binary_output (`bool`, *optional*, defaults to `False`): Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text. """ @@ -747,7 +747,7 @@ def __init__( framework: Optional[str] = None, task: str = "", args_parser: ArgumentHandler = None, - device: int = -1, + device: Union[int, str, "torch.device"] = -1, binary_output: bool = False, **kwargs, ): @@ -760,14 +760,21 @@ def __init__( self.feature_extractor = feature_extractor self.modelcard = modelcard self.framework = framework - if is_torch_available() and isinstance(device, torch.device): - self.device = device + if is_torch_available() and self.framework == "pt": + if isinstance(device, torch.device): + self.device = device + elif isinstance(device, str): + self.device = torch.device(device) + elif device < 0: + self.device = torch.device("cpu") + else: + self.device = torch.device("cuda:{device}") else: - self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") + self.device = device self.binary_output = binary_output # Special handling - if self.framework == "pt" and self.device.type == "cuda": + if self.framework == "pt" and self.device.type != "cpu": self.model = self.model.to(self.device) # Update config with task specific parameters