-
Notifications
You must be signed in to change notification settings - Fork 4
/
update_model.py
159 lines (125 loc) · 4.63 KB
/
update_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Update the CDFs parameters of a trained model.
To be called on a model checkpoint after training. This will update the internal
CDFs related buffers required for entropy coding.
"""
import argparse
import hashlib
import sys
from pathlib import Path
from typing import Dict
import torch
from compressai.models.priors import (
FactorizedPrior,
JointAutoregressiveHierarchicalPriors,
MeanScaleHyperprior,
ScaleHyperprior,
)
from compressai.zoo.image import model_architectures as zoo_models
def sha256_file(filepath: Path, len_hash_prefix: int = 8) -> str:
# from pytorch github repo
sha256 = hashlib.sha256()
with filepath.open("rb") as f:
while True:
buf = f.read(8192)
if len(buf) == 0:
break
sha256.update(buf)
digest = sha256.hexdigest()
return digest[:len_hash_prefix]
def load_checkpoint(filepath: Path) -> Dict[str, torch.Tensor]:
checkpoint = torch.load(filepath, map_location="cpu")
if "network" in checkpoint:
state_dict = checkpoint["network"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
return state_dict
description = """
Export a trained model to a new checkpoint with an updated CDFs parameters and a
hash prefix, so that it can be loaded later via `load_state_dict_from_url`.
""".strip()
models = {
"factorized-prior": FactorizedPrior,
"jarhp": JointAutoregressiveHierarchicalPriors,
"mean-scale-hyperprior": MeanScaleHyperprior,
"scale-hyperprior": ScaleHyperprior,
}
models.update(zoo_models)
def setup_args():
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"filepath", type=str, help="Path to the checkpoint model to be exported."
)
parser.add_argument("-n", "--name", type=str, help="Exported model name.")
parser.add_argument("-d", "--dir", type=str, help="Exported model directory.")
parser.add_argument(
"--no-update",
action="store_true",
default=False,
help="Do not update the model CDFs parameters.",
)
parser.add_argument(
"-a",
"--architecture",
default="scale-hyperprior",
choices=models.keys(),
help="Set model architecture (default: %(default)s).",
)
return parser
def update(info):
filepath = Path(info["filepath"]).resolve()
if not filepath.is_file():
raise RuntimeError(f'"{filepath}" is not a valid file.')
state_dict = load_checkpoint(filepath)
model_cls_or_entrypoint = models[info["architecture"]]
if not isinstance(model_cls_or_entrypoint, type):
model_cls = model_cls_or_entrypoint()
else:
model_cls = model_cls_or_entrypoint
net = model_cls.from_state_dict(state_dict)
if not info["no-update"]:
net.update(force=True)
state_dict = net.state_dict()
if not info["updated-name"]:
filename = filepath
while filename.suffixes:
filename = Path(filename.stem)
else:
filename = info["updated-name"]
ext = "".join(filepath.suffixes)
if info["fileroot"] is not None:
output_dir = Path(info["fileroot"])
Path(output_dir).mkdir(exist_ok=True)
else:
output_dir = Path.cwd()
filepath = output_dir / f"{filename}{ext}"
torch.save(state_dict, filepath)
def main(argv):
##
# architecture: ['bmshj2018-factorized', 'bmshj2018-hyperprior', 'mbt2018-mean', 'mbt2018', 'cheng2020-anchor', 'cheng2020-attn']
info = {
"fileroot": 'checkpoint/cheng2020-attn-lambda0.01_ILSVRC2012',
"filepath": 'checkpoint/cheng2020-attn-lambda0.01_ILSVRC2012/018.pth.tar',
"updated-nameprefix": 'updated',
"updated-name": 'updated018',
"no-update": False,
"architecture": 'cheng2020-attn',
}
update(info)
if __name__ == "__main__":
main(sys.argv[1:])