From 7aa9407f7e61d320f14d161ef3d26337367f50f5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 31 Aug 2020 16:03:24 -0700 Subject: [PATCH] Add torch jit ignore decorator to Lightning properties [PR](https://github.com/pytorch/pytorch/pull/42390) adds support to Torchscript for module properties, and as a result, these module properties in Lghtning will no longer compile because they use unsupported features in their properties (that are now being compiled) --- pytorch_lightning/core/lightning.py | 4 ++++ pytorch_lightning/utilities/device_dtype_mixin.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 66d067a5146b6..70cd5d847eadd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -76,6 +76,7 @@ def __init__(self, *args, **kwargs): self._example_input_array = None self._datamodule = None + @torch.jit.ignore @property def example_input_array(self) -> Any: return self._example_input_array @@ -84,6 +85,7 @@ def example_input_array(self) -> Any: def example_input_array(self, example: Any) -> None: self._example_input_array = example + @torch.jit.ignore @property def datamodule(self) -> Any: return self._datamodule @@ -92,6 +94,7 @@ def datamodule(self) -> Any: def datamodule(self, datamodule: Any) -> None: self._datamodule = datamodule + @torch.jit.ignore @property def on_gpu(self): """ @@ -1720,6 +1723,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) + @torch.jit.ignore @property def hparams(self) -> Union[AttributeDict, str]: if not hasattr(self, '_hparams'): diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index bea3df3e5ced9..1767e06f968f7 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -11,6 +11,7 @@ def __init__(self): self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') + @torch.jit.ignore @property def dtype(self) -> Union[str, torch.dtype]: return self._dtype @@ -20,6 +21,7 @@ def dtype(self, new_dtype: Union[str, torch.dtype]): # necessary to avoid infinite recursion raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') + @torch.jit.ignore @property def device(self) -> Union[str, torch.device]: return self._device