Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to import torch_geometric only when init Raindrop #381

Merged
merged 3 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions pypots/nn/modules/raindrop/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,6 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.nn.parameter import Parameter

from ....utils.logging import logger

try:
from .layers import PositionalEncoding, ObservationPropagation
from torch_geometric.nn.inits import glorot
except ImportError as e:
logger.error(
f"❌ {e}\n"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)
except NameError as e:
logger.error(
f"❌ {e}\n"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)


class BackboneRaindrop(nn.Module):
Expand All @@ -53,7 +32,18 @@ def __init__(
sensor_wise_mask=False,
static=False,
):

try:
from .layers import PositionalEncoding, ObservationPropagation
except (ImportError, NameError) as e:
raise ImportError(
f"❌ {e}. Note that torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)

super().__init__()

self.n_layers = n_layers
self.n_features = n_features
self.d_model = d_model
Expand Down Expand Up @@ -91,7 +81,7 @@ def __init__(
)
self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers)

self.R_u = Parameter(torch.Tensor(1, self.n_features * self.d_ob))
self.R_u = nn.Parameter(torch.Tensor(1, self.n_features * self.d_ob))

self.ob_propagation = ObservationPropagation(
in_channels=max_len * self.d_ob,
Expand Down Expand Up @@ -126,15 +116,15 @@ def init_weights(self):
self.encoder.weight.data.uniform_(-init_range, init_range)
if self.static:
self.static_emb.weight.data.uniform_(-init_range, init_range)
glorot(self.R_u)
nn.init.xavier_uniform(self.R_u) # xavier_uniform also known as glorot

def forward(
self,
X: torch.Tensor,
timestamps: torch.Tensor,
lengths: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""Forward processing of BRITS.
"""Forward processing of Raindrop.

Parameters
----------
Expand Down
13 changes: 4 additions & 9 deletions pypots/nn/modules/raindrop/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,17 @@
from torch.nn import init
from torch.nn.parameter import Parameter

from ....utils.logging import logger

try:
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot
from torch_geometric.typing import PairTensor, Adj, OptTensor
from torch_geometric.utils import softmax
from torch_scatter import scatter
from torch_sparse import SparseTensor
except ImportError as e:
logger.error(
f"❌ {e}\n"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)
except ImportError:
# Modules here only for Raindrop model, and torch_geometric import errors are caught in BackboneRaindrop.
# Hence, we can pass them here.
pass


class PositionalEncoding(nn.Module):
Expand Down
Loading