-
Notifications
You must be signed in to change notification settings - Fork 0
/
push_to_hf.py
75 lines (57 loc) · 1.98 KB
/
push_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import json
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
import click
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from msma import EDMScorer, ScoreFlow, build_model_from_pickle
@click.command
@click.option(
"--basedir",
help="Directory holding the model weights and logs",
type=str,
required=True,
)
@click.option(
"--preset", help="Preset of the score model used", type=str, required=True
)
def main(basedir, preset):
basedir = Path(basedir)
modeldir = basedir / preset
net = build_model_from_pickle(preset)
with open(modeldir / "config.json", "rb") as f:
model_params = json.load(f)
model = ScoreFlow(
net,
device="cpu",
**model_params["PatchFlow"],
)
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
api = HfApi()
# Use your own repo
repo_name = "ahsanMah/localizing-edm"
# Create repo if not existing yet and get the associated repo_id
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
# Save all files in a temporary directory and push them in a single commit
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
# Save weights
save_file(model.state_dict(), tmpdir / "model.safetensors")
# save config
(tmpdir / "config.json").write_text(
json.dumps(model.config, sort_keys=True, indent=4)
)
# save gmm and cached score likelihoods
shutil.copyfile(modeldir / "gmm.pkl", tmpdir / "gmm.pkl")
shutil.copyfile(modeldir / "refscores.npz", tmpdir / "refscores.npz")
# Generate model card
# card = generate_model_card(model)
# (tmpdir / "README.md").write_text(card)
# Save logs
shutil.copytree(modeldir / "logs", tmpdir / "logs")
# Push to hub
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
if __name__ == "__main__":
main()