Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce {pre,post}-{epoch,forward} processes and registries #274

Merged
merged 2 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions tests/registry_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest import TestCase

from torchdistill.core.registry import register_forward_proc_func, FORWARD_PROC_FUNC_DICT
from torchdistill.core.registry import register_forward_proc_func, get_forward_proc_func, register_pre_epoch_proc_func, \
register_pre_forward_proc_func, register_post_forward_proc_func, register_post_epoch_proc_func, \
get_pre_epoch_proc_func, get_pre_forward_proc_func, get_post_forward_proc_func, get_post_epoch_proc_func
from torchdistill.datasets.registry import register_dataset, register_collate_func, register_sample_loader_class, \
register_sample_loader_func, register_batch_sampler, register_transform, register_dataset_wrapper, \
DATASET_DICT, COLLATE_FUNC_DICT, SAMPLE_LOADER_CLASS_DICT, SAMPLE_LOADER_FUNC_DICT, BATCH_SAMPLER_DICT, \
Expand Down Expand Up @@ -50,20 +52,20 @@ def test_register_forward_proc_func(self):
def test_forward_proc0(model, batch):
return model(batch)

assert FORWARD_PROC_FUNC_DICT['test_forward_proc0'] == test_forward_proc0
assert get_forward_proc_func('test_forward_proc0') == test_forward_proc0

@register_forward_proc_func()
def test_forward_proc1(model, batch):
return model(batch)

assert FORWARD_PROC_FUNC_DICT['test_forward_proc1'] == test_forward_proc1
assert get_forward_proc_func('test_forward_proc1') == test_forward_proc1
random_name = 'custom_forward_proc_name2'

@register_forward_proc_func(key=random_name)
def test_forward_proc2(model, batch, label):
return model(batch, label)

assert FORWARD_PROC_FUNC_DICT[random_name] == test_forward_proc2
assert get_forward_proc_func(random_name) == test_forward_proc2

def test_register_collate_func(self):
@register_collate_func
Expand Down Expand Up @@ -419,3 +421,83 @@ def __init__(self):
self.name = 'test2'

assert SPECIAL_MODULE_DICT[random_name] == TestSpecialModule2

def test_register_pre_epoch_proc_func(self):
@register_pre_epoch_proc_func
def test_pre_epoch_proc_func0():
pass

assert get_pre_epoch_proc_func('test_pre_epoch_proc_func0') == test_pre_epoch_proc_func0

@register_pre_epoch_proc_func()
def test_pre_epoch_proc_func1():
pass

assert get_pre_epoch_proc_func('test_pre_epoch_proc_func1') == test_pre_epoch_proc_func1
random_name = 'custom_pre_epoch_proc_func_name2'

@register_pre_epoch_proc_func(key=random_name)
def test_pre_epoch_proc_func2():
pass

assert get_pre_epoch_proc_func(random_name) == test_pre_epoch_proc_func2

def test_register_pre_forward_proc_func(self):
@register_pre_forward_proc_func
def test_pre_forward_proc_func0():
pass

assert get_pre_forward_proc_func('test_pre_forward_proc_func0') == test_pre_forward_proc_func0

@register_pre_forward_proc_func()
def test_pre_forward_proc_func1():
pass

assert get_pre_forward_proc_func('test_pre_forward_proc_func1') == test_pre_forward_proc_func1
random_name = 'custom_pre_forward_proc_func_name2'

@register_pre_forward_proc_func(key=random_name)
def test_pre_forward_proc_func2():
pass

assert get_pre_forward_proc_func(random_name) == test_pre_forward_proc_func2

def test_register_post_forward_proc_func(self):
@register_post_forward_proc_func
def test_post_forward_proc_func0():
pass

assert get_post_forward_proc_func('test_post_forward_proc_func0') == test_post_forward_proc_func0

@register_post_forward_proc_func()
def test_post_forward_proc_func1():
pass

assert get_post_forward_proc_func('test_post_forward_proc_func1') == test_post_forward_proc_func1
random_name = 'custom_post_forward_proc_func_name2'

@register_post_forward_proc_func(key=random_name)
def test_post_forward_proc_func2():
pass

assert get_post_forward_proc_func(random_name) == test_post_forward_proc_func2

def test_register_post_epoch_proc_func(self):
@register_post_epoch_proc_func
def test_post_epoch_proc_func0():
pass

assert get_post_epoch_proc_func('test_post_epoch_proc_func0') == test_post_epoch_proc_func0

@register_post_epoch_proc_func()
def test_post_epoch_proc_func1():
pass

assert get_post_epoch_proc_func('test_post_epoch_proc_func1') == test_post_epoch_proc_func1
random_name = 'custom_post_epoch_proc_func_name2'

@register_post_epoch_proc_func(key=random_name)
def test_post_epoch_proc_func2():
pass

assert get_post_epoch_proc_func(random_name) == test_post_epoch_proc_func2
2 changes: 1 addition & 1 deletion torchdistill/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import forward_proc
from . import pre_epoch_proc, pre_forward_proc, forward_proc, post_forward_proc, post_epoch_proc
29 changes: 28 additions & 1 deletion torchdistill/core/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR

