-
Notifications
You must be signed in to change notification settings - Fork 64
/
load_data.py
176 lines (129 loc) · 5.33 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import logging
import random
from torchvision import datasets, transforms
import utils.dists as dists
class Generator(object):
"""Generate federated learning training and testing data."""
# Abstract read function
def read(self, path):
# Read the dataset, set: trainset, testset, labels
raise NotImplementedError
# Group the data by label
def group(self):
# Create empty dict of labels
grouped_data = {label: []
for label in self.labels} # pylint: disable=no-member
# Populate grouped data dict
for datapoint in self.trainset: # pylint: disable=all
_, label = datapoint # Extract label
label = self.labels[label]
grouped_data[label].append( # pylint: disable=no-member
datapoint)
self.trainset = grouped_data # Overwrite trainset with grouped data
# Run data generation
def generate(self, path):
self.read(path)
self.trainset_size = len(self.trainset) # Extract trainset size
self.group()
return self.trainset
class Loader(object):
"""Load and pass IID data partitions."""
def __init__(self, config, generator):
# Get data from generator
self.config = config
self.trainset = generator.trainset
self.testset = generator.testset
self.labels = generator.labels
self.trainset_size = generator.trainset_size
# Store used data seperately
self.used = {label: [] for label in self.labels}
self.used['testset'] = []
def extract(self, label, n):
if len(self.trainset[label]) > n:
extracted = self.trainset[label][:n] # Extract data
self.used[label].extend(extracted) # Move data to used
del self.trainset[label][:n] # Remove from trainset
return extracted
else:
logging.warning('Insufficient data in label: {}'.format(label))
logging.warning('Dumping used data for reuse')
# Unmark data as used
for label in self.labels:
self.trainset[label].extend(self.used[label])
self.used[label] = []
# Extract replenished data
return self.extract(label, n)
def get_partition(self, partition_size):
# Get an partition uniform across all labels
# Use uniform distribution
dist = dists.uniform(partition_size, len(self.labels))
partition = [] # Extract data according to distribution
for i, label in enumerate(self.labels):
partition.extend(self.extract(label, dist[i]))
# Shuffle data partition
random.shuffle(partition)
return partition
def get_testset(self):
# Return the entire testset
return self.testset
class BiasLoader(Loader):
"""Load and pass 'preference bias' data partitions."""
def get_partition(self, partition_size, pref):
# Get a non-uniform partition with a preference bias
# Extract bias configuration from config
bias = self.config.data.bias['primary']
secondary = self.config.data.bias['secondary']
# Calculate sizes of majorty and minority portions
majority = int(partition_size * bias)
minority = partition_size - majority
# Calculate number of minor labels
len_minor_labels = len(self.labels) - 1
if secondary:
# Distribute to random secondary label
dist = [0] * len_minor_labels
dist[random.randint(0, len_minor_labels - 1)] = minority
else:
# Distribute among all minority labels
dist = dists.uniform(minority, len_minor_labels)
# Add majority data to distribution
dist.insert(self.labels.index(pref), majority)
partition = [] # Extract data according to distribution
for i, label in enumerate(self.labels):
partition.extend(self.extract(label, dist[i]))
# Shuffle data partition
random.shuffle(partition)
return partition
class ShardLoader(Loader):
"""Load and pass 'shard' data partitions."""
def create_shards(self):
# Extract shard configuration from config
per_client = self.config.data.shard['per_client']
# Determine correct total shards, shard size
total = self.config.clients.total * per_client
shard_size = int(self.trainset_size / total)
data = [] # Flatten data
for _, items in self.trainset.items():
data.extend(items)
shards = [data[(i * shard_size):((i + 1) * shard_size)]
for i in range(total)]
random.shuffle(shards)
self.shards = shards
self.used = []
logging.info('Created {} shards of size {}'.format(
len(shards), shard_size))
def extract_shard(self):
shard = self.shards[0]
self.used.append(shard)
del self.shards[0]
return shard
def get_partition(self):
# Get a partition shard
# Extract number of shards per client
per_client = self.config.data.shard['per_client']
# Create data partition
partition = []
for i in range(per_client):
partition.extend(self.extract_shard())
# Shuffle data partition
random.shuffle(partition)
return partition