diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 72caccc6..e80d2658 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -49,13 +49,15 @@ def __init__( self, model: PyTorchIEModel, taskmodule: TaskModule, - # args_parser: ArgumentHandler = None, - device: int = -1, + device: Union[int, str] = "cpu", binary_output: bool = False, **kwargs, ): self.taskmodule = taskmodule - self.device = torch.device("cpu" if device < 0 else f"cuda:{device}") + device_str = ( + ("cpu" if device < 0 else f"cuda:{device}") if isinstance(device, int) else device + ) + self.device = torch.device(device_str) self.binary_output = binary_output # Module.to() returns just self, but moved to the device. This is not correctly @@ -324,13 +326,6 @@ def __call__( forward_params = {**self._forward_params, **forward_params} postprocess_params = {**self._postprocess_params, **postprocess_params} - self.call_count += 1 - if self.call_count > 10 and self.device.type == "cuda": - warnings.warn( - "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset", - UserWarning, - ) - single_document = False if isinstance(documents, Document): single_document = True