Skip to content

Commit

Permalink
isort
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jun 16, 2024
1 parent b43faff commit 13ec2da
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/samformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from .module import SamFormerModel
from .lightning_module import SamFormerLightningModule
from .estimator import SamFormerEstimator
from .lightning_module import SamFormerLightningModule
from .module import SamFormerModel

__all__ = [
"SamFormerModel",
Expand Down
20 changes: 10 additions & 10 deletions src/gluonts/torch/model/samformer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Optional, Iterable, Dict, Any
from typing import Any, Dict, Iterable, Optional

import torch
import lightning.pytorch as pl
import torch

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.torch.distributions import Output, StudentTOutput
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import (
AsNumpyArray,
Transformation,
AddObservedValuesIndicator,
AsNumpyArray,
ExpectedNumInstanceSampler,
InstanceSampler,
InstanceSplitter,
ValidationSplitSampler,
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
TestSplitSampler,
Transformation,
ValidationSplitSampler,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import Output, StudentTOutput

from .lightning_module import SamFormerLightningModule

Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/torch/model/samformer/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.lr = lr
self.weight_decay = weight_decay
self.rho = rho
self.sam = sam
self.sam = sam

self.automatic_optimization = False

Expand Down Expand Up @@ -99,9 +99,9 @@ def training_step(self, batch, batch_idx: int): # type: ignore
self.manual_backward(train_loss_2)
opt.second_step(zero_grad=True)
else:
opt.zero_grad()
self.manual_backward(train_loss)
opt.step()
opt.zero_grad()
self.manual_backward(train_loss)
opt.step()

self.log(
"train_loss",
Expand Down
1 change: 0 additions & 1 deletion src/gluonts/torch/model/samformer/sam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch

from torch.optim import Optimizer


Expand Down

0 comments on commit 13ec2da

Please sign in to comment.