Skip to content

Conversation

@stas00
Copy link
Contributor

@stas00 stas00 commented Oct 23, 2025

Integrates HF Accelerate's support for ALST/Ulysses sequence parallelism huggingface/accelerate#3817 into HF Trainer

TODO:

…lerate

Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
@stas00
Copy link
Contributor Author

stas00 commented Oct 23, 2025

I have a hard time finding where I can add the documentation for this new backend. Context parallelism isn't documented anywhere in HF Trainer - how can users discover it? If I'm missing it, could you please point me to where I should extend the documentation? Thank you.

Same story with CP tests - there are none :( so need to figure out how to write some.

@stas00
Copy link
Contributor Author

stas00 commented Oct 24, 2025

OK, the next issue with the existing integration of CP/FSDP. What does the following mean?

 $ sometrainerscript.py --help
 [...]
 --parallelism_config PARALLELISM_CONFIG, --parallelism-config PARALLELISM_CONFIG

this arg tells users absolutely nothing about what value(s) to pass to --parallelism_config. If it's not meant to be in CLI args and can only be used explicitly by writing code perhaps it shouldn't be listed in --help or at least say that it has to be coded?

@Rocketknight1
Copy link
Member

cc @SunMarc

@stas00 stas00 marked this pull request as draft October 27, 2025 04:42
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
return env

def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False):
Copy link
Contributor Author

@stas00 stas00 Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really old version. In the latest incarnation it always return a Path object. But to keep BC, I added a new flag here instead. The tests are less clunkier then.

The latest version is here: https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py

If wanted you could switch to the latest version instead and adapt tests to simplify.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's much better for it to always return a pathlib.Path object but you'd need to tweak a few tests which use this API.

Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
@stas00 stas00 marked this pull request as ready for review October 28, 2025 02:34
@stas00
Copy link
Contributor Author

stas00 commented Oct 28, 2025

@SunMarc, this is ready for a review. Tests fail because they need the accelerate PR huggingface/accelerate#3817

I just didn't know where to update docs since parallelism doesn't seem to be documented here at all. Please correct me if I'm wrong.

Thanks to @kashif with the test.

Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this clean integration ! Left a couple of comments. It would be great @kashif @qgallouedec if you can have a look at this PR so that we can also make it compatible with TRL.

return env

def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 3865 to 3871
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / max(total_good_tokens, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need to do this if num_items_in_batch is computed and passed in in unwrapped_model.loss_function. num_items_in_batch was introduced to fix the gradient accumulation https://unsloth.ai/blog/gradient. num_items_in_batch is basically total_good_tokens if grad_acc = 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what is not needed?

This code block is because we need to compute the correct loss across SP ranks. If you just average those it'll be incorrect in the case of -100 masked tokens (SFT), since each rank is likely to process a different number of unmasked tokens (this is not DP averaging).

Unless what you mean is that we don't need to calculate total_good_tokens since num_items_in_batch is already that, but the rest of the code remains - did I understand you correctly?

Copy link
Member

@SunMarc SunMarc Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you pass num_items_in_batch in loss_function, it will sum the loss then divide it by num_items_in_batch directly. This way I think we don't need to actually to recalculate the total_loss from the averaged losses and the good_tokens_per_rank. Maybe I'm wrong so please correct me ! But I think this might solve the grad acc issue. In any case, we will keep the current code as not all models accepts num_items_in_batch when calculating the loss.

total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))

