-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgtzanmelspec.py
59 lines (47 loc) · 1.64 KB
/
gtzanmelspec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import pandas as pd
"""
gtzanmelspec.py
Contains the custom-made GTZAN Mel-Spectrogram dataset with labels
"""
# Change string labels to numeric values
label_dict = {'blues' : 0,
'classical' : 1,
'country' : 2,
'disco' : 3,
'hiphop' : 4,
'jazz' : 5,
'metal' : 6,
'pop' : 7,
'reggae' : 8,
'rock' : 9}
class GTZANMel(Dataset):
def __init__(self, csv_path):
"""
csv_path (string): path to csv file
drop the header row which contains metadata info
"""
self.to_tensor = transforms.ToTensor()
self.data_info = pd.read_csv(csv_path, header=None)
self.data_info.drop(self.data_info.head(1).index, inplace=True)
# mel spectrogram image files are located in index 2
self.image_arr = np.asarray(self.data_info.iloc[:, 2])
# labels are located in index 3
self.label_arr = np.asarray(self.data_info.iloc[:, 3].map(label_dict))
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
imgitem = self.image_arr[index]
img = Image.open(imgitem)
imgtensor = self.to_tensor(img)
imglabel = self.label_arr[index]
return (imgtensor, imglabel)
def __len__(self):
return self.data_len
if __name__ == "__main__":
# Just for testing purposes
data = GTZANMel('metadata_test.csv')
single = data.__getitem__(0)
print(single)