https://pytutorial.marcoinacio.com/sections/pytorch_lightning/
#!pip install pytorch_lightning optuna mlflow
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader
import pickle
from copy import deepcopy
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import tempfile
import os
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.preprocessing import StandardScaler
import optuna
from optuna.integration import PyTorchLightningPruningCallback
%matplotlib inline
pl.__version__
'1.3.1'
Let's start by generating some random data
torch.manual_seed(1)
beta = torch.rand(10, 1)
train_inputv = torch.randn(700, 10)
train_target = torch.mm(train_inputv, beta)
train_target = train_target
test_inputv = torch.randn(200, 10)
test_target = torch.mm(test_inputv, beta)
test_target = test_target
cutpoints = [torch.quantile(train_target, x).item() for x in [.1, .7, .9]]
train_target_label = sum([0+(train_target > cutpoint) for cutpoint in cutpoints],0)
train_target_label = train_target_label.flatten()
test_target_label = sum([0+(test_target > cutpoint) for cutpoint in cutpoints],0)
test_target_label = test_target_label.flatten()
Let's scale our data to help the neural network training process.
scaler = StandardScaler().fit(train_inputv.numpy())
train_inputv = torch.as_tensor(scaler.transform(train_inputv), dtype=torch.float32)
test_inputv = torch.as_tensor(scaler.transform(test_inputv), dtype=torch.float32)
class LitNN(pl.LightningModule):
def __init__(self, nfeatures, n_classification_labels, hsizes = [50, 10],
lr=0.01, weight_decay=0, batch_size=50, dropout=0.5):
super().__init__()
assert n_classification_labels != 1
self.lr = lr
self.batch_size = batch_size
self.weight_decay = weight_decay
self.n_classification_labels = n_classification_labels
input_size = nfeatures
modules_list = []
for hsize in hsizes:
modules_list.extend([
nn.Linear(input_size, hsize),
nn.ELU(),
nn.BatchNorm1d(hsize),
nn.Dropout(dropout),
])
input_size = hsize
out_size = n_classification_labels if n_classification_labels else 1
modules_list.append(self._initialize_layer(nn.Linear(input_size, out_size)))
self.modules_list = nn.ModuleList(modules_list)
def forward(self, x):
for module in self.modules_list:
x = module(x)
return x
def _initialize_layer(self, layer):
nn.init.constant_(layer.bias, 0)
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(layer.weight, gain=gain)
return layer
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
return optimizer
def training_step(self, train_batch, batch_idx):
inputv, target = train_batch
output = self.forward(inputv)
print("actual batch_size", len(inputv))
if self.n_classification_labels:
loss = F.cross_entropy(output, target)
self.log('train_loss_ce', loss.item())
else:
loss = F.mse_loss(output, target)
self.log('train_loss_rmse', np.sqrt(loss.item()))
return loss
def test_validation_step(self, batch, batch_idx, name):
inputv, target = batch
output = self.forward(inputv)
if self.n_classification_labels:
loss_ce = F.cross_entropy(output, target).item()
loss_zo = (torch.argmax(output, 1) != target)+0.
loss_zo = loss_zo.mean().item()
self.log(f'{name}_loss_ce', loss_ce)
self.log(f'{name}_loss_zo', loss_zo)
else:
loss_mse = F.mse_loss(output, target).item()
loss_mae = F.l1_loss(output, target).item()
self.log(f'{name}_loss_rmse', np.sqrt(loss_mse))
self.log(f'{name}_loss_mae', loss_mae)
def validation_step(self, val_batch, batch_idx):
self.test_validation_step(val_batch, batch_idx, 'val')
def test_step(self, test_batch, batch_idx):
self.test_validation_step(test_batch, batch_idx, 'test')
class DataModule(pl.LightningDataModule):
def __init__(self, train_inputv, train_target,
test_inputv=None, test_target=None,
n_classification_labels=None, batch_size = 50,
num_workers=2, train_val_split_seed=0):
super().__init__()
assert not n_classification_labels is None
assert n_classification_labels != 1
self.batch_size = min(batch_size, len(train_target))
self.n_classification_labels = n_classification_labels
y_dtype = torch.long if n_classification_labels else torch.float32
self.train_inputv = torch.as_tensor(train_inputv, dtype=torch.float32)
self.train_target = torch.as_tensor(train_target, dtype=y_dtype)
self.test_inputv = test_inputv
self.test_target = test_target
if test_inputv is not None:
self.test_inputv = torch.as_tensor(test_inputv, dtype=torch.float32)
if test_target is not None:
self.test_target = torch.as_tensor(test_target, dtype=y_dtype)
self.num_workers = num_workers
self.train_val_split_seed = train_val_split_seed
def setup(self, stage):
if stage == 'fit':
full_dataset = TensorDataset(self.train_inputv, self.train_target)
generator = torch.Generator().manual_seed(self.train_val_split_seed)
partitions = [len(full_dataset) - len(full_dataset)//10, len(full_dataset) // 10]
full_dataset = torch.utils.data.random_split(full_dataset, partitions,
generator=generator)
self.train_dataset, self.val_dataset = full_dataset
if stage == 'test':
if self.test_inputv is not None:
self.test_dataset = TensorDataset(self.test_inputv, self.test_target)
def train_dataloader(self):
print("THIS SHOULD BE CALLED!!!!!", self.batch_size)
return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True,
shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,
num_workers = self.num_workers)
def test_dataloader(self):
if self.test_inputv is None:
raise RuntimeError("Test data not set")
return DataLoader(self.test_dataset, batch_size=self.batch_size,
num_workers = self.num_workers)
Let's check the cross entropy error performance on a Extra Trees classifier as simple baseline for our models
# For comparison
clf = ExtraTreesClassifier(n_estimators=1000, random_state=0)
clf.fit(train_inputv, train_target_label)
(clf.predict(test_inputv) != test_target_label.numpy()).mean()
0.32
Now, we train a neural network with fixed hyperparameters
datamodule = DataModule(train_inputv, train_target_label,
test_inputv, test_target_label,
n_classification_labels=4)
smodel = LitNN(nfeatures=train_inputv.shape[1], n_classification_labels=4)
early_stop_callback = EarlyStopping(
monitor='val_loss_ce',
min_delta=0.00,
patience=30,
verbose=False,
mode='min'
)
# use MLFlow as logger if available, see other options at
# https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html
# you can start MLFLow server with:
# mlflow server --backend-store-uri=./ml-runs
try:
from pytorch_lightning.loggers import MLFlowLogger
logger = MLFlowLogger(
experiment_name="Default",
tracking_uri="file:./mlruns"
)
except ImportError:
# default: Tensorboard, you can start with:
# tensorboard --logdir lightning_logs
logger = True
trainer = pl.Trainer(
precision=32,
gpus=torch.cuda.device_count(),
tpu_cores=None,
logger=logger,
val_check_interval=0.25, # do validation check 4 times for each epoch
auto_scale_batch_size=True,
#auto_lr_find=True,
callbacks=early_stop_callback,
max_epochs = 100,
)
# find "best" batch_size and lr
trainer.tune(smodel, datamodule = datamodule)
# fit smodel
trainer.fit(smodel, datamodule = datamodule)
# test smodel
trainer.test(smodel, datamodule = datamodule)
# predict smodel
test_pred = np.vstack(deepcopy(trainer).predict(deepcopy(smodel), DataLoader(test_inputv)))
# check if smodel if is pickable
_ = pickle.dumps(smodel)
smodel.trainer.callback_metrics
GPU available: True, used: True GPU available: True, used: True TPU available: False, using: 0 TPU cores TPU available: False, using: 0 TPU cores LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] /home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. warnings.warn(*args, **kwargs)
THIS SHOULD BE CALLED!!!!! 2
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. warnings.warn(*args, **kwargs) Batch size 2 succeeded, trying batch size 4 Batch size 2 succeeded, trying batch size 4 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 4 succeeded, trying batch size 8 Batch size 4 succeeded, trying batch size 8 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 8 succeeded, trying batch size 16 Batch size 8 succeeded, trying batch size 16 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 16 succeeded, trying batch size 32 Batch size 16 succeeded, trying batch size 32 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 32 succeeded, trying batch size 64 Batch size 32 succeeded, trying batch size 64 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 64 succeeded, trying batch size 128 Batch size 64 succeeded, trying batch size 128 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 128 succeeded, trying batch size 256 Batch size 128 succeeded, trying batch size 256 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 256 succeeded, trying batch size 512 Batch size 256 succeeded, trying batch size 512 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 512 succeeded, trying batch size 1024 Batch size 512 succeeded, trying batch size 1024 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
Batch size 630 succeeded, trying batch size 1260 Batch size 630 succeeded, trying batch size 1260 Finished batch size finder, will continue with full run using batch size 630 Finished batch size finder, will continue with full run using batch size 630 Restored states from the checkpoint file at /home/marco/Documents/projects/python-intro/sections/scale_batch_size_temp_model.ckpt Restored states from the checkpoint file at /home/marco/Documents/projects/python-intro/sections/scale_batch_size_temp_model.ckpt LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
actual batch_size 2 actual batch_size 2 actual batch_size 2
| Name | Type | Params -------------------------------------------- 0 | modules_list | ModuleList | 1.2 K -------------------------------------------- 1.2 K Trainable params 0 Non-trainable params 1.2 K Total params 0.005 Total estimated model params size (MB) | Name | Type | Params -------------------------------------------- 0 | modules_list | ModuleList | 1.2 K -------------------------------------------- 1.2 K Trainable params 0 Non-trainable params 1.2 K Total params 0.005 Total estimated model params size (MB)
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2 actual batch_size 2
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown... warnings.warn(*args, **kwargs) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] /home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. warnings.warn(*args, **kwargs)
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) /tmp/pytmpfiles/ipykernel_45814/2495237515.py in <module> 46 47 # test smodel ---> 48 trainer.test(smodel, datamodule = datamodule) 49 50 # predict smodel ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path, verbose, datamodule) 577 578 # run test --> 579 results = self._run(model) 580 581 assert self.state.stopped ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model) 754 755 # dispatch `start_training` or `start_evaluating` or `start_predicting` --> 756 self.dispatch() 757 758 # plugin will finalized fitting (e.g. ddp_spawn will load trained model) ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self) 791 def dispatch(self): 792 if self.evaluating: --> 793 self.accelerator.start_evaluating(self) 794 elif self.predicting: 795 self.accelerator.start_predicting(self) ~/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_evaluating(self, trainer) 97 98 def start_evaluating(self, trainer: 'pl.Trainer') -> None: ---> 99 self.training_type_plugin.start_evaluating(trainer) 100 101 def start_predicting(self, trainer: 'pl.Trainer') -> None: ~/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_evaluating(self, trainer) 146 def start_evaluating(self, trainer: 'pl.Trainer') -> None: 147 # double dispatch to initiate the test loop --> 148 self._results = trainer.run_stage() 149 150 def start_predicting(self, trainer: 'pl.Trainer') -> None: ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self) 802 803 if self.evaluating: --> 804 return self.run_evaluate() 805 if self.predicting: 806 return self.run_predict() ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluate(self) 1042 1043 with self.profiler.profile(f"run_{self.state.stage}_evaluation"): -> 1044 eval_loop_results = self.run_evaluation() 1045 1046 # remove the tensors from the eval results ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, on_epoch) 1007 1008 # log epoch metrics -> 1009 eval_loop_results = self.logger_connector.get_evaluate_epoch_results() 1010 1011 # save predictions to disk ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in get_evaluate_epoch_results(self) 274 metrics_to_log = self.cached_results.get_epoch_log_metrics() 275 if len(metrics_to_log) > 0: --> 276 self.log_metrics(metrics_to_log, {}) 277 278 self.prepare_eval_loop_results() ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in log_metrics(self, metrics, grad_norm_dic, step) 227 if self.trainer.is_global_zero: 228 self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) --> 229 self.trainer.logger.save() 230 231 # track the logged metrics ~/.local/lib/python3.8/site-packages/pytorch_lightning/loggers/base.py in save(self) 302 def save(self) -> None: 303 """Save log data.""" --> 304 self._finalize_agg_metrics() 305 306 def finalize(self, status: str) -> None: ~/.local/lib/python3.8/site-packages/pytorch_lightning/loggers/base.py in _finalize_agg_metrics(self) 143 144 if metrics_to_log is not None: --> 145 self.log_metrics(metrics=metrics_to_log, step=agg_step) 146 147 def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): ~/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py in wrapped_fn(*args, **kwargs) 47 def wrapped_fn(*args, **kwargs): 48 if rank_zero_only.rank == 0: ---> 49 return fn(*args, **kwargs) 50 51 return wrapped_fn ~/.local/lib/python3.8/site-packages/pytorch_lightning/loggers/mlflow.py in log_metrics(self, metrics, step) 203 k = new_k 204 --> 205 self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) 206 207 @rank_zero_only ~/miniforge3/lib/python3.8/site-packages/mlflow/tracking/client.py in log_metric(self, run_id, key, value, timestamp, step) 169 _validate_metric(key, value, timestamp, step) 170 metric = Metric(key, value, timestamp, step) --> 171 self.store.log_metric(run_id, metric) 172 173 def log_param(self, run_id, key, value): ~/miniforge3/lib/python3.8/site-packages/mlflow/store/file_store.py in log_metric(self, run_id, metric) 595 _validate_run_id(run_id) 596 _validate_metric_name(metric.key) --> 597 run = self.get_run(run_id) 598 check_run_is_active(run.info) 599 metric_path = self._get_metric_path(run.info.experiment_id, run_id, metric.key) ~/miniforge3/lib/python3.8/site-packages/mlflow/store/file_store.py in get_run(self, run_id) 398 databricks_pb2.INVALID_STATE) 399 metrics = self.get_all_metrics(run_id) --> 400 params = self.get_all_params(run_id) 401 tags = self.get_all_tags(run_id) 402 return Run(run_info, RunData(metrics, params, tags)) ~/miniforge3/lib/python3.8/site-packages/mlflow/store/file_store.py in get_all_params(self, run_uuid) 518 519 def get_all_params(self, run_uuid): --> 520 parent_path, param_files = self._get_run_files(run_uuid, "param") 521 params = [] 522 for param_file in param_files: ~/miniforge3/lib/python3.8/site-packages/mlflow/store/file_store.py in _get_run_files(self, run_uuid, resource_type) 422 def _get_run_files(self, run_uuid, resource_type): 423 _validate_run_id(run_uuid) --> 424 run_info = self._get_run_info(run_uuid) 425 if run_info is None: 426 raise MlflowException("Run '%s' metadata is in invalid state." % run_uuid, ~/miniforge3/lib/python3.8/site-packages/mlflow/store/file_store.py in _get_run_info(self, run_uuid) 406 Note: Will get both active and deleted runs. 407 """ --> 408 exp_id, run_dir = self._find_run_root(run_uuid) 409 if run_dir is None: 410 raise MlflowException("Run '%s' not found" % run_uuid, ~/miniforge3/lib/python3.8/site-packages/mlflow/store/file_store.py in _find_run_root(self, run_uuid) 339 all_experiments = self._get_active_experiments(True) + self._get_deleted_experiments(True) 340 for experiment_dir in all_experiments: --> 341 runs = find(experiment_dir, run_uuid, full_path=True) 342 if len(runs) == 0: 343 continue ~/miniforge3/lib/python3.8/site-packages/mlflow/utils/file_utils.py in find(root, name, full_path) 86 """ 87 path_name = os.path.join(root, name) ---> 88 return list_all(root, lambda x: x == path_name, full_path) 89 90 ~/miniforge3/lib/python3.8/site-packages/mlflow/utils/file_utils.py in list_all(root, filter_func, full_path) 44 if not is_directory(root): 45 raise Exception("Invalid parent directory '%s'" % root) ---> 46 matches = [x for x in os.listdir(root) if filter_func(os.path.join(root, x))] 47 return [os.path.join(root, m) for m in matches] if full_path else matches 48 ~/miniforge3/lib/python3.8/site-packages/mlflow/utils/file_utils.py in <listcomp>(.0) 44 if not is_directory(root): 45 raise Exception("Invalid parent directory '%s'" % root) ---> 46 matches = [x for x in os.listdir(root) if filter_func(os.path.join(root, x))] 47 return [os.path.join(root, m) for m in matches] if full_path else matches 48 ~/miniforge3/lib/python3.8/posixpath.py in join(a, *p) 86 path += b 87 else: ---> 88 path += sep + b 89 except (TypeError, AttributeError, BytesWarning): 90 genericpath._check_arg_types('join', a, *p) KeyboardInterrupt:
Let's optimize the hyperparameters using Optuna library
try:
study
except NameError:
study = optuna.create_study(direction="minimize", pruner=optuna.pruners.SuccessiveHalvingPruner())
try:
tempdir
except NameError:
tempdir = tempfile.TemporaryDirectory().name
os.mkdir(tempdir)
print(tempdir)
def objective(trial: optuna.trial.Trial) -> float:
hsize1 = trial.suggest_int("hsize1", 10, 1000)
hsize2 = trial.suggest_int("hsize2", 10, max(20, 1000 - hsize1))
batch_size = trial.suggest_int("batch_size", 50, 400)
lr = trial.suggest_float("lr", 1e-5, 0.1)
dropout = trial.suggest_float("dropout", 0.0, 0.5)
weight_decay = trial.suggest_float("weight_decay", 0.0, 0.01)
hyperparameters = dict(
hsize1=hsize1, hsize2=hsize2,
batch_size=batch_size, lr=lr,
dropout=dropout, weight_decay=weight_decay,
)
model = LitNN(hsizes = [hsize1, hsize2], lr=lr, batch_size=batch_size, dropout=dropout,
weight_decay = weight_decay, nfeatures=train_inputv.shape[1],
n_classification_labels=4)
datamodule = DataModule(train_inputv, train_target_label, batch_size=batch_size,
n_classification_labels=4)
early_stop_callback = EarlyStopping(
monitor='val_loss_ce',
min_delta=0.00,
patience=30,
verbose=False,
mode='min'
)
try:
from pytorch_lightning.loggers import MLFlowLogger
logger = MLFlowLogger(
experiment_name="Default",
tracking_uri="file:./mlruns"
)
except ImportError:
logger = True
trainer = pl.Trainer(
precision=32,
gpus=torch.cuda.device_count(),
logger=logger,
val_check_interval=0.25,
callbacks=[early_stop_callback,
PyTorchLightningPruningCallback(trial, monitor="val_loss_ce")
],
max_epochs = 100,
)
trainer.fit(model, datamodule = datamodule)
trainer.logger.log_hyperparams(hyperparameters)
with open(f"{os.path.join(tempdir, str(trial.number))}.pkl", "wb") as f:
pickle.dump(model, f)
return trainer.callback_metrics["val_loss_ce"].item()
study.optimize(objective, n_trials=10000, timeout=6)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:", study.best_params)
with open(f"{os.path.join(tempdir, str(study.best_trial.number))}.pkl", "rb") as f:
best_model = pickle.load(f)
Let's compare the results with our previous model:
best_model.trainer.test(best_model, datamodule = datamodule)
best_model.trainer.callback_metrics
smodel.trainer.callback_metrics
Let's summarize the results:
# save on study on disk
with open(f"{os.path.join(tempdir, 'study')}.pkl", "wb") as f:
pickle.dump(study, f)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:", study.best_params)
with open(f"{os.path.join(tempdir, str(study.best_trial.number))}.pkl", "rb") as f:
best_model = pickle.load(f)
trials_summary = sorted(study.trials, key=lambda x: np.inf if x.value is None else x.value)
trials_summary = [dict(trial_number=trial.number, loss=trial.value, **trial.params) for trial in trials_summary]
trials_summary = pd.DataFrame(trials_summary)
trials_summary.iloc[:200]
datamodule = DataModule(train_inputv, train_target,
test_inputv, test_target,
n_classification_labels=0)
smodel = LitNN(nfeatures=train_inputv.shape[1], n_classification_labels=0)
early_stop_callback = EarlyStopping(
monitor='val_loss_rmse',
min_delta=0.00,
patience=30,
verbose=False,
mode='min'
)
try:
from pytorch_lightning.loggers import MLFlowLogger
logger = MLFlowLogger(
experiment_name="Default",
tracking_uri="file:./mlruns"
)
except ImportError:
logger = True
trainer = pl.Trainer(
precision=32,
gpus=torch.cuda.device_count(),
tpu_cores=None,
logger=logger,
val_check_interval=0.25, # do validation check 4 times for each epoch
#auto_scale_batch_size=True,
#auto_lr_find=True,
callbacks=early_stop_callback,
max_epochs = 100,
)
trainer.fit(smodel, datamodule = datamodule)
trainer.test(smodel, datamodule = datamodule)
test_pred = np.vstack(deepcopy(trainer).predict(deepcopy(smodel), DataLoader(test_inputv)))
_ = pickle.dumps(smodel)
smodel.trainer.callback_metrics