from .registry import get_forward_proc_func
from .post_epoch_proc import default_post_epoch_process_with_teacher
from .post_forward_proc import default_post_forward_process
from .pre_epoch_proc import default_pre_epoch_process_with_teacher
from .pre_forward_proc import default_pre_forward_process
from .registry import get_pre_epoch_proc_func, get_pre_forward_proc_func, get_forward_proc_func, \
get_post_forward_proc_func, get_post_epoch_proc_func
from .util import set_hooks, wrap_model, change_device, tensor2numpy2tensor, clear_io_dict, \
extract_io_dict, update_io_dict, extract_sub_model_output_dict
from ..common.constant import SELF_MODULE_PATH, def_logger
Expand Down Expand Up @@ -86,6 +91,25 @@ def setup_loss(self, train_config):
logger.info(self.criterion)
self.extract_org_loss = get_func2extract_org_output(criterion_config.get('func2extract_org_loss', None))

def setup_pre_post_processes(self, train_config):
pre_epoch_process = default_pre_epoch_process_with_teacher
if 'pre_epoch_process' in train_config:
pre_epoch_process = get_pre_epoch_proc_func(train_config['pre_epoch_process'])
setattr(DistillationBox, 'pre_epoch_process', pre_epoch_process)
pre_forward_process = default_pre_forward_process
if 'pre_forward_process' in train_config:
pre_forward_process = get_pre_forward_proc_func(train_config['pre_forward_process'])
setattr(DistillationBox, 'pre_forward_process', pre_forward_process)
post_forward_process = default_post_forward_process
if 'post_forward_process' in train_config:
post_forward_process = get_post_forward_proc_func(train_config['post_forward_process'])

setattr(DistillationBox, 'post_forward_process', post_forward_process)
post_epoch_process = default_post_epoch_process_with_teacher
if 'post_epoch_process' in train_config:
post_epoch_process = get_post_epoch_proc_func(train_config['post_epoch_process'])
setattr(DistillationBox, 'post_epoch_process', post_epoch_process)

def setup(self, train_config):
# Set up train and val data loaders
self.setup_data_loaders(train_config)
Expand Down Expand Up @@ -182,6 +206,9 @@ def setup(self, train_config):
self.accelerator.prepare(self.student_model, self.optimizer,
self.train_data_loader, self.val_data_loader)

# Set up {pre,post}-{epoch,forward} processes
self.setup_pre_post_processes(train_config)


def __init__(self, teacher_model, student_model, dataset_dict,
train_config, device, device_ids, distributed, lr_factor, accelerator=None):
Expand Down
46 changes: 46 additions & 0 deletions torchdistill/core/post_epoch_proc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from torch import distributed as dist
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR

from .registry import register_post_epoch_proc_func
from ..common.constant import def_logger
from ..models.special import SpecialModule

logger = def_logger.getChild(__name__)


@register_post_epoch_proc_func
def default_post_epoch_process_with_teacher(self, **kwargs):
# Epoch-wise scheduler step
if self.lr_scheduler is not None and self.scheduling_step <= 0:
if isinstance(self.lr_scheduler, ReduceLROnPlateau):
metrics = kwargs['metrics']
self.lr_scheduler.step(metrics)
elif isinstance(self.lr_scheduler, LambdaLR):
epoch = self.lr_scheduler.last_epoch + 1
self.lr_scheduler.step(epoch)
else:
self.lr_scheduler.step()
if isinstance(self.teacher_model, SpecialModule):
self.teacher_model.post_process()
if isinstance(self.student_model, SpecialModule):
self.student_model.post_process()
if self.distributed:
dist.barrier()


@register_post_epoch_proc_func
def default_post_epoch_process_without_teacher(self, **kwargs):
# Epoch-wise scheduler step
if self.lr_scheduler is not None and self.scheduling_step <= 0:
if isinstance(self.lr_scheduler, ReduceLROnPlateau):
metrics = kwargs['metrics']
self.lr_scheduler.step(metrics)
elif isinstance(self.lr_scheduler, LambdaLR):
epoch = self.lr_scheduler.last_epoch + 1
self.lr_scheduler.step(epoch)
else:
self.lr_scheduler.step()
if isinstance(self.model, SpecialModule):
self.model.post_process()
if self.distributed:
dist.barrier()
39 changes: 39 additions & 0 deletions torchdistill/core/post_forward_proc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR

from .registry import register_post_forward_proc_func
from ..common.constant import def_logger

logger = def_logger.getChild(__name__)


@register_post_forward_proc_func
def default_post_forward_process(self, loss, **kwargs):
self.stage_grad_count += 1
if self.grad_accum_step > 1:
loss /= self.grad_accum_step

if self.accelerator is not None:
self.accelerator.backward(loss)
else:
loss.backward()

if self.stage_grad_count % self.grad_accum_step == 0:
if self.max_grad_norm is not None:
target_params = [p for group in self.optimizer.param_groups for p in group['params']]
torch.nn.utils.clip_grad_norm_(target_params, self.max_grad_norm)

