-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSMOTE.py
92 lines (78 loc) · 2.84 KB
/
SMOTE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
"""
@author : Suvodeep M (smajumd3@ncsu.edu)
"""
from __future__ import print_function, division
import pdb
import unittest
import random
from collections import Counter
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors as NN
class smote(object):
def __init__(self, pd_data, neighbor=5,r=2 ,up_to_num=[],auto=True):
"""
:param pd_data: panda.DataFrame, the last column must be class label
:param neighbor: num of nearst neighbors to select
:param up_to_num: size of minorities to over-sampling
:param up_to_max: if up_to_num is not supplied, all minority classes will
be over-sampled as much as majority class
:return panda.DataFrame smoted data
"""
self.set_data(pd_data)
self.auto = auto
self.neighbor = neighbor
self.up_to_max = False
self.up_to_num = up_to_num
self.r = r
self.label_num = len(set(pd_data[pd_data.columns[-1]].values))
def set_data(self, pd_data):
if not pd_data.empty:
self.data = pd_data
else:
raise ValueError(
"The last column of pd_data should be string as class label")
def get_majority_num(self):
total_data = self.data.values.tolist()
labelCont = Counter(self.data[self.data.columns[-1]].values)
majority_num = max(labelCont.values())
return majority_num
def run(self):
"""
run smote
"""
def get_ngbr(data_no_label, knn):
rand_sample_idx = random.randint(0, len(data_no_label) - 1)
rand_sample = data_no_label[rand_sample_idx]
distance, ngbr = knn.kneighbors(rand_sample.reshape(1, -1))
rand_ngbr_idx = random.randint(0, len(ngbr))
return data_no_label[rand_ngbr_idx], rand_sample
total_data = self.data.values.tolist()
labelCont = Counter(self.data[self.data.columns[-1]].values)
majority_num = max(labelCont.values())
for label, num in labelCont.items():
if num < majority_num:
to_add = majority_num - num
last_column = self.data[self.data.columns[-1]]
data_w_label = self.data.loc[last_column == label]
data_no_label = data_w_label[self.data.columns[:-1]].values
if len(data_no_label) < self.neighbor:
num_neigh = len(data_no_label)
else:
num_neigh = self.neighbor
knn = NN(n_neighbors=num_neigh,p=self.r,algorithm='ball_tree').fit(data_no_label)
if self.auto:
to_add = to_add
else:
to_add = self.up_to_num
for _ in range(to_add):
rand_ngbr, sample = get_ngbr(data_no_label, knn)
new_row = []
for i, one in enumerate(rand_ngbr):
gap = random.random()
new_row.append(max(0, sample[i] + (
sample[i] - one) * gap))
new_row.append(label)
total_data.append(new_row)
return pd.DataFrame(total_data)