From 6cdbdfadda636ed5fab760822e31a62a0266cc89 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 1 May 2024 01:16:30 +0800 Subject: [PATCH 1/3] fix: try to import torch_geometric only when init Raindrop; --- pypots/nn/modules/raindrop/backbone.py | 38 ++++++++++---------------- pypots/nn/modules/raindrop/layers.py | 11 ++------ 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/pypots/nn/modules/raindrop/backbone.py b/pypots/nn/modules/raindrop/backbone.py index 50af4fe0..7a7e30af 100644 --- a/pypots/nn/modules/raindrop/backbone.py +++ b/pypots/nn/modules/raindrop/backbone.py @@ -13,27 +13,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import TransformerEncoderLayer, TransformerEncoder -from torch.nn.parameter import Parameter - -from ....utils.logging import logger - -try: - from .layers import PositionalEncoding, ObservationPropagation - from torch_geometric.nn.inits import glorot -except ImportError as e: - logger.error( - f"❌ {e}\n" - "Note torch_geometric is missing, please install it with " - "'pip install torch_geometric torch_scatter torch_sparse' or " - "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" - ) -except NameError as e: - logger.error( - f"❌ {e}\n" - "Note torch_geometric is missing, please install it with " - "'pip install torch_geometric torch_scatter torch_sparse' or " - "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" - ) class BackboneRaindrop(nn.Module): @@ -53,7 +32,18 @@ def __init__( sensor_wise_mask=False, static=False, ): + + try: + from .layers import PositionalEncoding, ObservationPropagation + except (ImportError, NameError) as e: + raise ImportError( + f"❌ {e}. Note that torch_geometric is missing, please install it with " + "'pip install torch_geometric torch_scatter torch_sparse' or " + "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" + ) + super().__init__() + self.n_layers = n_layers self.n_features = n_features self.d_model = d_model @@ -91,7 +81,7 @@ def __init__( ) self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers) - self.R_u = Parameter(torch.Tensor(1, self.n_features * self.d_ob)) + self.R_u = nn.Linear(1, self.n_features * self.d_ob, bias=False) self.ob_propagation = ObservationPropagation( in_channels=max_len * self.d_ob, @@ -126,7 +116,7 @@ def init_weights(self): self.encoder.weight.data.uniform_(-init_range, init_range) if self.static: self.static_emb.weight.data.uniform_(-init_range, init_range) - glorot(self.R_u) + nn.init.xavier_uniform(self.R_u) # xavier_uniform also known as glorot def forward( self, @@ -134,7 +124,7 @@ def forward( timestamps: torch.Tensor, lengths: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: - """Forward processing of BRITS. + """Forward processing of Raindrop. Parameters ---------- diff --git a/pypots/nn/modules/raindrop/layers.py b/pypots/nn/modules/raindrop/layers.py index d70fba66..4c9f8f91 100644 --- a/pypots/nn/modules/raindrop/layers.py +++ b/pypots/nn/modules/raindrop/layers.py @@ -18,8 +18,6 @@ from torch.nn import init from torch.nn.parameter import Parameter -from ....utils.logging import logger - try: from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import glorot @@ -28,12 +26,9 @@ from torch_scatter import scatter from torch_sparse import SparseTensor except ImportError as e: - logger.error( - f"❌ {e}\n" - "Note torch_geometric is missing, please install it with " - "'pip install torch_geometric torch_scatter torch_sparse' or " - "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" - ) + # Modules here only for Raindrop model, and torch_geometric import errors are caught in BackboneRaindrop. + # Hence, we can pass them here. + pass class PositionalEncoding(nn.Module): From 01a3c626b81da1a90a8d3fab9a715b4b77b49a7e Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 5 May 2024 23:13:53 +0800 Subject: [PATCH 2/3] fix: linting error; --- pypots/nn/modules/raindrop/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/nn/modules/raindrop/layers.py b/pypots/nn/modules/raindrop/layers.py index 4c9f8f91..0c56a5ba 100644 --- a/pypots/nn/modules/raindrop/layers.py +++ b/pypots/nn/modules/raindrop/layers.py @@ -25,7 +25,7 @@ from torch_geometric.utils import softmax from torch_scatter import scatter from torch_sparse import SparseTensor -except ImportError as e: +except ImportError: # Modules here only for Raindrop model, and torch_geometric import errors are caught in BackboneRaindrop. # Hence, we can pass them here. pass From 7357caace5c603ba1aa4536337f971d6442f11b1 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 5 May 2024 23:17:02 +0800 Subject: [PATCH 3/3] no message --- pypots/nn/modules/raindrop/backbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/nn/modules/raindrop/backbone.py b/pypots/nn/modules/raindrop/backbone.py index 7a7e30af..3d23146b 100644 --- a/pypots/nn/modules/raindrop/backbone.py +++ b/pypots/nn/modules/raindrop/backbone.py @@ -81,7 +81,7 @@ def __init__( ) self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers) - self.R_u = nn.Linear(1, self.n_features * self.d_ob, bias=False) + self.R_u = nn.Parameter(torch.Tensor(1, self.n_features * self.d_ob)) self.ob_propagation = ObservationPropagation( in_channels=max_len * self.d_ob,