def ForCausalLMLoss(
    logits,
    labels,
    vocab_size: int,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    shift_labels: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()

    if shift_labels is None:
        # Shift so that tokens < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    logits = logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(logits.device)
    loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss


def fixed_cross_entropy(
    source: torch.Tensor,
    target: torch.Tensor,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    **kwargs,
) -> torch.Tensor:
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss

Copy link
Contributor Author

@stas00 stas00 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass num_items_in_batch you indeed don't need to do local loss calculation since it'll do that already. But we need to calculate a distributed across ranks loss.

Here is an example: Let's take a 2k tokens sample SP-split across 2 ranks using SFT:

  1. SP rank0 - 900 masked and 100 non-masked tokens (a long initial prompt that is -100 masked out)
  2. SP rank1 - 100 masked and 900 non-masked tokens

So each rank produces the correct loss if we use num_items_in_batch - but how do you combine the losses of 2 ranks. straight average will give a very skewed result, because the rank0's loss contributes 9x less non-masked tokens.

Let's take it to a more telling example:

  1. SP rank0 - 1000 masked and 0 non-masked tokens (a long initial prompt that is masked out)
  2. SP rank1 - 0 masked and 1000 non-masked tokens

here rank0 can't even contribute anything to the total loss - a normal averaging of 2 losses would be completely broken, since you'd average with an undefined behavior, since the loss function will return a NaN or None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So each rank produces the correct loss if we use num_items_in_batch - but how do you combine the losses of 2 ranks. straight average will give a very skewed result, because the rank0's loss contributes 9x less non-masked tokens.

The denominator of the losses is both num_items_in_batch, the value of each loss already takes into account the number of non-masked tokens as we do reduction = "sum". So we just sum them to get the final loss. In your first examples, num_items_in_batch will be equal to 1000. For rank0, the loss will be equal to (L1+...L100)/1000 and for rank1, it will be (l1+..+l900)/1000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we are missing each other. I'm talking about differentiable loss combination across ranks and I think you're talking about the local rank's loss.

Could you please point me to the code in HF Trainer that performs a differentiable loss combination across multiple ranks? I couldn't find any.

@kashif
Copy link
Contributor

kashif commented Nov 4, 2025

Thanks @SunMarc I did an initial run with default settings in TRL and these branches, and it all worked nicely (apart from model saving due to not updating deepspeed i think). I will check the config that uses a bespoke compute_loss in the sft trainer, thanks for the heads up!

@SunMarc
Copy link
Member

SunMarc commented Nov 4, 2025

@SunMarc, this is ready for a review. Tests fail because they need the accelerate PR huggingface/accelerate#3817

I just didn't know where to update docs since parallelism doesn't seem to be documented here at all. Please correct me if I'm wrong.

Thanks to @kashif with the test.

We indeed do not have docs related to that. @kashif as you added cp support in trainer, would you be willing to add some docs around that.
@stas00, I think we can either update deepspeed docs https://huggingface.co/docs/transformers/main/en/deepspeed and/or create a new docs called contextparallel like in accelerate.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif
Copy link
Contributor

kashif commented Nov 4, 2025

@SunMarc docs on TRL side: huggingface/trl#4420

@kashif
Copy link
Contributor

kashif commented Nov 4, 2025

i can look at the docs on the accelerate side as well...

@stas00
Copy link
Contributor Author

stas00 commented Nov 4, 2025

@stas00, I think we can either update deepspeed docs https://huggingface.co/docs/transformers/main/en/deepspeed and/or create a new docs called contextparallel like in accelerate.

Ideally we would have a dedicated doc like you suggested, which could then link into deepspeed for nuances as one way to do that. The key is for the user to quickly understand what's possible, thus a single context parallel entry point doc would be very useful to users.

sfc-gh-sbekman and others added 6 commits November 5, 2025 03:38
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
@zhangwj618
Copy link

zhangwj618 commented Nov 20, 2025

Is there a mismatch between the docs and the code? docs/source/en/deepspeed.md says:
"By default, when you only configure sp_size, DP is automatically calculated as dp_size = world_size / sp_size."
However, when I run the code with sp_size != world_size, I get this error, unless I specify dp_replicate_size manully.

@kashif
Copy link
Contributor

kashif commented Nov 20, 2025

@zhangwj618 the doc is wrong... my bad let me fix it!

kashif and others added 3 commits November 20, 2025 17:28
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@ArthurZucker ArthurZucker merged commit 7e0ea69 into huggingface:main Nov 21, 2025
15 of 21 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks a lot @stas00 for your work 🤗

@stas00 stas00 deleted the alst-integration branch November 21, 2025 17:39
@stas00
Copy link
Contributor Author

stas00 commented Nov 21, 2025

super! Thanks a lot to Marc and Kashif for help with integration and Weijie Zhang for being the first early adopter!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants