-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_dataset_task_2.py
36 lines (26 loc) · 1.21 KB
/
prepare_dataset_task_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import pandas as pd
from load_labels import get_room_labels_by_filename
from transform import load_image_to_numpy
from constants import TASK2_CLASS_COL, TASK2_CLASS_LABELS
from helpers import get_paths, get_paths_and_names, filter_non_existent_paths, serialize_labels_with_images
data_path = "./data"
labels_file = "labels.csv"
images_dir = "main_task_data"
labels_np_file = "labels_task2.npy"
images_np_file = "imgs_task2.npy"
labels_np_path = os.path.join(data_path, labels_np_file)
images_np_path = os.path.join(data_path, images_np_file)
IMAGE_SIDE_LENGTH = 8
def main():
df = pd.read_csv(os.path.join(data_path, labels_file), usecols=["filename", TASK2_CLASS_COL])
df = df[df[TASK2_CLASS_COL].isin(TASK2_CLASS_LABELS)]
img_paths = get_paths(os.path.join(data_path, images_dir), ".jpg")
path_w_names = get_paths_and_names(img_paths)
path_w_names = filter_non_existent_paths(path_w_names, df)
labels_w_imgs = [
(get_room_labels_by_filename(df, name), load_image_to_numpy(path, side_length=IMAGE_SIDE_LENGTH)) for path, name in path_w_names
]
serialize_labels_with_images(labels_w_imgs, labels_path=labels_np_path, images_path=images_np_path)
if __name__ == "__main__":
main()