Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: shared load_state_dict_from_url monkeypatch #2223

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Testing: shared load_state_dict_from_url monkeypatch
adamjstewart committed Aug 13, 2024
commit ee563db8d78f5c48e77eaa6c6aa30104e93ba72e
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from typing import Any

import pytest
import torch
import torchvision
from pytest import MonkeyPatch


def load(*args: Any, progress: bool = False, **kwargs: Any) -> Any:
return torch.load(*args, **kwargs)


@pytest.fixture
def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
21 changes: 10 additions & 11 deletions tests/models/test_dofa.py
Original file line number Diff line number Diff line change
@@ -2,11 +2,9 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum
@@ -22,11 +20,6 @@
)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestDOFA:
@pytest.mark.parametrize(
'wavelengths',
@@ -86,7 +79,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = dofa_base_patch16_224()
@@ -95,7 +92,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_dofa(self) -> None:
@@ -123,7 +119,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = dofa_large_patch16_224()
@@ -132,7 +132,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_dofa(self) -> None:
21 changes: 10 additions & 11 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
@@ -2,32 +2,29 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestResNet18:
@pytest.fixture(params=[*ResNet18_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet18', in_chans=weights.meta['in_chans'])
@@ -36,7 +33,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_resnet(self) -> None:
@@ -64,7 +60,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
@@ -73,7 +73,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_resnet(self) -> None:
13 changes: 5 additions & 8 deletions tests/models/test_swin.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
@@ -14,19 +13,18 @@
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestSwin_V2_B:
@pytest.fixture(params=[*Swin_V2_B_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_b()
@@ -35,7 +33,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_swin_v2_b(self) -> None:
14 changes: 5 additions & 9 deletions tests/models/test_vit.py
Original file line number Diff line number Diff line change
@@ -2,32 +2,29 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestViTSmall16:
@pytest.fixture(params=[*ViTSmall16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
@@ -38,7 +35,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_vit(self) -> None:
14 changes: 5 additions & 9 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,11 @@

import os
from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torch.nn as nn
import torchvision
from pytest import MonkeyPatch
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
@@ -21,11 +19,6 @@
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestBYOL:
def test_custom_augment_fn(self) -> None:
model = resnet18()
@@ -88,7 +81,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
@@ -99,7 +96,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
13 changes: 5 additions & 8 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
import timm
import torch
import torch.nn as nn
import torchvision
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
@@ -56,11 +55,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
return None

@@ -125,7 +119,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
@@ -136,7 +134,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
13 changes: 5 additions & 8 deletions tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@
import pytest
import timm
import torch
import torchvision
from pytest import MonkeyPatch
from torch.nn import Module
from torchvision.models._api import WeightsEnum
@@ -25,11 +24,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestMoCoTask:
@pytest.mark.parametrize(
'name',
@@ -89,7 +83,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
@@ -100,7 +98,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Loading