Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Knowledge Distillation Base #432

Merged
merged 11 commits into from
Dec 9, 2024
250 changes: 250 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from abc import abstractmethod
from functools import partial

import torch
from torch.nn import functional as F


class LigerFusedLinearDistillationBase(torch.autograd.Function):

@abstractmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
"""
raise NotImplementedError("Distillation loss function must be implemented.")

@staticmethod
def chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
ignore_index=-100,
compute_ce_loss=True,
):
# Student
student_logits_chunk = student_input_chunk @ student_weight.t()
if student_bias is not None:
student_logits_chunk += student_bias
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)

# Teacher
with torch.no_grad():
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_logits_chunk += teacher_bias

# The hard/task loss
ce_loss = 0.0
if compute_ce_loss:
ce_loss = F.nll_loss(
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)

return student_logits_chunk, teacher_logits_chunk, ce_loss

@staticmethod
def forward(
ctx,
student_input,
student_weight,
teacher_input,
teacher_weight,
target,
student_bias=None,
teacher_bias=None,
loss_fn=None,
chunk_size=1024,
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
temperature=1.0,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with distillation loss.
Only need to compute gradients for student model.

Args:
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk.
compute_ce_loss (bool): Whether to compute CE loss.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
CHUNK_SIZE = chunk_size
grad_weight = torch.zeros_like(student_weight)
grad_inputs = []
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
loss_acc = torch.zeros((), device=student_input.device)

loss_func_to_call = partial(
LigerFusedLinearDistillationBase._compute_loss,
distillation_loss_fn=loss_fn,
full_target=target,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if student_bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_bias.add_(chunk_grad_bias)
else:
(chunk_grad_input, chunk_grad_weight), (
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)

for student_input_chunk, teacher_input_chunk, target_chunk in zip(
_student_input_chunks, _teacher_input_chunks, _target_chunks
):
grad_input = accumulate_chunk(
student_input_chunk, teacher_input_chunk, target_chunk
)
grad_inputs.append(grad_input)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
return loss_acc

@staticmethod
def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
grad_bias = grad_bias * grad_output if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias

@staticmethod
def _compute_loss(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
student_logits_chunk, teacher_logits_chunk, hard_loss = (
LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=student_bias,
teacher_bias=teacher_bias,
ignore_index=ignore_index,
compute_ce_loss=compute_ce_loss,
)
)

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= full_target.shape[0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivam15s Could you help me understand why this normalization term was modified? 😀


loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
109 changes: 109 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,112 @@ def get_batch_loss_metrics(
policy_nll_loss,
)
return loss, (*return_vars, *aggregated_aux_outputs)


class HFDistillationLoss:
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1,
):
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature

@abstractmethod
def distillation_loss(self, student_logits, teacher_logits):
"""Abstract method for computing distillation loss."""
pass

def concatenated_forward(
self,
student_input: torch.FloatTensor,
student_weight: torch.FloatTensor,
teacher_input: torch.FloatTensor,
teacher_weight: torch.FloatTensor,
target: torch.LongTensor,
student_bias: torch.FloatTensor = None,
teacher_bias: torch.FloatTensor = None,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
]:
"""Compute forward pass for both student and teacher models."""

student_batch_seq_len_size, student_hidden_size = student_input.shape
student_input_reshaped = student_input.view(-1, student_hidden_size)
teacher_batch_seq_len_size, teacher_hidden_size = teacher_input.shape
teacher_input_reshaped = teacher_input.view(-1, teacher_hidden_size)

student_outputs = student_input_reshaped @ student_weight.t()
if student_bias is not None:
student_outputs = student_outputs + student_bias

teacher_outputs = teacher_input_reshaped @ teacher_weight.t()
if teacher_bias is not None:
teacher_outputs = teacher_outputs + teacher_bias

student_logits = student_outputs.view(student_batch_seq_len_size, -1).float()
teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float()

if torch.all(target == self.ignore_index):
return torch.tensor(0.0)

def cross_entropy_loss(logits, labels):
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = target
ce_loss = cross_entropy_loss(
student_logits.view(-1, student_logits.shape[-1]),
labels.view(-1),
)

return (
student_logits,
teacher_logits,
ce_loss,
)

def get_batch_loss_metrics(
self,
student_input: torch.FloatTensor,
student_weight: torch.FloatTensor,
teacher_input: torch.FloatTensor,
teacher_weight: torch.FloatTensor,
target: torch.LongTensor,
student_bias: torch.FloatTensor = None,
teacher_bias: torch.FloatTensor = None,
):
"""Compute the distillation loss metrics for the given batch."""
forward_output = self.concatenated_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
target,
student_bias,
teacher_bias,
)
(
student_logits,
teacher_logits,
hard_loss,
) = forward_output

soft_loss = self.distillation_loss(student_logits, teacher_logits)
# full loss
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean()
return loss
Loading