Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit suggestions #958

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
name: Upgrade code

- repo: https://github.com/PyCQA/docformatter
rev: v1.5.0
rev: v1.5.1
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
Expand All @@ -45,12 +45,12 @@ repos:
exclude: CHANGELOG.md

- repo: https://github.com/PyCQA/isort
rev: 5.11.2
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black
name: Format code
Expand Down
3 changes: 0 additions & 3 deletions pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

@under_review()
class DataMonitorBase(Callback):

supported_loggers = (
TensorBoardLogger,
WandbLogger,
Expand Down Expand Up @@ -113,7 +112,6 @@ def _is_logger_available(self, logger: LightningLoggerBase) -> bool:

@under_review()
class ModuleDataMonitor(DataMonitorBase):

GROUP_NAME_INPUT = "input"
GROUP_NAME_OUTPUT = "output"

Expand Down Expand Up @@ -199,7 +197,6 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None:

@under_review()
class TrainingDataMonitor(DataMonitorBase):

GROUP_NAME = "training_step"

def __init__(self, log_every_n_steps: int = None):
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def runner(self, device: torch.device) -> Tuple[Experience]:

# step through each env
for env_idx, (env, action) in enumerate(zip(self.pool, actions)):

exp = self.env_step(env_idx, env, action)
history = self.histories[env_idx]
history.append(exp)
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

@under_review()
class KittiDataModule(LightningDataModule):

name = "kitti"

def __init__(
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def __init__(
*args,
**kwargs,
) -> None:

super().__init__(*args, **kwargs)
self.num_workers = num_workers
self.batch_size = batch_size
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

@under_review()
class SSLImagenetDataModule(LightningDataModule): # pragma: no cover

name = "imagenet"

def __init__(
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/vision_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class VisionDataModule(LightningDataModule):

EXTRA_ARGS: dict = {}
name: str = ""
#: Dataset class to use
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

@under_review()
class LightDataset(ABC, Dataset):

data: Tensor
targets: Tensor
normalize: tuple
Expand Down
3 changes: 0 additions & 3 deletions pl_bolts/datasets/ssl_amdim_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def generate_train_val_split(cls, examples, labels, pct_val):

cts = {x: 0 for x in range(nb_classes)}
for img, class_idx in zip(examples, labels):

# allow labeled
if cts[class_idx] < nb_val_images:
val_x.append(img)
Expand Down Expand Up @@ -60,7 +59,6 @@ def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val):

cts = {x: 0 for x in range(nb_classes)}
for img_name, class_idx in zip(examples, labels):

# allow labeled
if cts[class_idx] < nb_imgs_in_val:
labeled.append(img_name)
Expand All @@ -76,7 +74,6 @@ def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val):

@classmethod
def deterministic_shuffle(cls, x, y):

n = len(x)
idxs = list(range(0, n))
np.random.seed(1234)
Expand Down
10 changes: 3 additions & 7 deletions pl_bolts/losses/self_supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,8 @@ def forward(self, anchor_representations, positive_representations, mask_mat):

# trick 2: tanh clip
raw_scores = tanh_clip(raw_scores, clip_val=self.tclip)
"""
pos_scores includes scores for all the positive samples
neg_scores includes scores for all the negative samples, with
scores for positive samples set to the min score (-self.tclip here)
"""
"""pos_scores includes scores for all the positive samples neg_scores includes scores for all the negative
samples, with scores for positive samples set to the min score (-self.tclip here)"""
# ----------------------
# EXTRACT POSITIVE SCORES
# use the index mask to pull all the diagonals which are b1 x b1
Expand Down Expand Up @@ -337,8 +334,7 @@ def forward(self, anchor_maps, positive_maps):

regularizer = 0
losses = []
for (ai, pi) in self.map_indexes:

for ai, pi in self.map_indexes:
# choose a random map
if ai == -1:
ai = np.random.randint(0, len(anchor_maps))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):

images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]

Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/models/detection/retinanet/retinanet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:

from torchvision.models.detection.retinanet import RetinaNet as torchvision_RetinaNet
from torchvision.models.detection.retinanet import RetinaNetHead, retinanet_resnet50_fpn
from torchvision.ops import box_iou
Expand Down Expand Up @@ -97,7 +96,6 @@ def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):

images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]

Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/gans/pix2pix/pix2pix_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def _weights_init(m):
@under_review()
class Pix2Pix(LightningModule):
def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200):

super().__init__()
self.save_hyperparameters()

Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class LitMNIST(LightningModule):
"""

def __init__(self, hidden_dim: int = 128, learning_rate: float = 1e-3, **kwargs: Any) -> None:

if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

Expand Down
8 changes: 7 additions & 1 deletion pl_bolts/models/rl/per_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def train_batch(
states, actions, rewards, dones, new_states = samples

for idx, _ in enumerate(dones):
yield (states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx],), indices[
yield (
states[idx],
actions[idx],
rewards[idx],
dones[idx],
new_states[idx],
), indices[
idx
], weights[idx]

Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/rl/reinforce_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def train_batch(
"""

