Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def _set_activation_fn_map() -> None:
MAX_DROPLESS_BLOCK_SIZE_ROW = 128


class ReverseKLImpl(str, enum.Enum):
tp = "tp"
stp = "stp"
no_tp = "no_tp"


class CrossEntropyImpl(str, enum.Enum):
auto = "auto"
torch = "torch"
Expand Down
157 changes: 116 additions & 41 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce
from fast_llm.functional.config import CrossEntropyImpl, TargetFormat
from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat
from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward
from fast_llm.utils import Assert

Expand Down Expand Up @@ -234,13 +234,22 @@ def _torch_reverse_kl_forward_backward_vocab_parallel(
grad_output: float | None,
target_format: TargetFormat,
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
This is used for TP version where we split accross vocab dimantion.
This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case.
In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss.
"""
Assert.eq(
teacher_softmax_temperature,
1,
msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL",
)
Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL")
# TODO: merge into single function _torch_reverse_kl_forward_backward
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
Expand All @@ -249,10 +258,10 @@ def _torch_reverse_kl_forward_backward_vocab_parallel(
Assert.eq(loss_mask.shape, logits.shape[:-1])

# Compute log probabilities - let _fused_softmax handle scaling internally
teacher_log_probs = distributed_log_softmax(target, group=group)
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
batch_size = logits.shape[0]
with torch.enable_grad():
logits_ = logits.detach().requires_grad_(grad_output is not None)
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
student_log_probs = distributed_log_softmax(logits_, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
Expand Down Expand Up @@ -284,20 +293,19 @@ def _torch_reverse_kl_forward_backward_vocab_parallel(
return loss.detach_(), grad


def _torch_reverse_kl_forward_backward(
def _torch_reverse_kl_forward_backward_no_tp(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
group: ProcessGroup | None = None,
teacher_softmax_temperature: float = 1.0,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case.
In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss.
THis is only used for no-TP case.
"""
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
Expand All @@ -309,20 +317,20 @@ def _torch_reverse_kl_forward_backward(
# Clamp to prevent extreme values that cause NaNs in log_softmax
scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0)

teacher_log_probs = torch.log_softmax(scaled_target, dim=-1)
teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1)

# For reverse KL: KL(q||p) = Ξ£ q * log(q/p) = Ξ£ q * (log(q) - log(p))
# Use kl_div with: input=log(p), target=q, log_target=False
# This gives: Ξ£ q * (log(q) - log(p)) = exactly what we want!

with torch.enable_grad():
logits_ = logits.detach().requires_grad_(grad_output is not None)
logits_ = logits.float().detach().requires_grad_(grad_output is not None)

scaled_logits = logits_ * logits_scale_factor
# Clamp to prevent extreme values that cause NaNs in log_softmax
scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0)
student_log_probs = torch.log_softmax(scaled_logits, dim=-1)
student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1)

