diff --git a/setup.py b/setup.py index 113f64b8..e390d1be 100644 --- a/setup.py +++ b/setup.py @@ -4,11 +4,12 @@ with open('README.md', 'r') as f: long_description = f.read() - +description = 'A Modular, Configuration-Driven Framework for Knowledge Distillation. ' \ + 'Trained models, training logs and configurations are available for ensuring the reproducibiliy.' setup( name='torchdistill', - version='0.0.1', - description='A unified knowledge distillation framework.', + version='0.0.2', + description=description, long_description=long_description, long_description_content_type='text/markdown', url='https://github.com/yoshitomo-matsubara/torchdistill', diff --git a/torchdistill/core/distillation.py b/torchdistill/core/distillation.py index e386d8d2..d20179dc 100644 --- a/torchdistill/core/distillation.py +++ b/torchdistill/core/distillation.py @@ -11,7 +11,7 @@ from torchdistill.common.module_util import check_if_wrapped, freeze_module_params, get_module, unfreeze_module_params from torchdistill.core.forward_proc import get_forward_proc_func from torchdistill.core.util import set_hooks, wrap_model, change_device, tensor2numpy2tensor, extract_io_dict, \ - extract_sub_model_output_dict + update_io_dict, extract_sub_model_output_dict from torchdistill.datasets.util import build_data_loaders from torchdistill.losses.custom import get_custom_loss from torchdistill.losses.single import KDLoss, get_single_loss @@ -227,10 +227,11 @@ def get_teacher_output(self, sample_batch, targets, supp_dict): # Deep copy of teacher info dict if teacher special module contains trainable module(s) teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \ if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None + extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) if isinstance(self.teacher_model, SpecialModule): - self.teacher_model.post_forward(self.teacher_io_dict) + self.teacher_model.post_forward(extracted_teacher_io_dict) - extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) + update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device)) # Write cache files if output file paths (cache_file_paths) are given if isinstance(cache_file_paths, (list, tuple)): if teacher_io_dict4cache is None: @@ -249,13 +250,15 @@ def forward(self, sample_batch, targets, supp_dict): teacher_outputs, extracted_teacher_io_dict =\ self.get_teacher_output(sample_batch, targets, supp_dict=supp_dict) student_outputs = self.student_forward_proc(self.student_model, sample_batch, targets, supp_dict) + extracted_student_io_dict = extract_io_dict(self.student_io_dict, self.device) if isinstance(self.student_model, SpecialModule): - self.student_model.post_forward(self.student_io_dict) + self.student_model.post_forward(extracted_student_io_dict) org_loss_dict = self.extract_org_loss(self.org_criterion, student_outputs, teacher_outputs, targets, uses_teacher_output=self.uses_teacher_output, supp_dict=supp_dict) + update_io_dict(extracted_student_io_dict, extract_io_dict(self.student_io_dict, self.device)) output_dict = {'teacher': extracted_teacher_io_dict, - 'student': extract_io_dict(self.student_io_dict, self.device)} + 'student': extracted_student_io_dict} total_loss = self.criterion(output_dict, org_loss_dict, targets) return total_loss diff --git a/torchdistill/core/forward_hook.py b/torchdistill/core/forward_hook.py index dc558db2..3aa8b4f1 100644 --- a/torchdistill/core/forward_hook.py +++ b/torchdistill/core/forward_hook.py @@ -112,7 +112,7 @@ def pop_io_dict(self): for io_type in list(module_io_dict.keys()): sub_dict = module_io_dict.pop(io_type) values = [sub_dict[key] for key in sorted(sub_dict.keys())] - gathered_obj = gather(values, self.target_device) if self.uses_cuda else values[-1] + gathered_obj = gather(values, self.target_device) if self.uses_cuda and len(values) > 1 else values[-1] gathered_io_dict[module_path][io_type] = gathered_obj return gathered_io_dict diff --git a/torchdistill/core/util.py b/torchdistill/core/util.py index edf359bb..bacaaa2c 100644 --- a/torchdistill/core/util.py +++ b/torchdistill/core/util.py @@ -85,18 +85,25 @@ def tensor2numpy2tensor(data, device): def extract_io_dict(model_io_dict, target_device): - uses_cuda = target_device.type.startswith('cuda') + uses_cuda = target_device.type == 'cuda' gathered_io_dict = dict() for module_path, module_io_dict in model_io_dict.items(): gathered_io_dict[module_path] = dict() for io_type in list(module_io_dict.keys()): sub_dict = module_io_dict.pop(io_type) values = [sub_dict[key] for key in sorted(sub_dict.keys())] - gathered_obj = gather(values, target_device) if uses_cuda else values[-1] + gathered_obj = gather(values, target_device) if uses_cuda and len(values) > 1 else values[-1] gathered_io_dict[module_path][io_type] = gathered_obj return gathered_io_dict +def update_io_dict(main_io_dict, new_io_dict): + for key, module_io_dict in new_io_dict.items(): + for io_type, value in module_io_dict.items(): + if len(value) > 0: + main_io_dict[key][io_type] = value + + def extract_sub_model_output_dict(model_output_dict, index): sub_model_output_dict = dict() for module_path, sub_model_io_dict in model_output_dict.items():