diff --git a/pypots/nn/modules/raindrop/backbone.py b/pypots/nn/modules/raindrop/backbone.py index 50af4fe0..3d23146b 100644 --- a/pypots/nn/modules/raindrop/backbone.py +++ b/pypots/nn/modules/raindrop/backbone.py @@ -13,27 +13,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import TransformerEncoderLayer, TransformerEncoder -from torch.nn.parameter import Parameter - -from ....utils.logging import logger - -try: - from .layers import PositionalEncoding, ObservationPropagation - from torch_geometric.nn.inits import glorot -except ImportError as e: - logger.error( - f"❌ {e}\n" - "Note torch_geometric is missing, please install it with " - "'pip install torch_geometric torch_scatter torch_sparse' or " - "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" - ) -except NameError as e: - logger.error( - f"❌ {e}\n" - "Note torch_geometric is missing, please install it with " - "'pip install torch_geometric torch_scatter torch_sparse' or " - "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" - ) class BackboneRaindrop(nn.Module): @@ -53,7 +32,18 @@ def __init__( sensor_wise_mask=False, static=False, ): + + try: + from .layers import PositionalEncoding, ObservationPropagation + except (ImportError, NameError) as e: + raise ImportError( + f"❌ {e}. Note that torch_geometric is missing, please install it with " + "'pip install torch_geometric torch_scatter torch_sparse' or " + "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" + ) + super().__init__() + self.n_layers = n_layers self.n_features = n_features self.d_model = d_model @@ -91,7 +81,7 @@ def __init__( ) self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers) - self.R_u = Parameter(torch.Tensor(1, self.n_features * self.d_ob)) + self.R_u = nn.Parameter(torch.Tensor(1, self.n_features * self.d_ob)) self.ob_propagation = ObservationPropagation( in_channels=max_len * self.d_ob, @@ -126,7 +116,7 @@ def init_weights(self): self.encoder.weight.data.uniform_(-init_range, init_range) if self.static: self.static_emb.weight.data.uniform_(-init_range, init_range) - glorot(self.R_u) + nn.init.xavier_uniform(self.R_u) # xavier_uniform also known as glorot def forward( self, @@ -134,7 +124,7 @@ def forward( timestamps: torch.Tensor, lengths: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: - """Forward processing of BRITS. + """Forward processing of Raindrop. Parameters ---------- diff --git a/pypots/nn/modules/raindrop/layers.py b/pypots/nn/modules/raindrop/layers.py index d70fba66..0c56a5ba 100644 --- a/pypots/nn/modules/raindrop/layers.py +++ b/pypots/nn/modules/raindrop/layers.py @@ -18,8 +18,6 @@ from torch.nn import init from torch.nn.parameter import Parameter -from ....utils.logging import logger - try: from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import glorot @@ -27,13 +25,10 @@ from torch_geometric.utils import softmax from torch_scatter import scatter from torch_sparse import SparseTensor -except ImportError as e: - logger.error( - f"❌ {e}\n" - "Note torch_geometric is missing, please install it with " - "'pip install torch_geometric torch_scatter torch_sparse' or " - "'conda install -c pyg pyg pytorch-scatter pytorch-sparse'" - ) +except ImportError: + # Modules here only for Raindrop model, and torch_geometric import errors are caught in BackboneRaindrop. + # Hence, we can pass them here. + pass class PositionalEncoding(nn.Module):