Skip to content
This repository has been archived by the owner on Mar 22, 2021. It is now read-only.

Commit

Permalink
added k-fold validation and averaging, added saving oof predictions, … (
Browse files Browse the repository at this point in the history
#73)

* added k-fold validation and averaging, added saving oof predictions, fixed pytorch memory issues, updated results exploration

* updated utils
  • Loading branch information
jakubczakon authored Sep 10, 2018
1 parent 6fab94b commit 17c8eec
Show file tree
Hide file tree
Showing 7 changed files with 570 additions and 393 deletions.
287 changes: 81 additions & 206 deletions common_blocks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import numpy as np
import torch
from PIL import Image
from deepsense import neptune
import neptune
from torch.autograd import Variable
from torch.optim.lr_scheduler import ExponentialLR
from tempfile import TemporaryDirectory
Expand All @@ -25,6 +24,7 @@
ORIGINAL_SIZE = (101, 101)
THRESHOLD = 0.5


class Callback:
def __init__(self):
self.epoch_id = None
Expand Down Expand Up @@ -159,63 +159,6 @@ def on_batch_end(self, metrics, *args, **kwargs):
self.batch_id += 1


class ValidationMonitor(Callback):
def __init__(self, epoch_every=None, batch_every=None):
super().__init__()
if epoch_every == 0:
self.epoch_every = False
else:
self.epoch_every = epoch_every
if batch_every == 0:
self.batch_every = False
else:
self.batch_every = batch_every

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
val_loss = self.get_validation_loss()
self.model.train()
for name, loss in val_loss.items():
loss = loss.data.cpu().numpy()[0]
logger.info('epoch {0} validation {1}: {2:.5f}'.format(self.epoch_id, name, loss))
self.epoch_id += 1


class EarlyStopping(Callback):
def __init__(self, patience, minimize=True):
super().__init__()
self.patience = patience
self.minimize = minimize
self.best_score = None
self.epoch_since_best = 0
self._training_break = False

def on_epoch_end(self, *args, **kwargs):
self.model.eval()
val_loss = self.get_validation_loss()
loss_sum = val_loss['sum']
loss_sum = loss_sum.data.cpu().numpy()[0]

self.model.train()

if not self.best_score:
self.best_score = loss_sum

if (self.minimize and loss_sum < self.best_score) or (not self.minimize and loss_sum > self.best_score):
self.best_score = loss_sum
self.epoch_since_best = 0
else:
self.epoch_since_best += 1

if self.epoch_since_best > self.patience:
self._training_break = True
self.epoch_id += 1

def training_break(self, *args, **kwargs):
return self._training_break


class ExponentialLRScheduler(Callback):
def __init__(self, gamma, epoch_every=1, batch_every=None):
super().__init__()
Expand Down Expand Up @@ -256,50 +199,63 @@ def on_batch_end(self, *args, **kwargs):
self.batch_id += 1


class ModelCheckpoint(Callback):
def __init__(self, filepath, epoch_every=1, minimize=True):
class ExperimentTiming(Callback):
def __init__(self, epoch_every=None, batch_every=None):
super().__init__()
self.filepath = filepath
self.minimize = minimize
self.best_score = None

if epoch_every == 0:
self.epoch_every = False
else:
self.epoch_every = epoch_every
if batch_every == 0:
self.batch_every = False
else:
self.batch_every = batch_every
self.batch_start = None
self.epoch_start = None
self.current_sum = None
self.current_mean = None

def on_train_begin(self, *args, **kwargs):
self.epoch_id = 0
self.batch_id = 0
os.makedirs(os.path.dirname(self.filepath), exist_ok=True)

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
val_loss = self.get_validation_loss()
loss_sum = val_loss['sum']
loss_sum = loss_sum.data.cpu().numpy()[0]

self.model.train()
logger.info('starting training...')

if self.best_score is None:
self.best_score = loss_sum
def on_train_end(self, *args, **kwargs):
logger.info('training finished')

if (self.minimize and loss_sum < self.best_score) or (not self.minimize and loss_sum > self.best_score) or (
self.epoch_id == 0):
self.best_score = loss_sum
save_model(self.model, self.filepath)
logger.info('epoch {0} model saved to {1}'.format(self.epoch_id, self.filepath))
def on_epoch_begin(self, *args, **kwargs):
if self.epoch_id > 0:
epoch_time = datetime.now() - self.epoch_start
if self.epoch_every:
if (self.epoch_id % self.epoch_every) == 0:
logger.info('epoch {0} time {1}'.format(self.epoch_id - 1, str(epoch_time)[:-7]))
self.epoch_start = datetime.now()
self.current_sum = timedelta()
self.current_mean = timedelta()
logger.info('epoch {0} ...'.format(self.epoch_id))

self.epoch_id += 1
def on_batch_begin(self, *args, **kwargs):
if self.batch_id > 0:
current_delta = datetime.now() - self.batch_start
self.current_sum += current_delta
self.current_mean = self.current_sum / self.batch_id
if self.batch_every:
if self.batch_id > 0 and (((self.batch_id - 1) % self.batch_every) == 0):
logger.info('epoch {0} average batch time: {1}'.format(self.epoch_id, str(self.current_mean)[:-5]))
if self.batch_every:
if self.batch_id == 0 or self.batch_id % self.batch_every == 0:
logger.info('epoch {0} batch {1} ...'.format(self.epoch_id, self.batch_id))
self.batch_start = datetime.now()


class NeptuneMonitor(Callback):
def __init__(self, model_name):
def __init__(self, image_nr, image_resize, model_name):
super().__init__()
self.model_name = model_name
self.ctx = neptune.Context()
self.epoch_loss_averager = Averager()
self.image_nr = image_nr
self.image_resize = image_resize

def on_train_begin(self, *args, **kwargs):
self.epoch_loss_averagers = {}
Expand Down Expand Up @@ -338,8 +294,8 @@ def _send_numeric_channels(self, *args, **kwargs):
self.ctx.channel_send('{} epoch_val {} loss'.format(self.model_name, name), x=self.epoch_id, y=loss)


class ExperimentTiming(Callback):
def __init__(self, epoch_every=None, batch_every=None):
class ValidationMonitor(Callback):
def __init__(self, data_dir, loader_mode, epoch_every=None, batch_every=None):
super().__init__()
if epoch_every == 0:
self.epoch_every = False
Expand All @@ -349,119 +305,7 @@ def __init__(self, epoch_every=None, batch_every=None):
self.batch_every = False
else:
self.batch_every = batch_every
self.batch_start = None
self.epoch_start = None
self.current_sum = None
self.current_mean = None

def on_train_begin(self, *args, **kwargs):
self.epoch_id = 0
self.batch_id = 0
logger.info('starting training...')

def on_train_end(self, *args, **kwargs):
logger.info('training finished')

def on_epoch_begin(self, *args, **kwargs):
if self.epoch_id > 0:
epoch_time = datetime.now() - self.epoch_start
if self.epoch_every:
if (self.epoch_id % self.epoch_every) == 0:
logger.info('epoch {0} time {1}'.format(self.epoch_id - 1, str(epoch_time)[:-7]))
self.epoch_start = datetime.now()
self.current_sum = timedelta()
self.current_mean = timedelta()
logger.info('epoch {0} ...'.format(self.epoch_id))

def on_batch_begin(self, *args, **kwargs):
if self.batch_id > 0:
current_delta = datetime.now() - self.batch_start
self.current_sum += current_delta
self.current_mean = self.current_sum / self.batch_id
if self.batch_every:
if self.batch_id > 0 and (((self.batch_id - 1) % self.batch_every) == 0):
logger.info('epoch {0} average batch time: {1}'.format(self.epoch_id, str(self.current_mean)[:-5]))
if self.batch_every:
if self.batch_id == 0 or self.batch_id % self.batch_every == 0:
logger.info('epoch {0} batch {1} ...'.format(self.epoch_id, self.batch_id))
self.batch_start = datetime.now()


class ReduceLROnPlateau(Callback): # thank you keras
def __init__(self):
super().__init__()
pass


class NeptuneMonitorSegmentation(NeptuneMonitor):
def __init__(self, image_nr, image_resize, model_name):
super().__init__(model_name)
self.image_nr = image_nr
self.image_resize = image_resize

def on_epoch_end(self, *args, **kwargs):
self._send_numeric_channels()
# self._send_image_channels()
self.epoch_id += 1

def _send_image_channels(self):
self.model.eval()
pred_masks = self.get_prediction_masks()
self.model.train()

for name, pred_mask in pred_masks.items():
for i, image_duplet in enumerate(pred_mask):
h, w = image_duplet.shape[1:]
image_glued = np.zeros((h, 2 * w + 10))

image_glued[:, :w] = image_duplet[0, :, :]
image_glued[:, (w + 10):] = image_duplet[1, :, :]

pill_image = Image.fromarray((image_glued * 255.).astype(np.uint8))
h_, w_ = image_glued.shape
pill_image = pill_image.resize((int(self.image_resize * w_), int(self.image_resize * h_)),
Image.ANTIALIAS)

self.ctx.channel_send('{} {}'.format(self.model_name, name), neptune.Image(
name='epoch{}_batch{}_idx{}'.format(self.epoch_id, self.batch_id, i),
description="true and prediction masks",
data=pill_image))

if i == self.image_nr:
break

def get_prediction_masks(self):
prediction_masks = {}
batch_gen, steps = self.validation_datagen
for batch_id, data in enumerate(batch_gen):
if len(data) != len(self.output_names) + 1:
raise ValueError('incorrect targets provided')
X = data[0]
targets_tensors = data[1:]

if torch.cuda.is_available():
X = Variable(X).cuda()
else:
X = Variable(X)

outputs_batch = self.model(X)
if len(outputs_batch) == len(self.output_names):
for name, output, target in zip(self.output_names, outputs_batch, targets_tensors):
prediction = sigmoid(np.squeeze(output.data.cpu().numpy(), axis=1))
ground_truth = np.squeeze(target.cpu().numpy(), axis=1)
prediction_masks[name] = np.stack([prediction, ground_truth], axis=1)
else:
for name, target in zip(self.output_names, targets_tensors):
prediction = sigmoid(np.squeeze(outputs_batch.data.cpu().numpy(), axis=1))
ground_truth = np.squeeze(target.cpu().numpy(), axis=1)
prediction_masks[name] = np.stack([prediction, ground_truth], axis=1)
break
return prediction_masks


class ValidationMonitorSegmentation(ValidationMonitor):
def __init__(self, data_dir, loader_mode, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_dir = data_dir
self.validation_pipeline = postprocessing_pipeline_simplified
self.loader_mode = loader_mode
Expand All @@ -483,6 +327,16 @@ def set_params(self, transformer, validation_datagen, meta_valid=None, *args, **
def get_validation_loss(self):
return self._get_validation_loss()

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
val_loss = self.get_validation_loss()
self.model.train()
for name, loss in val_loss.items():
loss = loss.data.cpu().numpy()[0]
logger.info('epoch {0} validation {1}: {2:.5f}'.format(self.epoch_id, name, loss))
self.epoch_id += 1

def _get_validation_loss(self):
output, epoch_loss = self._transform()
y_pred = self._generate_prediction(output)
Expand Down Expand Up @@ -565,11 +419,24 @@ def _generate_prediction(self, outputs):
return y_pred


class ModelCheckpointSegmentation(ModelCheckpoint):
def __init__(self, metric_name='sum', *args, **kwargs):
super().__init__(*args, **kwargs)
class ModelCheckpoint(Callback):
def __init__(self, filepath, metric_name='sum', epoch_every=1, minimize=True):
self.filepath = filepath
self.minimize = minimize
self.best_score = None

if epoch_every == 0:
self.epoch_every = False
else:
self.epoch_every = epoch_every

self.metric_name = metric_name

def on_train_begin(self, *args, **kwargs):
self.epoch_id = 0
self.batch_id = 0
os.makedirs(os.path.dirname(self.filepath), exist_ok=True)

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
Expand All @@ -583,19 +450,27 @@ def on_epoch_end(self, *args, **kwargs):
self.best_score = loss_sum

if (self.minimize and loss_sum < self.best_score) or (not self.minimize and loss_sum > self.best_score) or (
self.epoch_id == 0):
self.epoch_id == 0):
self.best_score = loss_sum
persist_torch_model(self.model, self.filepath)
logger.info('epoch {0} model saved to {1}'.format(self.epoch_id, self.filepath))

self.epoch_id += 1


class EarlyStoppingSegmentation(EarlyStopping):
def __init__(self, metric_name='sum', *args, **kwargs):
super().__init__(*args, **kwargs)
class EarlyStopping(Callback):
def __init__(self, metric_name='sum', patience=1000, minimize=True):
super().__init__()
self.patience = patience
self.minimize = minimize
self.best_score = None
self.epoch_since_best = 0
self._training_break = False
self.metric_name = metric_name

def training_break(self, *args, **kwargs):
return self._training_break

def on_epoch_end(self, *args, **kwargs):
self.model.eval()
val_loss = self.get_validation_loss()
Expand All @@ -619,7 +494,7 @@ def on_epoch_end(self, *args, **kwargs):


def postprocessing_pipeline_simplified(cache_dirpath, loader_mode):
if loader_mode == 'crop_and_pad':
if loader_mode == 'resize_and_pad':
size_adjustment_function = partial(crop_image, target_size=ORIGINAL_SIZE)
elif loader_mode == 'resize':
size_adjustment_function = partial(resize_image, target_size=ORIGINAL_SIZE)
Expand Down
Loading

0 comments on commit 17c8eec

Please sign in to comment.