forked from open-mmlab/mmengine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpoint.py
780 lines (652 loc) · 28.2 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
import re
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, Dict, Optional
import torch
import mmengine
from mmengine.dist import get_dist_info
from mmengine.fileio import FileClient, get_file_backend
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
from mmengine.utils.dl_utils import load_url
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
# downloaded from Internet. If it is not set, as a workaround, using
# `XDG_CACHE_HOME`` or `~/.cache` instead.
# Note that `XDG_CACHE_HOME` defines the base directory relative to which
# user-specific non-essential data files should be stored. If `XDG_CACHE_HOME`
# is either not set or empty, a default equal to `~/.cache` should be used.
ENV_MMENGINE_HOME = 'MMENGINE_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmengine_home():
mmengine_home = os.path.expanduser(
os.getenv(
ENV_MMENGINE_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine')))
mkdir_or_exist(mmengine_home)
return mmengine_home
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_model_wrapper(module) or isinstance(module, BaseTTAModel):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
else:
print_log(err_msg, logger=logger, level=logging.WARNING)
def get_torchvision_models():
import torchvision
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
else:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json')
model_urls = mmengine.load(json_path)
if digit_version(torchvision.__version__) < digit_version('0.14.0a0'):
weights_list = [
cls for cls_name, cls in torchvision.models.__dict__.items()
if cls_name.endswith('_Weights')
]
else:
weights_list = [
torchvision.models.get_model_weights(model)
for model in torchvision.models.list_models(torchvision.models)
]
for cls in weights_list:
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if not hasattr(cls, 'DEFAULT'):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_name = cls.__name__
cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
cls_key = cls_name.replace('_Weights', '').lower()
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
model_urls[cls_key] = weight_enum.url
return model_urls
def get_external_models():
mmengine_home = _get_mmengine_home()
default_json_path = osp.join(mmengine.__path__[0], 'hub/openmmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmengine_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmengine.__path__[0], 'hub/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmengine.__path__[0], 'hub/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
# Some checkpoints converted from 3rd-party repo don't
# have the "state_dict" key.
state_dict = checkpoint
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
class CheckpointLoader:
"""A general checkpoint loader to manage all schemes."""
_schemes: Dict[str, Callable] = {}
@classmethod
def _register_scheme(cls, prefixes, loader, force=False):
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if (prefix not in cls._schemes) or force:
cls._schemes[prefix] = loader
else:
raise KeyError(
f'{prefix} is already registered as a loader backend, '
'add "force=True" if you want to override it')
# sort, longer prefixes take priority
cls._schemes = OrderedDict(
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod
def register_scheme(cls, prefixes, loader=None, force=False):
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args:
prefixes (str or list[str] or tuple[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
Defaults to None.
force (bool, optional): Whether to override the loader
if the prefix has already been registered. Defaults to False.
"""
if loader is not None:
cls._register_scheme(prefixes, loader, force=force)
return
def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force)
return loader_cls
return _register
@classmethod
def _get_checkpoint_loader(cls, path):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
Args:
path (str): checkpoint path
Returns:
callable: checkpoint loader
"""
for p in cls._schemes:
# use regular match to handle some cases that where the prefix of
# loader has a prefix. For example, both 's3://path' and
# 'open-mmlab:s3://path' should return `load_from_ceph`
if re.match(p, path) is not None:
return cls._schemes[p]
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger='current'):
"""load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
logger (str): The logger for message. Defaults to 'current'.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__
print_log(
f'Loads checkpoint by {class_name[10:]} backend from path: '
f'{filename}',
logger=logger)
return checkpoint_loader(filename, map_location)
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
"""load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename,
map_location=None,
model_dir=None,
progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object,
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
if rank == 0:
checkpoint = load_url(
filename,
model_dir=model_dir,
map_location=map_location,
progress=progress)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = load_url(
filename,
model_dir=model_dir,
map_location=map_location,
progress=progress)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with pavi prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
assert filename.startswith('pavi://'), \
f'Expected filename startswith `pavi://`, but get {filename}'
model_path = filename[7:]
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str, optional): The storage backend type.
Defaults to 'petrel'.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
file_backend = get_file_backend(
filename, backend_args={'backend': backend})
with io.BytesIO(file_backend.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'):
warnings.warn(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead', DeprecationWarning)
model_name = filename[11:]
else:
model_name = filename[14:]
return load_from_http(model_urls[model_name], map_location=map_location)
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
Args:
filename (str): checkpoint file path with open-mmlab or
openmmlab prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_external_models()
prefix_str = 'open-mmlab://'
if filename.startswith(prefix_str):
model_name = filename[13:]
else:
model_name = filename[12:]
prefix_str = 'openmmlab://'
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(
f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}',
DeprecationWarning)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_from_http(model_url, map_location=map_location)
else:
filename = osp.join(_get_mmengine_home(), model_url)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_from_http(
model_urls[model_name], map_location=map_location)
checkpoint = _process_mmcls_checkpoint(checkpoint)
return checkpoint
def _load_checkpoint(filename, map_location=None, logger=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str, optional): Same as :func:`torch.load`.
Default: None.
logger (:mod:`logging.Logger`, optional): The logger for error message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
"""Load partial pretrained model with specific prefix.
Args:
prefix (str): The prefix of sub-module.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location=map_location)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if not prefix.endswith('.'):
prefix += '.'
prefix_len = len(prefix)
state_dict = {
k[prefix_len:]: v
for k, v in state_dict.items() if k.startswith(prefix)
}
assert state_dict, f'{prefix} is not in the pretrained model'
return state_dict
def _load_checkpoint_to_model(model,
checkpoint,
strict=False,
logger=None,
revise_keys=[(r'^module\.', '')]):
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
state_dict = OrderedDict(
{re.sub(p, r, k): v
for k, v in state_dict.items()})
# Keep metadata in state_dict
state_dict._metadata = metadata
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None,
revise_keys=[(r'^module\.', '')]):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location, logger)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
return _load_checkpoint_to_model(model, checkpoint, strict, logger,
revise_keys)
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict_cpu
@deprecated_function(
since='0.3.0',
removed_in='0.5.0',
instructions='`_save_to_state_dict` will be deprecated in the future, '
'please use `nn.Module._save_to_state_dict` directly.')
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
keep_vars (bool): Whether to keep the variable property of the
parameters.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
if buf is not None and name not in module._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_model_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version)
module._save_to_state_dict(destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(checkpoint,
filename,
file_client_args=None,
backend_args=None):
"""Save checkpoint to file.
Args:
checkpoint (dict): Module whose params are to be saved.
filename (str): Checkpoint filename.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Defaults to None. It will be deprecated in future. Please use
`backend_args` instead.
backend_args (dict, optional): Arguments to instantiate the
preifx of uri corresponding backend. Defaults to None.
New in v0.2.0.
"""
if file_client_args is not None:
warnings.warn(
'"file_client_args" will be deprecated in future. '
'Please use "backend_args" instead', DeprecationWarning)
if backend_args is not None:
raise ValueError(
'"file_client_args" and "backend_args" cannot be set '
'at the same time.')
if filename.startswith('pavi://'):
if file_client_args is not None or backend_args is not None:
raise ValueError(
'"file_client_args" or "backend_args" should be "None" if '
'filename starts with "pavi://"')
try:
from pavi import exception, modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except exception.NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
file_client = FileClient.infer_client(file_client_args, filename)
if file_client_args is None:
file_backend = get_file_backend(
filename, backend_args=backend_args)
else:
file_backend = file_client
with io.BytesIO() as f:
torch.save(checkpoint, f)
file_backend.put(f.getvalue(), filename)
def find_latest_checkpoint(path: str) -> Optional[str]:
"""Find the latest checkpoint from the given path.
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
Args:
path(str): The path to find checkpoints.
Returns:
str or None: File path of the latest checkpoint.
"""
save_file = osp.join(path, 'last_checkpoint')
last_saved: Optional[str]
if os.path.exists(save_file):
with open(save_file) as f:
last_saved = f.read().strip()
else:
print_log('Did not find last_checkpoint to be resumed.')
last_saved = None
return last_saved