Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend existing unit tests using Cover-Agent #2331

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from timm.layers import create_act_layer, set_layer_config
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn

import importlib
import os
Expand Down Expand Up @@ -76,3 +76,46 @@ def test_hard_swish_grad():
def test_hard_mish_grad():
for _ in range(100):
_run_act_layer_grad('hard_mish')

def test_get_act_layer_empty_string():
# Empty string should return None
assert get_act_layer('') is None


def test_create_act_layer_inplace_error():
class NoInplaceAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

# Should recover when inplace arg causes TypeError
layer = create_act_layer(NoInplaceAct, inplace=True)
assert isinstance(layer, NoInplaceAct)


def test_create_act_layer_edge_cases():
# Test None input
assert create_act_layer(None) is None

# Test TypeError handling for inplace
class CustomAct(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x):
return x

result = create_act_layer(CustomAct, inplace=True)
assert isinstance(result, CustomAct)


def test_get_act_fn_callable():
def custom_act(x):
return x
assert get_act_fn(custom_act) is custom_act


def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None

81 changes: 81 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter

from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler

from timm.optim import create_optimizer_v2
Expand Down Expand Up @@ -741,3 +743,82 @@ def test_lookahead_radam(optimizer):
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
)


def test_param_groups_layer_decay_with_end_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)

param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
end_layer_decay=0.5,
verbose=True
)

assert len(param_groups) > 0
# Verify layer scaling is applied with end decay
for group in param_groups:
assert 'lr_scale' in group
assert group['lr_scale'] <= 1.0
assert group['lr_scale'] >= 0.5


def test_param_groups_layer_decay_with_matcher():
class ModelWithMatcher(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 5)
self.layer2 = torch.nn.Linear(5, 2)

def group_matcher(self, coarse=False):
return lambda name: int(name.split('.')[0][-1])

model = ModelWithMatcher()
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
verbose=True
)

assert len(param_groups) > 0
# Verify layer scaling is applied
for group in param_groups:
assert 'lr_scale' in group
assert 'weight_decay' in group
assert len(group['params']) > 0


def test_param_groups_weight_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
weight_decay = 0.01
no_weight_decay_list = ['1.weight']

param_groups = param_groups_weight_decay(
model,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay_list
)

assert len(param_groups) == 2
assert param_groups[0]['weight_decay'] == 0.0
assert param_groups[1]['weight_decay'] == weight_decay

# Verify parameters are correctly grouped
no_decay_params = set(param_groups[0]['params'])
decay_params = set(param_groups[1]['params'])

for name, param in model.named_parameters():
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
assert param in no_decay_params
else:
assert param in decay_params

137 changes: 136 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
from torchvision.ops.misc import FrozenBatchNorm2d

import timm
import pytest
from timm.utils.model import freeze, unfreeze
from timm.utils.model import ActivationStatsHook
from timm.utils.model import extract_spp_stats

from timm.utils.model import _freeze_unfreeze
from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual
from timm.utils.model import reparameterize_model
from timm.utils.model import get_state_dict

def test_freeze_unfreeze():
model = timm.create_model('resnet18')
Expand Down Expand Up @@ -54,4 +61,132 @@ def test_freeze_unfreeze():
freeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
assert isinstance(model.layer1[0].bn1, BatchNorm2d)

def test_activation_stats_hook_validation():
model = timm.create_model('resnet18')

def test_hook(model, input, output):
return output.mean().item()

# Test error case with mismatched lengths
with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"):
ActivationStatsHook(
model,
hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'],
hook_fns=[test_hook]
)


def test_extract_spp_stats():
model = timm.create_model('resnet18')

def test_hook(model, input, output):
return output.mean().item()

stats = extract_spp_stats(
model,
hook_fn_locs=['layer1.0.conv1'],
hook_fns=[test_hook],
input_shape=[2, 3, 32, 32]
)

assert isinstance(stats, dict)
assert test_hook.__name__ in stats
assert isinstance(stats[test_hook.__name__], list)
assert len(stats[test_hook.__name__]) > 0

def test_freeze_unfreeze_bn_root():
import torch.nn as nn
from timm.layers import BatchNormAct2d

# Create batch norm layers
bn = nn.BatchNorm2d(10)
bn_act = BatchNormAct2d(10)

# Test with BatchNorm2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn, mode="freeze")

# Test with BatchNormAct2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn_act, mode="freeze")


def test_activation_stats_functions():
import torch

# Create sample input tensor [batch, channels, height, width]
x = torch.randn(2, 3, 4, 4)

# Test avg_sq_ch_mean
result1 = avg_sq_ch_mean(None, None, x)
assert isinstance(result1, float)

# Test avg_ch_var
result2 = avg_ch_var(None, None, x)
assert isinstance(result2, float)

# Test avg_ch_var_residual
result3 = avg_ch_var_residual(None, None, x)
assert isinstance(result3, float)


def test_reparameterize_model():
import torch.nn as nn

class FusableModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)

def fuse(self):
return nn.Identity()

class ModelWithFusable(nn.Module):
def __init__(self):
super().__init__()
self.fusable = FusableModule()
self.normal = nn.Linear(10, 10)

model = ModelWithFusable()

# Test with inplace=False (should create a copy)
new_model = reparameterize_model(model, inplace=False)
assert isinstance(new_model.fusable, nn.Identity)
assert isinstance(model.fusable, FusableModule) # Original unchanged

# Test with inplace=True
reparameterize_model(model, inplace=True)
assert isinstance(model.fusable, nn.Identity)


def test_get_state_dict_custom_unwrap():
import torch.nn as nn

class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)

model = CustomModel()

def custom_unwrap(m):
return m

state_dict = get_state_dict(model, unwrap_fn=custom_unwrap)
assert 'linear.weight' in state_dict
assert 'linear.bias' in state_dict


def test_freeze_unfreeze_string_input():
model = timm.create_model('resnet18')

# Test with string input
_freeze_unfreeze(model, 'layer1', mode='freeze')
assert model.layer1[0].conv1.weight.requires_grad == False

# Test unfreezing with string input
_freeze_unfreeze(model, 'layer1', mode='unfreeze')
assert model.layer1[0].conv1.weight.requires_grad == True

Loading