Skip to content

Commit

Permalink
Set JSD loss default chunk_size to 2
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <austin362667@gmail.com>
  • Loading branch information
austin362667 authored and shivam15s committed Dec 17, 2024
1 parent 7b22ac7 commit 1d3b064
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def forward(
teacher_weight=teacher_weight,
target=true_labels,
loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
chunk_size=2,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
ignore_index=ignore_index,
Expand Down
12 changes: 5 additions & 7 deletions test/chunked_loss/test_jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,12 @@ def test_correctness(

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

with torch.autograd.detect_anomaly():
output1 = torch_lm_head_jsd(student_input1, teacher_input, target)
output2 = liger_lm_head_jsd(student_input2, teacher_input, target)
loss1 = torch_lm_head_jsd(student_input1, teacher_input, target)
loss2 = liger_lm_head_jsd(student_input2, teacher_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)

output1.backward()
output2.backward()
loss1.backward()
loss2.backward()

assert_verbose_allclose(
student_input1.grad, student_input2.grad, atol=atol, rtol=rtol
Expand Down

0 comments on commit 1d3b064

Please sign in to comment.