self.optimizer.step()
self.optimizer.zero_grad()

# Step-wise scheduler step
if self.lr_scheduler is not None and self.scheduling_step > 0 \
and self.stage_grad_count % self.scheduling_step == 0:
if isinstance(self.lr_scheduler, ReduceLROnPlateau):
metrics = kwargs['metrics']
self.lr_scheduler.step(metrics)
elif isinstance(self.lr_scheduler, LambdaLR):
local_epoch = int(self.stage_grad_count / self.scheduling_step)
self.lr_scheduler.step(local_epoch)
else:
self.lr_scheduler.step()
23 changes: 23 additions & 0 deletions torchdistill/core/pre_epoch_proc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .registry import register_pre_epoch_proc_func
from .util import clear_io_dict
from ..common.constant import def_logger

logger = def_logger.getChild(__name__)


@register_pre_epoch_proc_func
def default_pre_epoch_process_with_teacher(self, epoch, **kwargs):
clear_io_dict(self.teacher_io_dict)
clear_io_dict(self.student_io_dict)
self.teacher_model.eval()
self.student_model.train()
if self.distributed:
self.train_data_loader.batch_sampler.sampler.set_epoch(epoch)


@register_pre_epoch_proc_func
def default_pre_epoch_process_without_teacher(self, epoch, **kwargs):
clear_io_dict(self.model_io_dict)
self.model.train()
if self.distributed:
self.train_data_loader.batch_sampler.sampler.set_epoch(epoch)
9 changes: 9 additions & 0 deletions torchdistill/core/pre_forward_proc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .registry import register_pre_forward_proc_func
from ..common.constant import def_logger

logger = def_logger.getChild(__name__)


@register_pre_forward_proc_func
def default_pre_forward_process(self, **kwargs):
pass
84 changes: 84 additions & 0 deletions torchdistill/core/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,36 @@
PRE_EPOCH_PROC_FUNC_DICT = dict()
PRE_FORWARD_PROC_FUNC_DICT = dict()
FORWARD_PROC_FUNC_DICT = dict()
POST_FORWARD_PROC_FUNC_DICT = dict()
POST_EPOCH_PROC_FUNC_DICT = dict()


def register_pre_epoch_proc_func(arg=None, **kwargs):
def _register_pre_epoch_proc_func(func):
key = kwargs.get('key')
if key is None:
key = func.__name__

PRE_EPOCH_PROC_FUNC_DICT[key] = func
return func

if callable(arg):
return _register_pre_epoch_proc_func(arg)
return _register_pre_epoch_proc_func


def register_pre_forward_proc_func(arg=None, **kwargs):
def _register_pre_forward_proc_func(func):
key = kwargs.get('key')
if key is None:
key = func.__name__

PRE_FORWARD_PROC_FUNC_DICT[key] = func
return func

if callable(arg):
return _register_pre_forward_proc_func(arg)
return _register_pre_forward_proc_func


def register_forward_proc_func(arg=None, **kwargs):
Expand All @@ -15,9 +47,61 @@ def _register_forward_proc_func(func):
return _register_forward_proc_func


def register_post_forward_proc_func(arg=None, **kwargs):
def _register_post_forward_proc_func(func):
key = kwargs.get('key')
if key is None:
key = func.__name__

POST_FORWARD_PROC_FUNC_DICT[key] = func
return func

if callable(arg):
return _register_post_forward_proc_func(arg)
return _register_post_forward_proc_func


def register_post_epoch_proc_func(arg=None, **kwargs):
def _register_post_epoch_proc_func(func):
key = kwargs.get('key')
if key is None:
key = func.__name__

POST_EPOCH_PROC_FUNC_DICT[key] = func
return func

if callable(arg):
return _register_post_epoch_proc_func(arg)
return _register_post_epoch_proc_func


def get_pre_epoch_proc_func(key):
if key in PRE_EPOCH_PROC_FUNC_DICT:
return PRE_EPOCH_PROC_FUNC_DICT[key]
raise ValueError('No pre-epoch process function `{}` registered'.format(key))


def get_pre_forward_proc_func(key):
if key in PRE_FORWARD_PROC_FUNC_DICT:
return PRE_FORWARD_PROC_FUNC_DICT[key]
raise ValueError('No pre-forward process function `{}` registered'.format(key))


def get_forward_proc_func(key):
if key is None:
return FORWARD_PROC_FUNC_DICT['forward_batch_only']
elif key in FORWARD_PROC_FUNC_DICT:
return FORWARD_PROC_FUNC_DICT[key]
raise ValueError('No forward process function `{}` registered'.format(key))


def get_post_forward_proc_func(key):
if key in POST_FORWARD_PROC_FUNC_DICT:
return POST_FORWARD_PROC_FUNC_DICT[key]
raise ValueError('No post-forward process function `{}` registered'.format(key))


def get_post_epoch_proc_func(key):
if key in POST_EPOCH_PROC_FUNC_DICT:
return POST_EPOCH_PROC_FUNC_DICT[key]
raise ValueError('No post-epoch process function `{}` registered'.format(key))
Loading