-
Notifications
You must be signed in to change notification settings - Fork 1
/
map_occupancy_prior_atc.py
58 lines (41 loc) · 1.36 KB
/
map_occupancy_prior_atc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# %% Some magic
# ! %load_ext autoreload
# ! %autoreload 2
# %% Imports
from pathlib import Path
import numpy as np
import torch
from bff.nets import DiscreteDirectional
from bff.utils import estimate_dynamics, plot_dir
from mod.occupancy import OccupancyMap
from mod.utils import Direction
# Change BASE_PATH to the folder where data and models are located
BASE_PATH = Path("/mnt/hdd/datasets/ATC/")
MAP_METADATA = BASE_PATH / "localization_grid.yaml"
NET_EPOCHS = 120
NET_WINDOW_SIZE = 64
NET_SCALE_FACTOR = 20
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PLOT_DPI = 300
occupancy = OccupancyMap.from_yaml(MAP_METADATA)
net_id_string = f"_w{NET_WINDOW_SIZE}_s{NET_SCALE_FACTOR}_t_{NET_EPOCHS}"
net = DiscreteDirectional(NET_WINDOW_SIZE)
net.load_weights(f"models/people_net{net_id_string}.pth")
# %% Build the deep prior map
prior = estimate_dynamics(
net,
occupancy,
scale=1,
net_scale=NET_SCALE_FACTOR,
device=DEVICE,
batch_size=5,
)
# %% Save deep prior map
np.save(f"maps/map_atc{net_id_string}.npy", prior)
# %% Load deep prior map
prior = np.load(f"maps/map_atc{net_id_string}.npy")
# %% Visualize
plot_dir(occupancy, prior, Direction.NW, dpi=PLOT_DPI)
plot_dir(occupancy, prior, Direction.NE, dpi=PLOT_DPI)
plot_dir(occupancy, prior, Direction.SW, dpi=PLOT_DPI)
plot_dir(occupancy, prior, Direction.SE, dpi=PLOT_DPI)