forked from DavidC001/CLaP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontrastive_HPE.py
65 lines (51 loc) · 1.88 KB
/
contrastive_HPE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys
sys.path.append('.')
import argparse
import json
import torch
import numpy as np
import os
#disable the warning
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=FutureWarning)
from contrastive_training.contrastive import contrastive_pretraining
from pose_estimation.pose_estim import pose_estimation
import torch.distributed as dist
def main(args):
#read experiment json file
with open(args.experiment) as f:
data = json.load(f)
default_args = {
"device": "cuda" if torch.cuda.is_available() else "cpu",
"models_dir": "trained_models",
"datasets_dir": "datasets",
"base_model": "resnet18"
}
args = {**default_args, **data}
#get the required parameters
device = data['device']
models_dir = data['models_dir']
datasets_dir = data['datasets_dir']
# Set environment variables for distributed training
addr = 'localhost'
port = '12355'
dist.init_process_group(
backend='gloo',
rank=0,
world_size=1,
init_method=f"tcp://{addr}:{port}?use_libuv=0"
)
# if it doesn't exist, create the directory to save the models
if not os.path.exists(models_dir):
os.makedirs(models_dir)
os.makedirs(models_dir + "/resnet50")
os.makedirs(models_dir + "/resnet18")
contrastive_pretraining(args=data["contrastive"], device=device, models_dir=models_dir, datasets_dir=datasets_dir)
pose_estimation(args=data["pose_estimation"], device=device, models_dir=models_dir, datasets_dir=datasets_dir)
if __name__ == "__main__":
#get experment file name from command line (required)
parser = argparse.ArgumentParser(description='Contrastive training')
parser.add_argument('--experiment', type=str, help='Path to the experiment json file', required=True)
args = parser.parse_args()
main(args)