# Reverse KL: input=teacher_log_probs, target=student_probs
if loss_mask is None:
loss = torch.nn.functional.kl_div(
Expand All @@ -336,11 +344,7 @@ def _torch_reverse_kl_forward_backward(
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask).mean()

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= group.size()
loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum()

if grad_output is not None:
# note, we never get here in TP over seq. dim.
Expand All @@ -352,6 +356,88 @@ def _torch_reverse_kl_forward_backward(
return loss.detach_(), grad


def _torch_reverse_kl_forward_backward_sequence_tensor_parallel(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
teacher_softmax_temperature: float = 1.0,
total_valid_tokens: int | None = None, # total number of unmasked tokens in the batch
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
THis is only used for sequence-tensor-parallel case where we split over sequence dimension.
"""
Assert.eq(
total_valid_tokens is not None,
msg="Total valid tokens must be provided for sequence-tensor-parallel reverse KL",
)
Assert.eq(
teacher_softmax_temperature,
1,
msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL",
)
Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL")
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
# Scale target logits more carefully
scaled_target = target * (logits_scale_factor / teacher_softmax_temperature)
# Clamp to prevent extreme values that cause NaNs in log_softmax
scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0)

teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1)

# For reverse KL: KL(q||p) = Ξ£ q * log(q/p) = Ξ£ q * (log(q) - log(p))
# Use kl_div with: input=log(p), target=q, log_target=False
# This gives: Ξ£ q * (log(q) - log(p)) = exactly what we want!

with torch.enable_grad():
logits_ = logits.float().detach().requires_grad_(grad_output is not None)

scaled_logits = logits_ * logits_scale_factor
# Clamp to prevent extreme values that cause NaNs in log_softmax
scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0)
student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1)

# Reverse KL: input=teacher_log_probs, target=student_probs
if loss_mask is None:
loss = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="sum",
log_target=True,
)
else:
# Apply loss mask - this requires some reshaping
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask).sum() # this can be 0.0 if all tokens are masked

if grad_output is not None:
# note, if we compute gradient w.r.t sum of losses,
# and grad_output should reflect the scaling by 1/valid samples
loss.backward(torch.full_like(loss, grad_output))
grad = logits_.grad.to(logits.dtype)
else:
grad = None

return loss.detach_(), grad


REVERSE_KL_IMPLEMENTATIONS = {
ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward_no_tp,
ReverseKLImpl.tp: _torch_reverse_kl_forward_backward_vocab_parallel,
ReverseKLImpl.stp: _torch_reverse_kl_forward_backward_sequence_tensor_parallel,
}


def reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
Expand All @@ -361,7 +447,8 @@ def reverse_kl_forward_backward(
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
vocab_parallel: bool = False,
reverse_kl_impl: ReverseKLImpl = ReverseKLImpl.no_tp,
total_valid_tokens: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher).
Expand Down Expand Up @@ -404,27 +491,15 @@ def reverse_kl_forward_backward(
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
# TODO: implement fused?
if vocab_parallel:
Assert.eq(teacher_softmax_temperature, 1)
Assert.eq(logits_scale_factor, 1)
raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.")
return _torch_reverse_kl_forward_backward_vocab_parallel(
logits,
target,
loss_mask,
grad_output,
target_format,
group,
)
else:
return _torch_reverse_kl_forward_backward(
logits,
target,
loss_mask,
grad_output,
logits_scale_factor,
target_format,
group,
teacher_softmax_temperature,
)
# TODO: implement fused reverse KL?
return REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl](
logits=logits,
target=target,
loss_mask=loss_mask,
grad_output=grad_output,
logits_scale_factor=logits_scale_factor,
target_format=target_format,
teacher_softmax_temperature=teacher_softmax_temperature,
group=group,
total_valid_tokens=total_valid_tokens,
)
52 changes: 44 additions & 8 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace
from fast_llm.engine.distributed.config import DistributedDimNames
from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward
from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig
from fast_llm.functional.config import (
CrossEntropyImpl,
DistillationLossImpl,
ReverseKLImpl,
TargetFormat,
TritonConfig,
)
from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward
from fast_llm.functional.dpo import compute_dpo_loss
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward
Expand Down Expand Up @@ -313,12 +319,13 @@ def _logits_cross_entropy_forward_backward_split(
logit_input_grad_.copy_(grad_)
loss = loss_ if loss is None else loss + loss_
del grad_, loss_
loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1)
if loss_count != 1:
loss.div_(loss_count)
if self._sequence_parallel_logits:
# TODO: Async
all_reduce(loss, group=self._tensor_space.distributed.tensor_group)
assert self._cross_entropy_splits is None, "This is not supported for now"
# loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1)
# if loss_count != 1:
# loss.div_(loss_count)
# if self._sequence_parallel_logits:
# # TODO: Async
# all_reduce(loss, group=self._tensor_space.distributed.tensor_group)
return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None

def _logits_cross_entropy_forward_backward(
Expand Down Expand Up @@ -412,6 +419,29 @@ def _logits_cross_entropy_forward_backward(

if distillation_target is not None and self._distillation_loss_factor > 0.0:
if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl:
local_valid_tokens = total_valid_tokens = logits.shape[0]
if logits.shape[-1] != self._config.vocab_size:
reverse_kl_impl = ReverseKLImpl.tp
assert loss_mask is None, "Loss mask is not implemented for TP (vocab dim) reverse KL yet"
elif self._sequence_parallel_logits:
# grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward
reverse_kl_impl = ReverseKLImpl.stp
if loss_mask is not None:
local_valid_tokens = loss_mask.sum()
total_valid_tokens = local_valid_tokens.clone()
all_reduce(
total_valid_tokens, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group
)
else:
local_valid_tokens = logits.shape[0]
total_valid_tokens = local_valid_tokens * self._group_size
# in the loss function we compute grads w.r.t sum of losses,
# so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling
# note, the function returns the sum of local losses, so we need to handle this properly for reporting
grad_output *= self._group_size / total_valid_tokens # multiply back by the group size
else:
reverse_kl_impl = ReverseKLImpl.no_tp

distillation_loss, distillation_grad = reverse_kl_forward_backward(
logits.flatten(0, -2),
distillation_target,
Expand All @@ -423,8 +453,14 @@ def _logits_cross_entropy_forward_backward(
target_format=(
TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits
),
vocab_parallel=logits.shape[-1] != self._config.vocab_size,
reverse_kl_impl=reverse_kl_impl,
total_valid_tokens=total_valid_tokens,
)
if self._sequence_parallel_logits:
# distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling
all_reduce(distillation_loss, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group)
distillation_loss /= total_valid_tokens # final global loss

elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy:
distillation_loss, distillation_grad = cross_entropy_forward_backward(
logits.flatten(0, -2),
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ def preprocess_meta(
micro_batch_size = batch_meta.micro_batch_size
sequence_length = batch_meta.sequence_length
micro_sequence_length = batch_meta.micro_sequence_length
truncate_documents = batch_meta.truncate_documents
else:
micro_batch_size, sequence_length = batch_meta.shape
if phase != PhaseType.inference:
sequence_length -= self._config.prediction_heads
micro_sequence_length = sequence_length
truncate_documents = True

if self._config.vision_encoder.enabled:
try:
Expand Down Expand Up @@ -245,6 +247,7 @@ def preprocess_meta(
TransformerKwargs.sequence_length: sequence_length,
TransformerKwargs.sequence_q_dim: sequence_q_dim,
TransformerKwargs.micro_batch_size: micro_batch_size,
LanguageModelKwargs.mask_inputs: not truncate_documents,
}
common_kwargs.update(vision_kwargs)

Expand Down