Skip to content

Commit

Permalink
imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 19, 2020
1 parent 1c440da commit 0879a96
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions pl_bolts/callbacks/self_supervised.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import math
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F

from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

class SSLOnlineEvaluator(pl.Callback): # pragma: no-cover

class SSLOnlineEvaluator(Callback): # pragma: no-cover
"""
Attaches a MLP for finetuning using the standard self-supervised protocol.
Expand Down Expand Up @@ -41,8 +43,6 @@ def __init__(
self.num_classes = num_classes

def on_pretrain_routine_start(self, trainer, pl_module):
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

# attach the evaluator to the module

if hasattr(pl_module, 'z_dim'):
Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/models/self_supervised/swav/swav_online_eval.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F

from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

class SwavOnlineEvaluator(pl.Callback):

class SwavOnlineEvaluator(Callback):
def __init__(
self,
drop_p: float = 0.2,
Expand All @@ -29,8 +31,6 @@ def __init__(
self.acc = []

def on_pretrain_routine_start(self, trainer, pl_module):
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

pl_module.non_linear_evaluator = SSLEvaluator(
n_input=self.z_dim,
n_classes=self.num_classes,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import torch
from pytorch_lightning import seed_everything

from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised import CPCV2, AMDIM, MocoV2, SimCLR, BYOL, SwAV
from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler
from pl_bolts.models.self_supervised.moco.transforms import (Moco2TrainCIFAR10Transforms, Moco2EvalCIFAR10Transforms)
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization


# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it...
Expand Down

0 comments on commit 0879a96

Please sign in to comment.