-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinference_time.py
82 lines (55 loc) · 1.86 KB
/
inference_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import logging
import random
import hydra
import numpy as np
import multiprocessing as mp
import wandb
from omegaconf import DictConfig, OmegaConf
import torch
import time
from agents.utils import sim_framework_path
log = logging.getLogger(__name__)
print(torch.cuda.is_available())
OmegaConf.register_new_resolver(
"add", lambda *numbers: sum(numbers)
)
torch.cuda.empty_cache()
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
@hydra.main(config_path="config", config_name="multi_task.yaml")
def main(cfg: DictConfig) -> None:
set_seed_everywhere(cfg.seed)
wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
run = wandb.init(
project=cfg.wandb.project,
entity=cfg.wandb.entity,
group=cfg.group,
mode="disabled",
config=wandb.config
)
agent = hydra.utils.instantiate(cfg.agents)
all_time = []
for data in agent.train_dataloader:
bp_imgs, inhand_imgs, action, mask = data
bp_imgs = bp_imgs.to(agent.device)
inhand_imgs = inhand_imgs.to(agent.device)
# obs = agent.scaler.scale_input(obs)
action = agent.scaler.scale_output(action)
action = action[:, agent.obs_seq_len - 1:, :].contiguous()
# obs = obs[:, :agent.obs_seq_len].contiguous()
bp_imgs = bp_imgs[:, :agent.obs_seq_len].contiguous()
inhand_imgs = inhand_imgs[:, :agent.obs_seq_len].contiguous()
state = (bp_imgs, inhand_imgs)
start = time.time()
model_pred = agent.model(state, goal=None)
end = time.time()
all_time.append(end - start)
all_time = np.array(all_time)
print('mean inference time is: ', all_time.mean())
if __name__ == "__main__":
main()