Skip to content

Commit

Permalink
Style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 28, 2021
1 parent 8f72642 commit 4ed7954
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pytest
import torch.nn as nn
from _pytest.fixtures import SubRequest
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torchvision.models import resnet18
Expand All @@ -17,20 +16,20 @@


class TestBYOL:
def test_custom_augment_fn(self) -> None:
encoder = resnet18()
layer = encoder.conv1
new_layer = nn.Conv2d( # type: ignore[attr-defined]
in_channels=4,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
bias=layer.bias,
).requires_grad_()
encoder.conv1 = new_layer
augment_fn = SimCLRAugmentation((2, 2))
BYOL(encoder, augment_fn=augment_fn)
def test_custom_augment_fn(self) -> None:
encoder = resnet18()
layer = encoder.conv1
new_layer = nn.Conv2d( # type: ignore[attr-defined]
in_channels=4,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
bias=layer.bias,
).requires_grad_()
encoder.conv1 = new_layer
augment_fn = SimCLRAugmentation((2, 2))
BYOL(encoder, augment_fn=augment_fn)


class TestBYOLTask:
Expand Down

0 comments on commit 4ed7954

Please sign in to comment.