Skip to content

Commit

Permalink
Merge pull request #427 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Fix ETSformer tuning bug, and release v0.6rc1
  • Loading branch information
WenjieDu authored May 28, 2024
2 parents 2668891 + c39beb2 commit e6d7c9f
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.5"
__version__ = "0.6rc1"


from . import imputation, classification, clustering, forecasting, optim, data, utils
Expand Down
6 changes: 4 additions & 2 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .base import BaseCommand
from .utils import load_package_from_path
from ..classification import BRITS as BRITS_classification
from ..classification import GRUD as GRUD_classification
from ..classification import Raindrop
from ..clustering import CRLI, VaDER
from ..data.saving.h5 import load_dict_from_h5
Expand Down Expand Up @@ -80,8 +81,9 @@
"pypots.imputation.GPVAE": GPVAE,
"pypots.imputation.BRITS": BRITS,
"pypots.imputation.MRNN": MRNN,
"pypots.imputation.GRUD": GRUD,
# classification models
"pypots.classification.GRUD": GRUD,
"pypots.classification.GRUD": GRUD_classification,
"pypots.classification.BRITS": BRITS_classification,
"pypots.classification.Raindrop": Raindrop,
# clustering models
Expand Down Expand Up @@ -248,7 +250,7 @@ def run(self):
)
raise RuntimeError(
f"Hyperparameters do not match. Mismatched hyperparameters "
f"(in the tuning configuration but not in the given model's arguments): {list(mismatched)}"
f"(in the tuning configuration but not in {model_class.__name__}'s arguments): {list(mismatched)}"
)

# initializing optimizer and model
Expand Down
16 changes: 8 additions & 8 deletions pypots/imputation/etsformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ class ETSformer(BaseNNImputer):

def __init__(
self,
n_steps,
n_features,
n_e_layers,
n_d_layers,
d_model,
n_heads,
d_ffn,
top_k,
n_steps: int,
n_features: int,
n_e_layers: int,
n_d_layers: int,
d_model: int,
n_heads: int,
d_ffn: int,
top_k: int,
dropout: float = 0,
ORT_weight: float = 1,
MIT_weight: float = 1,
Expand Down
14 changes: 7 additions & 7 deletions pypots/imputation/fedformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ class FEDformer(BaseNNImputer):

def __init__(
self,
n_steps,
n_features,
n_layers,
d_model,
n_heads,
d_ffn,
moving_avg_window_size,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
n_heads: int,
d_ffn: int,
moving_avg_window_size: int,
dropout: float = 0,
version="Fourier",
modes=32,
Expand Down
2 changes: 1 addition & 1 deletion pypots/nn/modules/etsformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, x):
f = fft.rfftfreq(t)[self.low_freq :]

x_freq, index_tuple = self.topk_freq(x_freq)
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2))
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device)

return self.extrapolate(x_freq, f, t)
Expand Down

0 comments on commit e6d7c9f

Please sign in to comment.