Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: shared load_state_dict_from_url monkeypatch #2223

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from typing import Any

import pytest
import torch
import torchvision
from pytest import MonkeyPatch


def load(*args: Any, progress: bool = False, **kwargs: Any) -> Any:
return torch.load(*args, **kwargs)


@pytest.fixture
def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
21 changes: 10 additions & 11 deletions tests/models/test_dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum
Expand All @@ -22,11 +20,6 @@
)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestDOFA:
@pytest.mark.parametrize(
'wavelengths',
Expand Down Expand Up @@ -86,7 +79,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = dofa_base_patch16_224()
Expand All @@ -95,7 +92,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_dofa(self) -> None:
Expand Down Expand Up @@ -123,7 +119,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = dofa_large_patch16_224()
Expand All @@ -132,7 +132,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_dofa(self) -> None:
Expand Down
21 changes: 10 additions & 11 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestResNet18:
@pytest.fixture(params=[*ResNet18_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet18', in_chans=weights.meta['in_chans'])
Expand All @@ -36,7 +33,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_resnet(self) -> None:
Expand Down Expand Up @@ -64,7 +60,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
Expand All @@ -73,7 +73,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_resnet(self) -> None:
Expand Down
13 changes: 5 additions & 8 deletions tests/models/test_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
Expand All @@ -14,19 +13,18 @@
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestSwin_V2_B:
@pytest.fixture(params=[*Swin_V2_B_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_b()
Expand All @@ -35,7 +33,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_swin_v2_b(self) -> None:
Expand Down
14 changes: 5 additions & 9 deletions tests/models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestViTSmall16:
@pytest.fixture(params=[*ViTSmall16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -38,7 +35,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_vit(self) -> None:
Expand Down
14 changes: 5 additions & 9 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

import os
from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torch.nn as nn
import torchvision
from pytest import MonkeyPatch
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
Expand All @@ -21,11 +19,6 @@
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestBYOL:
def test_custom_augment_fn(self) -> None:
model = resnet18()
Expand Down Expand Up @@ -88,7 +81,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -99,7 +96,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down
13 changes: 5 additions & 8 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import timm
import torch
import torch.nn as nn
import torchvision
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
Expand Down Expand Up @@ -56,11 +55,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
return None

Expand Down Expand Up @@ -125,7 +119,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -136,7 +134,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down
13 changes: 5 additions & 8 deletions tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
import timm
import torch
import torchvision
from pytest import MonkeyPatch
from torch.nn import Module
from torchvision.models._api import WeightsEnum
Expand All @@ -25,11 +24,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestMoCoTask:
@pytest.mark.parametrize(
'name',
Expand Down Expand Up @@ -89,7 +83,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -100,7 +98,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down
Loading
Loading