Skip to content

Commit 927f031

Browse files
committed
Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models
1 parent da6644b commit 927f031

File tree

149 files changed

+1387
-1269
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+1387
-1269
lines changed

avg_checkpoints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import glob
1818
import hashlib
19-
from timm.models.helpers import load_state_dict
19+
from timm.models import load_state_dict
2020

2121
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
2222
parser.add_argument('--input', default='', type=str, metavar='PATH',

clean_checkpoint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import hashlib
1414
import shutil
1515
from collections import OrderedDict
16-
from timm.models.helpers import load_state_dict
16+
from timm.models import load_state_dict
1717

1818
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
1919
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',

hubconf.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
dependencies = ['torch']
2-
from timm.models import registry
3-
4-
globals().update(registry._model_entrypoints)
2+
import timm
3+
globals().update(timm.models._registry._model_entrypoints)

inference.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,23 @@
55
66
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
77
"""
8-
import os
9-
import time
108
import argparse
119
import json
1210
import logging
11+
import os
12+
import time
1313
from contextlib import suppress
1414
from functools import partial
1515

1616
import numpy as np
1717
import pandas as pd
1818
import torch
1919

20-
from timm.models import create_model, apply_test_time_pool, load_checkpoint
2120
from timm.data import create_dataset, create_loader, resolve_data_config
21+
from timm.layers import apply_test_time_pool
22+
from timm.models import create_model
2223
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
2324

24-
25-
2625
try:
2726
from apex import amp
2827
has_apex = True

tests/test_layers.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import pytest
21
import torch
32
import torch.nn as nn
4-
import platform
5-
import os
63

7-
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
4+
from timm.layers import create_act_layer, set_layer_config
85

96

107
class MLP(nn.Module):

tests/test_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import timm
1616
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
17-
from timm.models.fx_features import _leaf_modules, _autowrap_functions
17+
from timm.models._features_fx import _leaf_modules, _autowrap_functions
1818

1919
if hasattr(torch._C, '_jit_set_profiling_executor'):
2020
# legacy executor is too slow to compile large models for unit tests

timm/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .version import __version__
2+
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
23
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
3-
is_scriptable, is_exportable, set_scriptable, set_exportable, \
44
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

timm/data/readers/class_map.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33

4+
45
def load_class_map(map_or_filename, root=''):
56
if isinstance(map_or_filename, dict):
67
assert dict, 'class_map dict must be non-empty'
@@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''):
1415
with open(class_map_path) as f:
1516
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
1617
elif class_map_ext == '.pkl':
17-
with open(class_map_path,'rb') as f:
18+
with open(class_map_path, 'rb') as f:
1819
class_to_idx = pickle.load(f)
1920
else:
2021
assert False, f'Unsupported class map file extension ({class_map_ext}).'

timm/layers/__init__.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from .activations import *
2+
from .adaptive_avgmax_pool import \
3+
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4+
from .blur_pool import BlurPool2d
5+
from .classifier import ClassifierHead, create_classifier
6+
from .cond_conv2d import CondConv2d, get_condconv_initializer
7+
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
8+
set_layer_config
9+
from .conv2d_same import Conv2dSame, conv2d_same
10+
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
11+
from .create_act import create_act_layer, get_act_layer, get_act_fn
12+
from .create_attn import get_attn, create_attn
13+
from .create_conv2d import create_conv2d
14+
from .create_norm import get_norm_layer, create_norm_layer
15+
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
16+
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
17+
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
18+
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
19+
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
20+
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
21+
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
22+
from .gather_excite import GatherExcite
23+
from .global_context import GlobalContext
24+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
25+
from .inplace_abn import InplaceAbn
26+
from .linear import Linear
27+
from .mixed_conv2d import MixedConv2d
28+
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
29+
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
30+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
31+
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
32+
from .padding import get_padding, get_same_padding, pad_same
33+
from .patch_embed import PatchEmbed
34+
from .pool2d_same import AvgPool2dSame, create_pool2d
35+
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
36+
from .selective_kernel import SelectiveKernel
37+
from .separable_conv import SeparableConv2d, SeparableConvNormAct
38+
from .space_to_depth import SpaceToDepthModule
39+
from .split_attn import SplitAttn
40+
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
41+
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
42+
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
43+
from .trace_utils import _assert, _float_to_int
44+
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

timm/models/__init__.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@
6464
from .xception_aligned import *
6565
from .xcit import *
6666

67-
from .factory import create_model, parse_model_name, safe_model_name
68-
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
69-
from .layers import TestTimePoolHead, apply_test_time_pool
70-
from .layers import convert_splitbn_model, convert_sync_batchnorm
71-
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
72-
from .layers import set_fast_norm
73-
from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
74-
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
67+
from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
68+
set_pretrained_download_progress, set_pretrained_check_hash
69+
from ._factory import create_model, parse_model_name, safe_model_name
70+
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
71+
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
72+
register_notrace_module, register_notrace_function
73+
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
74+
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
75+
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
76+
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
77+
from ._pretrained import PretrainedCfg, DefaultCfg, \
78+
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
79+
from ._prune import adapt_model_from_string
80+
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
7581
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

0 commit comments

Comments
 (0)