while True:

action = self.agent(self.state, self.device)

next_state, reward, done, _ = self.env.step(action[0])
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/rl/vanilla_policy_gradient_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def train_batch(
"""

while True:

action = self.agent(self.state, self.device)

next_state, reward, done, _ = self.env.step(action[0])
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(
initial_tau: float = 0.996,
**kwargs: Any,
) -> None:

super().__init__()
self.save_hyperparameters(ignore="base_encoder")

Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/byol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class MLP(nn.Module):
"""

def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None:

super().__init__()

self.model = nn.Sequential(
Expand Down Expand Up @@ -53,7 +52,6 @@ def __init__(
projector_hidden_dim: int = 4096,
projector_out_dim: int = 256,
) -> None:

super().__init__()

if isinstance(encoder, str):
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def forward(self, img_q, img_k, queue):

# compute key features
with torch.no_grad(): # no gradient to keys

# shuffle for making use of BN
if self._use_ddp(self.trainer):
img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/simclr/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class SimCLRTrainDataTransform:
def __init__(
self, input_height: int = 224, gaussian_blur: bool = True, jitter_strength: float = 1.0, normalize=None
) -> None:

if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `transforms` from `torchvision` which is not installed yet.")

Expand Down Expand Up @@ -140,7 +139,6 @@ class SimCLRFinetuneTransform(SimCLRTrainDataTransform):
def __init__(
self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False
) -> None:

super().__init__(
normalize=normalize, input_height=input_height, gaussian_blur=None, jitter_strength=jitter_strength
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(
exclude_bn_bias: bool = False,
**kwargs,
) -> None:

super().__init__()
self.save_hyperparameters(ignore="base_encoder")

Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/self_supervised/swav/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class SwAVFinetuneTransform:
def __init__(
self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False
) -> None:

self.jitter_strength = jitter_strength
self.input_height = input_height
self.normalize = normalize
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/models/vision/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
features_start: int = 64,
bilinear: bool = False,
):

if num_layers < 1:
raise ValueError(f"num_layers = {num_layers}, expected: num_layers > 0")

Expand Down
3 changes: 0 additions & 3 deletions pl_bolts/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,11 @@ def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]:
arguments: List[LitArg] = []
argument_names = []
for obj in inspect.getmro(cls):

if obj is root_cls and len(arguments) > 0:
break

if issubclass(obj, root_cls):

default_params = inspect.signature(obj.__init__).parameters # type: ignore

for arg in default_params:
arg_type = default_params[arg].annotation
arg_default = default_params[arg].default
Expand Down
2 changes: 0 additions & 2 deletions tests/datamodules/test_experience_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def test_source_is_done_2step_episode(self):
self.source.histories[0].append(self.exp1)

for idx, exp in enumerate(self.source.runner(self.device)):

self.assertTrue(isinstance(exp, tuple))

if idx == 0:
Expand All @@ -211,7 +210,6 @@ def test_source_is_done_metrics(self):
history += [self.exp1, self.exp2, self.exp2]

for idx, exp in enumerate(self.source.runner(self.device)):

if idx == n_steps - 1:
self.assertEqual(self.source._total_rewards[0], 1)
self.assertEqual(self.source.total_steps[0], 1)
Expand Down
1 change: 0 additions & 1 deletion tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
@pytest.mark.parametrize("batch_size,num_samples", [(16, 100), (1, 0)])
def test_dummy_ds(catch_warnings, batch_size, num_samples):
if num_samples > 0:

ds = DummyDataset((1, 28, 28), (1,), num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

Expand Down
1 change: 0 additions & 1 deletion tests/losses/test_rl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

class TestRLLoss(TestCase):
def setUp(self) -> None:

self.state = torch.rand(32, 4, 84, 84)
self.next_state = torch.rand(32, 4, 84, 84)
self.action = torch.ones([32])
Expand Down
1 change: 0 additions & 1 deletion tests/models/rl/unit/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def setUp(self) -> None:
self.value_agent = ValueAgent(self.net, self.env.action_space.n)

def test_value_agent(self):

action = self.value_agent(self.state, self.device)
self.assertIsInstance(action, list)
self.assertIsInstance(action[0], int)
Expand Down
2 changes: 0 additions & 2 deletions tests/utils/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class DummyParentModel(LightningModule):

name = "parent-model"

def __init__(self, a: int, b: str, c: str = "parent_model_c"):
Expand All @@ -19,7 +18,6 @@ def forward(self, x):


class DummyParentDataModule(LightningDataModule):

name = "parent-dm"

def __init__(self, d: str, c: str = "parent_dm_c"):
Expand Down