-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathcreate_dataset.py
49 lines (39 loc) · 1.57 KB
/
create_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import argparse
import numpy as np
from src.utils import read_dataset, read_multivariate_dataset
dataset_dir = './datasets/UCRArchive_2018'
multivariate_dir = './datasets/multivariate'
output_dir = './tmp'
multivariate_datasets = ['CharacterTrajectories', 'ECG', 'KickvsPunch', 'NetFlow']
def argsparser():
parser = argparse.ArgumentParser("SimTSC data creator")
parser.add_argument('--dataset', help='Dataset name', default='Coffee')
parser.add_argument('--seed', help='Random seed', type=int, default=0)
parser.add_argument('--shot', help='How many labeled time-series per class', type=int, default=1)
return parser
if __name__ == "__main__":
# Get the arguments
parser = argsparser()
args = parser.parse_args()
# Seeding
np.random.seed(args.seed)
# Create dirs
if args.dataset in multivariate_datasets:
output_dir = os.path.join(output_dir, 'multivariate_datasets_'+str(args.shot)+'_shot')
else:
output_dir = os.path.join(output_dir, 'ucr_datasets_'+str(args.shot)+'_shot')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Read data
if args.dataset in multivariate_datasets:
X, y, train_idx, test_idx = read_multivariate_dataset(multivariate_dir, args.dataset, args.shot)
else:
X, y, train_idx, test_idx = read_dataset(dataset_dir, args.dataset, args.shot)
data = {
'X': X,
'y': y,
'train_idx': train_idx,
'test_idx': test_idx
}
np.save(os.path.join(output_dir, args.dataset), data)