-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path_noniid_partition.py
151 lines (128 loc) · 5.46 KB
/
_noniid_partition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import Any, Dict, List, Tuple
import numpy as np
__all__ = [
"non_iid_partition_with_dirichlet_distribution",
"record_data_stats",
]
def non_iid_partition_with_dirichlet_distribution(
label_list: np.ndarray,
client_num: int,
classes: int,
alpha: float,
task: str = "classification",
) -> Dict[int, List[int]]:
"""
Obtain sample index list for each client from the Dirichlet distribution.
Modified from `FedML <https://github.com/FedML-AI/FedML/blob/master/fedml_core/non_iid_partition/noniid_partition.py>`_.
This LDA method is first proposed by :
Measuring the Effects of Non-Identical Data Distribution for
Federated Visual Classification (https://arxiv.org/pdf/1909.06335.pdf).
This can generate nonIIDness with unbalance sample number in each label.
The Dirichlet distribution is a density over a K dimensional vector p whose K components are positive and sum to 1.
Dirichlet can support the probabilities of a K-way categorical event.
In FL, we can view K clients' sample number obeys the Dirichlet distribution.
For more details of the Dirichlet distribution, please check https://en.wikipedia.org/wiki/Dirichlet_distribution
Parameters
----------
label_list : numpy.ndarray
the label list from classification/segmentation dataset
client_num : int
number of clients
classes : int or list
the number of classification (e.g., 10 for CIFAR-10) OR a list of segmentation categories
alpha : float
a concentration parameter controlling the identicalness among clients.
task : {"classification", "segmentation"}
CV specific task eg. classification, segmentation
Returns
-------
net_dataidx_map : dict
keys are indices of clients,
values are sample index list
"""
net_dataidx_map = {}
K = classes
# For multiclass labels, the list is ragged and not a numpy array
N = len(label_list) if task == "segmentation" else label_list.shape[0]
# guarantee the minimum number of sample in each client
min_size = 0
while min_size < 10:
idx_batch = [[] for _ in range(client_num)]
if task == "segmentation":
# Unlike classification tasks, here, one instance may have multiple categories/classes
for c, cat in enumerate(classes):
if c > 0:
idx_k = np.asarray(
[
np.any(label_list[i] == cat) and not np.any(np.in1d(label_list[i], classes[:c]))
for i in range(len(label_list))
]
)
else:
idx_k = np.asarray([np.any(label_list[i] == cat) for i in range(len(label_list))])
# Get the indices of images that have category = c
idx_k = np.where(idx_k)[0]
(
idx_batch,
min_size,
) = _partition_class_samples_with_dirichlet_distribution(
N,
alpha,
client_num,
idx_batch,
idx_k,
)
else:
# for each classification in the dataset
for k in range(K):
# get a list of batch indexes which are belong to label k
idx_k = np.where(label_list == k)[0]
(
idx_batch,
min_size,
) = _partition_class_samples_with_dirichlet_distribution(
N,
alpha,
client_num,
idx_batch,
idx_k,
)
for i in range(client_num):
np.random.shuffle(idx_batch[i])
net_dataidx_map[i] = idx_batch[i]
return net_dataidx_map
def _partition_class_samples_with_dirichlet_distribution(
N: int,
alpha: float,
client_num: int,
idx_batch: list,
idx_k: np.ndarray,
) -> Tuple[List[int], int]:
np.random.shuffle(idx_k)
# using dirichlet distribution to determine the unbalanced proportion for each client (client_num in total)
# e.g., when client_num = 4, proportions = [0.29543505 0.38414498 0.31998781 0.00043216], sum(proportions) = 1
proportions = np.random.dirichlet(np.repeat(alpha, client_num))
# get the index in idx_k according to the dirichlet distribution
proportions = np.array([p * (len(idx_j) < N / client_num) for p, idx_j in zip(proportions, idx_batch)])
proportions = proportions / proportions.sum()
proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
# generate the batch list for each client
idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
min_size = min([len(idx_j) for idx_j in idx_batch])
return idx_batch, min_size
def record_data_stats(
y_train: np.ndarray,
net_dataidx_map: Dict[int, List[int]],
task: str = "classification",
) -> Dict[int, Dict[Any, int]]:
net_cls_counts = {}
for net_i, dataidx in net_dataidx_map.items():
unq, unq_cnt = (
np.unique(np.concatenate(y_train[dataidx]), return_counts=True)
if task == "segmentation"
else np.unique(y_train[dataidx], return_counts=True)
)
tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
net_cls_counts[net_i] = tmp
# logging.debug("Data statistics: %s" % str(net_cls_counts))
return net_cls_counts