-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbuild_balanced_subset.py
34 lines (30 loc) · 1.17 KB
/
build_balanced_subset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import os.path as osp
import attack.config as cfg
from attack import datasets
from attack.utils.utils import create_dir
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import shutil
#dataset_dir = osp.join(cfg.DATASET_ROOT, "CINIC10_2/train")
#dest_root = osp.join("/mydata/model-extraction/data/cinic10_balanced_subset65000")
dataset_dir = osp.join(cfg.DATASET_ROOT, "cifar10/train")
SIZE = 50
dest_root = osp.join(f"/mydata/model-extraction/data/cifar10_balanced_subset{SIZE*10}")
create_dir(dest_root)
np.random.seed(cfg.DS_SEED)
for c in os.listdir(dataset_dir):
# sample SIZE images from each class
src_dir = osp.join(dataset_dir, c)
image_lst = os.listdir(src_dir)
sampled_idxs = np.random.choice(list(range(len(image_lst))), replace=False, size=SIZE)
sampled_images = np.array(image_lst)[sampled_idxs]
dest_dir = osp.join(dest_root, c)
create_dir(dest_dir)
for image in sampled_images:
# copy each image to the destination
from_dir = osp.join(src_dir, image)
to_dir = osp.join(dest_dir, image)
shutil.copy(from_dir, to_dir)