From b17eaa468c3153c915973806df1871eac512e074 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 23 Nov 2021 10:33:20 +0100 Subject: [PATCH] Fix torchscript prediction with GPU --- .../_model_adapters/_torchscript_model_adapter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py index 21a844f0..1833d239 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py @@ -12,17 +12,17 @@ class TorchscriptModelAdapter(ModelAdapter): def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[List[str]] = None): weight_path = str(bioimageio_model.weights["pytorch_script"].source.resolve()) if devices is None: - devices = ["cuda" if torch.cuda.is_available() else "cpu"] + self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] else: - devices = [torch.device(d) for d in devices] + self.devices = [torch.device(d) for d in devices] self._model = torch.jit.load(weight_path) - self._model.to(devices[0]) + self._model.to(self.devices[0]) self._internal_output_axes = [tuple(out.axes) for out in bioimageio_model.outputs] def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: with torch.no_grad(): - torch_tensor = [torch.from_numpy(b.data) for b in batch] + torch_tensor = [torch.from_numpy(b.data).to(self.devices[0]) for b in batch] result = self._model.forward(*torch_tensor) if not isinstance(result, (tuple, list)): result = [result]