Skip to content

Commit

Permalink
Merge pull request #58 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Support gradient accumulation, max gradient norm, misc, and fix typos
  • Loading branch information
yoshitomo-matsubara authored Feb 6, 2021
2 parents bf8d4f8 + 5252f94 commit 5fc2fa3
Show file tree
Hide file tree
Showing 45 changed files with 128 additions and 53 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and the configurations reuse the hyperparameters such as number of epochs used i
Executable code can be found in [examples/](examples/) such as
- [Image classification](examples/image_classification.py): ImageNet (ILSVRC 2012), CIFAR-10, CIFAR-100, etc
- [Object detection](examples/object_detection.py): COCO 2017, etc
- [Semantic segmentation](examples/semantic_segmentation.py): COCO 2017, etc
- [Semantic segmentation](examples/semantic_segmentation.py): COCO 2017, PASCAL VOC, etc

## Google Colab Examples
### CIFAR-10 and CIFAR-100
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/resnet110-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/resnet1202-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/resnet20-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/resnet32-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/resnet44-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/resnet56-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/wide_resnet16_8-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/wide_resnet28_10-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar10/ce/wide_resnet40_4-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar10:
name: &dataset_name 'cifar10'
type: 'CIFAR10'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar100/ce/wide_resnet16_8-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar100/ce/wide_resnet28_10-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion configs/sample/cifar100/ce/wide_resnet40_4-final_run.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datasets:
ilsvrc2012:
cifar100:
name: &dataset_name 'cifar100'
type: 'CIFAR100'
root: &root_dir !join ['./resource/dataset/', *dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'Trained models, training logs and configurations are available for ensuring the reproducibiliy.'
setup(
name='torchdistill',
version='0.1.2',
version='0.1.3',
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
35 changes: 32 additions & 3 deletions torchdistill/core/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,18 @@ def setup(self, train_config):
trainable_module_list.append(self.teacher_model)

self.optimizer = get_optimizer(trainable_module_list, optim_config['type'], optim_params_config)
self.optimizer.zero_grad()
self.max_grad_norm = optim_config.get('max_grad_norm', None)
self.grad_accum_step = optim_config.get('grad_accum_step', 1)
optimizer_reset = True

scheduler_config = train_config.get('scheduler', None)
if scheduler_config is not None and len(scheduler_config) > 0:
self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['type'], scheduler_config['params'])
self.scheduling_step = scheduler_config.get('scheduling_step', 0)
elif optimizer_reset:
self.lr_scheduler = None
self.scheduling_step = None

# Set up apex if you require mixed-precision training
self.apex = False
Expand All @@ -168,13 +173,15 @@ def setup(self, train_config):
def __init__(self, teacher_model, student_model, dataset_dict,
train_config, device, device_ids, distributed, lr_factor):
super().__init__()
# Key attributes (should not be modified)
self.org_teacher_model = teacher_model
self.org_student_model = student_model
self.dataset_dict = dataset_dict
self.device = device
self.device_ids = device_ids
self.distributed = distributed
self.lr_factor = lr_factor
# Local attributes (can be updated at each stage)
self.teacher_model = None
self.student_model = None
self.teacher_forward_proc, self.student_forward_proc = None, None
Expand All @@ -183,6 +190,10 @@ def __init__(self, teacher_model, student_model, dataset_dict,
self.train_data_loader, self.val_data_loader, self.optimizer, self.lr_scheduler = None, None, None, None
self.org_criterion, self.criterion, self.uses_teacher_output, self.extract_org_loss = None, None, None, None
self.teacher_updatable, self.teacher_any_frozen, self.student_any_frozen = None, None, None
self.grad_accum_step = None
self.max_grad_norm = None
self.scheduling_step = 0
self.stage_grad_count = 0
self.apex = None
self.setup(train_config)
self.num_epochs = train_config['num_epochs']
Expand Down Expand Up @@ -265,16 +276,33 @@ def forward(self, sample_batch, targets, supp_dict):
return total_loss

def update_params(self, loss):
self.optimizer.zero_grad()
self.stage_grad_count += 1
if self.grad_accum_step > 1:
loss /= self.grad_accum_step

if self.apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.optimizer.step()

if self.stage_grad_count % self.grad_accum_step == 0:
if self.max_grad_norm is not None:
target_params = amp.master_params(self.optimizer) if self.apex \
else [p for group in self.optimizer.param_groups for p in group['group']]
torch.nn.utils.clip_grad_norm_(target_params, self.max_grad_norm)

self.optimizer.step()
self.optimizer.zero_grad()

# Step-wise scheduler step
if self.lr_scheduler is not None and self.scheduling_step > 0 \
and self.stage_grad_count % self.scheduling_step == 0:
self.lr_scheduler.step()

def post_process(self, **kwargs):
if self.lr_scheduler is not None:
# Epoch-wise scheduler step
if self.lr_scheduler is not None and self.scheduling_step <= 0:
self.lr_scheduler.step()
if isinstance(self.teacher_model, SpecialModule):
self.teacher_model.post_process()
Expand Down Expand Up @@ -310,6 +338,7 @@ def __init__(self, teacher_model, student_model, data_loader_dict,

def advance_to_next_stage(self):
self.clean_modules()
self.stage_grad_count = 0
self.stage_number += 1
next_stage_config = self.train_config['stage{}'.format(self.stage_number)]
self.setup(next_stage_config)
Expand Down
Loading

0 comments on commit 5fc2fa3

Please sign in to comment.