@@ -72,55 +72,5 @@ def __init__(self, data_root, domains: list, status: str = 'train', trim: int =
72
72
73
73
assert len (self .domain_id ) == len (self .data ), f'domain_id not match data'
74
74
75
- # self.train = True if status == 'train' else False
76
- # self.num_domain = len(domains)
77
- #
78
- # # read txt files
79
- # data = []
80
- # len_domains = []
81
- # for _domain in domains:
82
- # suffix = 'train' if status == 'train' else 'test'
83
- # _data = read_images_labels(
84
- # os.path.join(f'dataset_map/domainnet', f'{_domain}_{suffix}.txt'),
85
- # shuffle=(status == 'train'),
86
- # trim=trim
87
- # )
88
- # len_domains.append(len(_data))
89
- # data.append(_data)
90
- #
91
- # max_len = max(len_domains)
92
- # # keep all domains have same # data | training
93
- # if status == 'train':
94
- # for i in range(len(data)):
95
- # data[i] = data[i] * round(.5 + max_len / len(data[i]))
96
- # self.data = data
97
- # else:
98
- # # cat to one domain
99
- # self.data = [functools.reduce(operator.iconcat, data, [])]
100
- #
101
- # domain_id = [[i] * len_domains[i] for i in range(self.num_domain)]
102
- # self.domain_id = [functools.reduce(operator.iconcat, domain_id, [])]
103
- #
104
- # assert len(self.domain_id) == len(self.data), f'domain_id not match data'
105
- #
106
- # def __getitem__(self, index):
107
- # if self.train:
108
- # domain = random.randint(0, self.num_domain - 1)
109
- # path, label = self.data[domain][index]
110
- # else:
111
- # domain = self.domain_id[0][index]
112
- # path, label = self.data[0][index]
113
- # path = os.path.join(self.image_root, path)
114
- # with Image.open(path) as image:
115
- # image = image.convert('RGB')
116
- # if self.transform is not None:
117
- # image = self.transform(image)
118
- #
119
- # return {
120
- # 'image': image,
121
- # 'label': label,
122
- # 'domain': domain
123
- # }
124
-
125
75
def __len__ (self ):
126
76
return len (self .data )
0 commit comments