diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index c56b63065..5c8d75a6f 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -93,6 +93,12 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 +class ReverseKLImpl(str, enum.Enum): + tp = "tp" + stp = "stp" + no_tp = "no_tp" + + class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index b18a9ec0b..1be4ed82b 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -234,6 +234,9 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( grad_output: float | None, target_format: TargetFormat, group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -241,6 +244,12 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -249,10 +258,10 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - teacher_log_probs = distributed_log_softmax(target, group=group) + teacher_log_probs = distributed_log_softmax(target.float(), group=group) batch_size = logits.shape[0] with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) student_log_probs = distributed_log_softmax(logits_, group=group) # Reverse KL: input=teacher_log_probs, target=student_probs @@ -284,20 +293,19 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( return loss.detach_(), grad -def _torch_reverse_kl_forward_backward( +def _torch_reverse_kl_forward_backward_no_tp( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, - group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. - In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. + THis is only used for no-TP case. """ Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -309,20 +317,20 @@ def _torch_reverse_kl_forward_backward( # Clamp to prevent extreme values that cause NaNs in log_softmax scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) # Use kl_div with: input=log(p), target=q, log_target=False # This gives: Σ q * (log(q) - log(p)) = exactly what we want! with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) scaled_logits = logits_ * logits_scale_factor # Clamp to prevent extreme values that cause NaNs in log_softmax scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) - student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -336,11 +344,7 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() - - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= group.size() + loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() if grad_output is not None: # note, we never get here in TP over seq. dim. @@ -352,6 +356,88 @@ def _torch_reverse_kl_forward_backward( return loss.detach_(), grad +def _torch_reverse_kl_forward_backward_sequence_tensor_parallel( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + teacher_softmax_temperature: float = 1.0, + total_valid_tokens: int | None = None, # total number of unmasked tokens in the batch + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + THis is only used for sequence-tensor-parallel case where we split over sequence dimension. + """ + Assert.eq( + total_valid_tokens is not None, + msg="Total valid tokens must be provided for sequence-tensor-parallel reverse KL", + ) + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + # Scale target logits more carefully + scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) + + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) + + # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) + # Use kl_div with: input=log(p), target=q, log_target=False + # This gives: Σ q * (log(q) - log(p)) = exactly what we want! + + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + + scaled_logits = logits_ * logits_scale_factor + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() # this can be 0.0 if all tokens are masked + + if grad_output is not None: + # note, if we compute gradient w.r.t sum of losses, + # and grad_output should reflect the scaling by 1/valid samples + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +REVERSE_KL_IMPLEMENTATIONS = { + ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward_no_tp, + ReverseKLImpl.tp: _torch_reverse_kl_forward_backward_vocab_parallel, + ReverseKLImpl.stp: _torch_reverse_kl_forward_backward_sequence_tensor_parallel, +} + + def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -361,7 +447,8 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - vocab_parallel: bool = False, + reverse_kl_impl: ReverseKLImpl = ReverseKLImpl.no_tp, + total_valid_tokens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -404,27 +491,15 @@ def reverse_kl_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # TODO: implement fused? - if vocab_parallel: - Assert.eq(teacher_softmax_temperature, 1) - Assert.eq(logits_scale_factor, 1) - raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") - return _torch_reverse_kl_forward_backward_vocab_parallel( - logits, - target, - loss_mask, - grad_output, - target_format, - group, - ) - else: - return _torch_reverse_kl_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - group, - teacher_softmax_temperature, - ) + # TODO: implement fused reverse KL? + return REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + total_valid_tokens=total_valid_tokens, + ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 24c06d5cc..b1f3564b9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -11,7 +11,13 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import ( + CrossEntropyImpl, + DistillationLossImpl, + ReverseKLImpl, + TargetFormat, + TritonConfig, +) from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -313,12 +319,13 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) - if loss_count != 1: - loss.div_(loss_count) - if self._sequence_parallel_logits: - # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + assert self._cross_entropy_splits is None, "This is not supported for now" + # loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + # if loss_count != 1: + # loss.div_(loss_count) + # if self._sequence_parallel_logits: + # # TODO: Async + # all_reduce(loss, group=self._tensor_space.distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -412,6 +419,29 @@ def _logits_cross_entropy_forward_backward( if distillation_target is not None and self._distillation_loss_factor > 0.0: if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + local_valid_tokens = total_valid_tokens = logits.shape[0] + if logits.shape[-1] != self._config.vocab_size: + reverse_kl_impl = ReverseKLImpl.tp + assert loss_mask is None, "Loss mask is not implemented for TP (vocab dim) reverse KL yet" + elif self._sequence_parallel_logits: + # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward + reverse_kl_impl = ReverseKLImpl.stp + if loss_mask is not None: + local_valid_tokens = loss_mask.sum() + total_valid_tokens = local_valid_tokens.clone() + all_reduce( + total_valid_tokens, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group + ) + else: + local_valid_tokens = logits.shape[0] + total_valid_tokens = local_valid_tokens * self._group_size + # in the loss function we compute grads w.r.t sum of losses, + # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling + # note, the function returns the sum of local losses, so we need to handle this properly for reporting + grad_output *= self._group_size / total_valid_tokens # multiply back by the group size + else: + reverse_kl_impl = ReverseKLImpl.no_tp + distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -423,8 +453,14 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), - vocab_parallel=logits.shape[-1] != self._config.vocab_size, + reverse_kl_impl=reverse_kl_impl, + total_valid_tokens=total_valid_tokens, ) + if self._sequence_parallel_logits: + # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling + all_reduce(distillation_loss, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group) + distillation_loss /= total_valid_tokens # final global loss + elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ebf84fc58..0c9d0e6ca 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -143,11 +143,13 @@ def preprocess_meta( micro_batch_size = batch_meta.micro_batch_size sequence_length = batch_meta.sequence_length micro_sequence_length = batch_meta.micro_sequence_length + truncate_documents = batch_meta.truncate_documents else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length + truncate_documents = True if self._config.vision_encoder.enabled: try: @@ -245,6 +247,7 @@ def preprocess_meta( TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, TransformerKwargs.micro_batch_size: micro_batch_size, + LanguageModelKwargs.mask_inputs: not truncate_documents, } common_kwargs.update(vision_kwargs)