Skip to content

Commit

Permalink
added student adapter (#820)
Browse files Browse the repository at this point in the history
Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
shaydeci and ofrimasad authored Apr 13, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent bcdc408 commit 6b762b3
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/super_gradients/training/models/kd_modules/kd_module.py
Original file line number Diff line number Diff line change
@@ -24,9 +24,10 @@ class implementing Knowledge Distillation logic as an SgModule
run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
arch_params: HpmStruct- Architecture H.P.
Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net s.t
teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
different input format from the student (for example different normalization).
Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net to act as if
teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
different input format from the student (for example different normalization).
Equivalent arg for the student model, can be passed through student_input_adapter.
"""

@@ -36,6 +37,7 @@ def __init__(self, arch_params: HpmStruct, student: SgModule, teacher: torch.nn.
self.student = student
self.teacher = teacher
self.teacher_input_adapter = get_param(self.arch_params, "teacher_input_adapter")
self.student_input_adapter = get_param(self.arch_params, "student_input_adapter")
self.run_teacher_on_eval = run_teacher_on_eval
self._freeze_teacher()

@@ -62,10 +64,17 @@ def eval(self):
self.teacher.eval()

def forward(self, x):
if self.student_input_adapter is not None:
student_output = self.student(self.student_input_adapter(x))
else:
student_output = self.student(x)

if self.teacher_input_adapter is not None:
return KDOutput(student_output=self.student(x), teacher_output=self.teacher(self.teacher_input_adapter(x)))
teacher_output = self.teacher(self.teacher_input_adapter(x))
else:
return KDOutput(student_output=self.student(x), teacher_output=self.teacher(x))
teacher_output = self.teacher(x)

return KDOutput(student_output=student_output, teacher_output=teacher_output)

def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
return self.student.initialize_param_groups(lr, training_params)

0 comments on commit 6b762b3

Please sign in to comment.