-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Knowledge Distillation Base #432
Conversation
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> Set default `chunk_size` to `1024` Signed-off-by: Austin Liu <austin362667@gmail.com> Rebase Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Sorry for the misleading question and late response. Passing What I questioned in the last comment was about the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thanks @Tcc0403 and @hongpeng-guo . I think we can merge this PR first? To unblock other distill loss impl. I have some follow-ups to iterate on in my mind:
@ByronHsu WDYT? |
@@ -243,6 +244,7 @@ def _compute_loss( | |||
hard_loss /= full_target.shape[0] | |||
|
|||
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) | |||
soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0]) | |||
soft_loss /= full_target.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shivam15s Could you help me understand why this normalization term was modified? 😀
Summary
Recreate #417 from the main repo.
Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the first split of #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment.
Code Changes
Refactor
beta
into two weights:weight_hard_loss
andweight_soft_loss
, as coefficients betweenhard_loss
andsoft_loss
. @Tcc0403 also pointed out that we could usetorch.lerp
if applicable.Pass
teacher_logits
andstudent_logits
directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing thestudent_log_probs
value calculated duringce_loss
in distillation base.get_batch_logps
intest/utils.py
.Modify
chunking
dimensions fromB
toB * T
. Thanks to @hongpeng-guo's great advice.Normalize the
distillation_loss
using(full_target != ignore_index).sum()
.TODO
Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.
The issue arises because we are not properly configuring the
chunk_size
for theB * T
dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks.In contrast, this problem does not occur with the preference loss, as chunking is performed on the
B
dimension. This produces fewer than 10 chunks, which is efficient and works as expected.In conclusion, I set
chunk_size
to1024
works pretty well in new benchmark results as shown in Add JSD Loss for Distillation #425Introduce Knowledge Distillation Base #417 (comment)
Knowledge Distillation
Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.
In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let
z_t
andz_s
represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperatureT
. When ground truth labelsy
are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.The combined loss function is defined as:
Here, we directly pass in
logits
rather thanlogpbs
. @Tcc0403Shared
DistillationBase
To support various distillation learning objectives, this PR aims to add a
LigerFusedLinearDistillationBase
which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.Testing Done
I'll post JSD tests and benchmarks results in next PR: #425
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence