-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset_nonIID_split.py
55 lines (40 loc) · 1.81 KB
/
dataset_nonIID_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import sys
import pandas as pd
from create_data import create_data_distribution
dir_path = os.getcwd()
clients = int(sys.argv[1])
data_dir = str(sys.argv[2])
dataset_train = pd.read_csv(data_dir + "/kiba_train.csv", sep=',', header=0)
dataset_test = pd.read_csv(data_dir + "/kiba_test.csv", sep=',', header=0)
dataset_complete = dataset_train.append(dataset_test, ignore_index=True).sample(frac=1)
dist = {}
for client, value in zip(range(clients), [1/clients]*clients):
dist[client] = value
print("Current distribution for split " + str(clients) + ": " + str(dist))
dataset = dataset_complete.copy(deep=True)
prop_table = dataset['target_sequence'].value_counts(normalize=True).sample(frac=1)
run_path = data_dir + "/run_nonIID_fl_protein_" + str(clients)
os.makedirs(run_path, exist_ok=True)
assignments = {}
dist_sorted = dict(sorted(dist.items(), key=lambda x: x[1]))
last_client = list(dist_sorted.items())[-1][0]
for client, perc in dist_sorted.items():
client_path = run_path + "/client_" + str(client)
os.makedirs(client_path, exist_ok=True)
count = 0
drugs = []
for drug, prop in prop_table.items():
if (count + prop) > perc and len(drugs) > 0 and not client == last_client:
break
count += prop
drugs.append(drug)
prop_table.drop(drugs, inplace=True)
partition = dataset.loc[dataset['target_sequence'].isin(drugs)]
train_partition = partition.sample(frac=0.7)
test_partition = partition.drop(train_partition.index)
train_partition.to_csv(client_path + "/kiba_train.csv")
print("New TRAIN dataset split created at " + client_path + "/kiba_train.csv")
test_partition.to_csv(client_path + "/kiba_test.csv")
print("New TEST dataset split created at " + client_path + "/kiba_test.csv")
create_data_distribution(client_path)