Skip to content

Commit

Permalink
feat: add augmentations inside the fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Mar 3, 2022
1 parent 95e0e9a commit b7bc522
Show file tree
Hide file tree
Showing 6 changed files with 698 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ NO_COLOR=\\e[39m
OK_COLOR=\\e[32m
ERROR_COLOR=\\e[31m
WARN_COLOR=\\e[33m
PORT=8889
PORT=8887
.SILENT: ;
default: help; # default target

Expand Down
350 changes: 321 additions & 29 deletions census_example.ipynb

Large diffs are not rendered by default.

104 changes: 86 additions & 18 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,9 +48,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File already exists.\n"
]
}
],
"source": [
"out.parent.mkdir(parents=True, exist_ok=True)\n",
"if out.exists():\n",
Expand All @@ -74,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -106,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -135,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -162,7 +170,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -184,9 +192,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/work/pytorch_tabnet/abstract_model.py:75: UserWarning: Device used : cuda\n",
" warnings.warn(f\"Device used : {self.device}\")\n"
]
}
],
"source": [
"clf = TabNetClassifier(\n",
" n_d=64, n_a=64, n_steps=5,\n",
Expand All @@ -212,7 +229,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -233,11 +250,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 5 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 50 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand All @@ -246,14 +263,65 @@
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 | loss: 1.33128 | train_accuracy: 0.3806 | valid_accuracy: 0.37699 | 0:00:09s\n",
"epoch 1 | loss: 0.85144 | train_accuracy: 0.46719 | valid_accuracy: 0.46617 | 0:00:18s\n",
"epoch 2 | loss: 0.76028 | train_accuracy: 0.24121 | valid_accuracy: 0.24136 | 0:00:28s\n",
"epoch 3 | loss: 0.72163 | train_accuracy: 0.37581 | valid_accuracy: 0.37446 | 0:00:37s\n",
"epoch 4 | loss: 0.69439 | train_accuracy: 0.39981 | valid_accuracy: 0.39821 | 0:00:47s\n",
"epoch 5 | loss: 0.68769 | train_accuracy: 0.35353 | valid_accuracy: 0.35281 | 0:00:56s\n",
"epoch 6 | loss: 0.67138 | train_accuracy: 0.39872 | valid_accuracy: 0.40041 | 0:01:05s\n",
"epoch 7 | loss: 0.66065 | train_accuracy: 0.36233 | valid_accuracy: 0.36153 | 0:01:15s\n",
"epoch 8 | loss: 0.64885 | train_accuracy: 0.36858 | valid_accuracy: 0.36746 | 0:01:24s\n",
"epoch 9 | loss: 0.63634 | train_accuracy: 0.44304 | valid_accuracy: 0.44233 | 0:01:34s\n",
"epoch 10 | loss: 0.62768 | train_accuracy: 0.46735 | valid_accuracy: 0.46732 | 0:01:43s\n",
"epoch 11 | loss: 0.62027 | train_accuracy: 0.47327 | valid_accuracy: 0.47222 | 0:01:53s\n",
"epoch 12 | loss: 0.60642 | train_accuracy: 0.53517 | valid_accuracy: 0.53464 | 0:02:02s\n",
"epoch 13 | loss: 0.59722 | train_accuracy: 0.57827 | valid_accuracy: 0.5759 | 0:02:11s\n",
"epoch 14 | loss: 0.59 | train_accuracy: 0.60386 | valid_accuracy: 0.60246 | 0:02:21s\n",
"epoch 15 | loss: 0.58667 | train_accuracy: 0.65109 | valid_accuracy: 0.64899 | 0:02:30s\n",
"epoch 16 | loss: 0.57656 | train_accuracy: 0.67919 | valid_accuracy: 0.67702 | 0:02:39s\n",
"epoch 17 | loss: 0.5699 | train_accuracy: 0.70977 | valid_accuracy: 0.7076 | 0:02:49s\n",
"epoch 18 | loss: 0.56334 | train_accuracy: 0.72038 | valid_accuracy: 0.71772 | 0:02:59s\n",
"epoch 19 | loss: 0.5563 | train_accuracy: 0.72982 | valid_accuracy: 0.72799 | 0:03:08s\n",
"epoch 20 | loss: 0.54899 | train_accuracy: 0.75258 | valid_accuracy: 0.75051 | 0:03:18s\n",
"epoch 21 | loss: 0.54273 | train_accuracy: 0.76514 | valid_accuracy: 0.76315 | 0:03:27s\n",
"epoch 22 | loss: 0.53724 | train_accuracy: 0.77599 | valid_accuracy: 0.77383 | 0:03:37s\n",
"epoch 23 | loss: 0.53204 | train_accuracy: 0.78374 | valid_accuracy: 0.7824 | 0:03:46s\n",
"epoch 24 | loss: 0.53061 | train_accuracy: 0.78619 | valid_accuracy: 0.78474 | 0:03:56s\n",
"epoch 25 | loss: 0.52426 | train_accuracy: 0.79484 | valid_accuracy: 0.79346 | 0:04:05s\n",
"epoch 26 | loss: 0.51819 | train_accuracy: 0.79971 | valid_accuracy: 0.79761 | 0:04:15s\n",
"epoch 27 | loss: 0.51111 | train_accuracy: 0.80415 | valid_accuracy: 0.80202 | 0:04:24s\n",
"epoch 28 | loss: 0.50293 | train_accuracy: 0.80586 | valid_accuracy: 0.80376 | 0:04:34s\n",
"epoch 29 | loss: 0.50046 | train_accuracy: 0.80905 | valid_accuracy: 0.80544 | 0:04:43s\n",
"epoch 30 | loss: 0.49694 | train_accuracy: 0.81142 | valid_accuracy: 0.8085 | 0:04:53s\n",
"epoch 31 | loss: 0.49653 | train_accuracy: 0.81658 | valid_accuracy: 0.81348 | 0:05:03s\n",
"epoch 32 | loss: 0.49146 | train_accuracy: 0.81722 | valid_accuracy: 0.81435 | 0:05:12s\n",
"epoch 33 | loss: 0.48624 | train_accuracy: 0.82066 | valid_accuracy: 0.81849 | 0:05:22s\n",
"epoch 34 | loss: 0.48193 | train_accuracy: 0.82114 | valid_accuracy: 0.81849 | 0:05:31s\n",
"epoch 35 | loss: 0.47773 | train_accuracy: 0.8264 | valid_accuracy: 0.8223 | 0:05:41s\n",
"epoch 36 | loss: 0.46996 | train_accuracy: 0.82837 | valid_accuracy: 0.8253 | 0:05:50s\n",
"epoch 37 | loss: 0.46653 | train_accuracy: 0.83143 | valid_accuracy: 0.82765 | 0:06:00s\n",
"epoch 38 | loss: 0.46379 | train_accuracy: 0.82878 | valid_accuracy: 0.82461 | 0:06:10s\n",
"epoch 39 | loss: 0.4592 | train_accuracy: 0.83674 | valid_accuracy: 0.83356 | 0:06:19s\n"
]
}
],
"source": [
"from pytorch_tabnet.augmentations import ClassificationSMOTE\n",
"aug = ClassificationSMOTE(p=0.2)\n",
"\n",
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" max_epochs=max_epochs, patience=100,\n",
" batch_size=16384, virtual_batch_size=256\n",
" batch_size=16384, virtual_batch_size=256,\n",
" augmentations=aug\n",
") "
]
},
Expand Down
13 changes: 11 additions & 2 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def fit(
callbacks=None,
pin_memory=True,
from_unsupervised=None,
warm_start=False
warm_start=False,
augmentations=None,
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -183,6 +184,11 @@ def fit(
self.input_dim = X_train.shape[1]
self._stop_training = False
self.pin_memory = pin_memory and (self.device.type != "cpu")
self.augmentations = augmentations

if self.augmentations is not None:
# This ensure reproducibility
self.augmentations._set_seed()

eval_set = eval_set if eval_set else []

Expand Down Expand Up @@ -472,9 +478,12 @@ def _train_batch(self, X, y):
"""
batch_logs = {"batch_size": X.shape[0]}

X = X.to(self.device).float()
X = X.to(self.device).float() # Is this .float() needed ?
y = y.to(self.device).float()

if self.augmentations is not None:
X, y = self.augmentations(X, y)

for param in self.network.parameters():
param.grad = None

Expand Down
90 changes: 90 additions & 0 deletions pytorch_tabnet/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from pytorch_tabnet.utils import define_device
import numpy as np

# TODO : change this so that p would be the proportion of rows that are changed
# add a beta argument (beta distribution)
class RegressionSMOTE():
"""
Apply SMOTE
This will average a percentage p of the elements in the batch with other elements.
The target will be averaged as well (this might work with binary classification and certain loss),
following a beta distribution.
"""
def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0):
""
self.seed = seed
self._set_seed()
self.device = define_device(device_name)
self.alpha = alpha
self.beta = beta
self.p = p
if (p < 0.) or (p > 1.0):
raise ValueError("Value of p should be between 0. and 1.")

def _set_seed(self):
torch.manual_seed(self.seed)
np.random.seed(self.seed)
return

def __call__(self, X, y):
batch_size = X.shape[0]
random_values = torch.rand(batch_size, device=self.device)
idx_to_change = random_values < self.p

# ensure that first element to switch has probability > 0.5
np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5
random_betas = torch.from_numpy(np_betas).to(self.device).float()
index_permute = torch.randperm(batch_size, device=self.device)

X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \
(1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size())

y[idx_to_change] = random_betas[idx_to_change, None]*y[idx_to_change] + \
(1 - random_betas[idx_to_change, None])*y[index_permute][idx_to_change].view(y[idx_to_change].size())

return X, y

class ClassificationSMOTE():
"""
Apply SMOTE for classification tasks.
This will average a percentage p of the elements in the batch with other elements.
The target will stay unchanged and keep the value of the most important row in the mix.
"""
def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0):
""
self.seed = seed
self._set_seed()
self.device = define_device(device_name)
self.alpha = alpha
self.beta = beta
self.p = p
if (p < 0.) or (p > 1.0):
raise ValueError("Value of p should be between 0. and 1.")

def _set_seed(self):
torch.manual_seed(self.seed)
np.random.seed(self.seed)
return

def __call__(self, X, y):
batch_size = X.shape[0]
random_values = torch.rand(batch_size, device=self.device)
idx_to_change = random_values < self.p

# ensure that first element to switch has probability > 0.5
np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5
random_betas = torch.from_numpy(np_betas).to(self.device).float()
index_permute = torch.randperm(batch_size, device=self.device)

X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \
(1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size())


return X, y



Loading

0 comments on commit b7bc522

Please sign in to comment.