Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Knowledge Distillation Base #432

Merged
merged 11 commits into from
Dec 9, 2024
Merged

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Dec 7, 2024

Summary

Recreate #417 from the main repo.

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the first split of #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment.

Code Changes

  1. Refactor beta into two weights: weight_hard_loss and weight_soft_loss, as coefficients between hard_loss and soft_loss. @Tcc0403 also pointed out that we could use torch.lerp if applicable.

  2. Pass teacher_logits and student_logits directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the student_log_probs value calculated during ce_loss in distillation base.

    1. Remove the unnecessary get_batch_logps in test/utils.py.
  3. Modify chunking dimensions from B to B * T. Thanks to @hongpeng-guo's great advice.

    1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension.
  4. Normalize the distillation_loss using (full_target != ignore_index).sum().

TODO

  1. Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.

    The issue arises because we are not properly configuring the chunk_size for the B * T dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks.

    In contrast, this problem does not occur with the preference loss, as chunking is performed on the B dimension. This produces fewer than 10 chunks, which is efficient and works as expected.

    In conclusion, I set chunk_size to 1024 works pretty well in new benchmark results as shown in Add JSD Loss for Distillation #425

  2. Introduce Knowledge Distillation Base #417 (comment)

Knowledge Distillation

Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.

In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let z_t and z_s represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature T. When ground truth labels y are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.

The combined loss function is defined as:

$$\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),$$

Here, we directly pass in logits rather than logpbs. @Tcc0403

Shared DistillationBase

To support various distillation learning objectives, this PR aims to add a LigerFusedLinearDistillationBase which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.

Testing Done

I'll post JSD tests and benchmarks results in next PR: #425

  • Hardware Type: L40S
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

austin362667 and others added 11 commits December 7, 2024 00:03
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>

Set default `chunk_size` to `1024`

Signed-off-by: Austin Liu <austin362667@gmail.com>

Rebase

Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Dec 7, 2024

Here, we directly pass in logits rather than logpbs.

Sorry for the misleading question and late response. Passing logpbs is totally fine, it's actually better that it can avoid underflow issues in the log-space. Torch's KLDivLoss also expect inputs in the log-space, and the extra amount of calculation from softmax to logsoftmax shouldn't be an issue anyway. So if most APIs are expecting input as logpbs, then I think it's the way to go.

What I questioned in the last comment was about the xxx_per_token_xxx things, but there's no problems with this PR now.

Copy link
Collaborator

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@austin362667
Copy link
Collaborator Author

austin362667 commented Dec 8, 2024

Thanks @Tcc0403 and @hongpeng-guo .

I think we can merge this PR first? To unblock other distill loss impl.

I have some follow-ups to iterate on in my mind:

  1. Refactor to use logprobs for aligning sglang/vllm interfaces afterwards. src
  2. Figure out temperature scaling in logprobs. src
  3. Pre-compute the logits/logprobs offline beforehand , rather than having to have the teacher model loaded during training. src

@ByronHsu WDYT?

@ByronHsu ByronHsu merged commit fcba35a into main Dec 9, 2024
4 of 6 checks passed
@ByronHsu ByronHsu deleted the austin362667/feat/distill_base branch December 9, 2024 06:08
@@ -243,6 +244,7 @@ def _compute_loss(
hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0])
soft_loss /= full_target.shape[0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivam15s Could you help me understand why this normalization term was modified? 😀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants