-
Notifications
You must be signed in to change notification settings - Fork 1
/
build_bayesian_discrete_directional_kth.py
44 lines (37 loc) · 1.33 KB
/
build_bayesian_discrete_directional_kth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import pickle
import pandas as pd
from mod.grid import Grid
from mod.models import BayesianDiscreteDirectional
from mod.utils import XYCoords, get_local_settings
CHUNK_SIZE = 2000
MAX_OBSERVATIONS = None
_, local = get_local_settings("config/local_settings_kth.json")
csv_path = local["dataset_folder"] + "kth_trajectory_data.csv"
pickle_path = local["pickle_folder"] + "bayes/discrete_directional_kth"
input_file = pd.read_csv(csv_path, chunksize=CHUNK_SIZE)
g = Grid(
origin=XYCoords(-58.9, -30.75),
resolution=1,
model=BayesianDiscreteDirectional,
)
total_observations = 0
print("Processing prior")
g.update_model()
filename = f"{pickle_path}_{total_observations:07d}.p"
pickle.dump(g, open(filename, "wb"))
print(f"** Saved {filename}")
for chunk in input_file:
print(
f"Processing chunk [{total_observations}-"
f"{total_observations + len(chunk.index)}]"
)
g.add_data(chunk)
total_observations = total_observations + len(chunk.index)
print("** Chunk processed, updating model...")
g.update_model()
filename = f"{pickle_path}_{total_observations:07d}.p"
pickle.dump(g, open(filename, "wb"))
print(f"** Saved {filename}")
if MAX_OBSERVATIONS is not None and total_observations >= MAX_OBSERVATIONS:
print(f"** Stopping, max observation ({MAX_OBSERVATIONS}) reached")
break