-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtimeseries_loader.py
51 lines (45 loc) · 1.82 KB
/
timeseries_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import random
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
class TimeSeriesDatasetLazy(Dataset):
def __init__(self, data_root, tasks, split='TEST', sequence_length=512,
balance=False):
# === Balanced ===
if balance:
dataset_dirs = []
for task in tasks:
dataset_dirs.extend(glob.glob(f"{data_root}/{task}/{split}/*"))
files = [glob.glob(f"{x}/*.csv") for x in dataset_dirs]
data = []
if split == "TRAIN":
max_dataset_samples = max([len(x) for x in files])
for dataset in files:
if len(dataset) < max_dataset_samples:
data.extend(dataset * np.ceil(max_dataset_samples/ len(dataset)).astype(int))
else:
data.extend(dataset)
else:
data = glob.glob(f"{data_root}/{split}/*/*.csv")
random.shuffle(data)
self.data = np.array(data)
else:
# === Original ===
data = []
for task in tasks:
data.extend(glob.glob(f"{data_root}/{task}/{split}/*/*.csv"))
random.shuffle(data)
self.data = np.array(data)
self.sequence_length = sequence_length
def __getitem__(self, idx):
x = torch.from_numpy(pd.read_csv(self.data[idx]).values.flatten()).float().view(1, -1)
if not x.shape[-1] == self.sequence_length:
x = F.interpolate(x.view(1, 1, -1), self.sequence_length, mode='linear')
x = (x - x.mean(-1, keepdims=True)) / (x.var(-1, keepdims=True) + 1e-5).sqrt()
x = x.squeeze()
return x
def __len__(self):
return len(self.data)