-
Notifications
You must be signed in to change notification settings - Fork 31.3k
HF Trainer: ALST/Ulysses sequence parallelism integration via HF Accelerate #41832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…lerate Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
|
I have a hard time finding where I can add the documentation for this new backend. Context parallelism isn't documented anywhere in HF Trainer - how can users discover it? If I'm missing it, could you please point me to where I should extend the documentation? Thank you. Same story with CP tests - there are none :( so need to figure out how to write some. |
|
OK, the next issue with the existing integration of CP/FSDP. What does the following mean? this arg tells users absolutely nothing about what value(s) to pass to |
|
cc @SunMarc |
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
| return env | ||
|
|
||
| def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): | ||
| def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a really old version. In the latest incarnation it always return a Path object. But to keep BC, I added a new flag here instead. The tests are less clunkier then.
The latest version is here: https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py
If wanted you could switch to the latest version instead and adapt tests to simplify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ydshieh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's much better for it to always return a pathlib.Path object but you'd need to tweak a few tests which use this API.
|
@SunMarc, this is ready for a review. Tests fail because they need the accelerate PR huggingface/accelerate#3817 I just didn't know where to update docs since parallelism doesn't seem to be documented here at all. Please correct me if I'm wrong. Thanks to @kashif with the test. |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this clean integration ! Left a couple of comments. It would be great @kashif @qgallouedec if you can have a look at this PR so that we can also make it compatible with TRL.
| return env | ||
|
|
||
| def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): | ||
| def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ydshieh
src/transformers/trainer.py
Outdated
| losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) | ||
| # special dealing with SFT that has prompt tokens that aren't used in loss computation | ||
| good_tokens = (shift_labels != -100).view(-1).sum() | ||
| good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) | ||
| total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) | ||
| total_good_tokens = sum(good_tokens_per_rank) | ||
| loss = total_loss / max(total_good_tokens, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably don't need to do this if num_items_in_batch is computed and passed in in unwrapped_model.loss_function. num_items_in_batch was introduced to fix the gradient accumulation https://unsloth.ai/blog/gradient. num_items_in_batch is basically total_good_tokens if grad_acc = 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, what is not needed?
This code block is because we need to compute the correct loss across SP ranks. If you just average those it'll be incorrect in the case of -100 masked tokens (SFT), since each rank is likely to process a different number of unmasked tokens (this is not DP averaging).
Unless what you mean is that we don't need to calculate total_good_tokens since num_items_in_batch is already that, but the rest of the code remains - did I understand you correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you pass num_items_in_batch in loss_function, it will sum the loss then divide it by num_items_in_batch directly. This way I think we don't need to actually to recalculate the total_loss from the averaged losses and the good_tokens_per_rank. Maybe I'm wrong so please correct me ! But I think this might solve the grad acc issue. In any case, we will keep the current code as not all models accepts num_items_in_batch when calculating the loss.
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
def ForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
if shift_labels is None:
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
def fixed_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
**kwargs,
) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return lossThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you pass num_items_in_batch you indeed don't need to do local loss calculation since it'll do that already. But we need to calculate a distributed across ranks loss.
Here is an example: Let's take a 2k tokens sample SP-split across 2 ranks using SFT:
- SP rank0 - 900 masked and 100 non-masked tokens (a long initial prompt that is -100 masked out)
- SP rank1 - 100 masked and 900 non-masked tokens
So each rank produces the correct loss if we use num_items_in_batch - but how do you combine the losses of 2 ranks. straight average will give a very skewed result, because the rank0's loss contributes 9x less non-masked tokens.
Let's take it to a more telling example:
- SP rank0 - 1000 masked and 0 non-masked tokens (a long initial prompt that is masked out)
- SP rank1 - 0 masked and 1000 non-masked tokens
here rank0 can't even contribute anything to the total loss - a normal averaging of 2 losses would be completely broken, since you'd average with an undefined behavior, since the loss function will return a NaN or None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So each rank produces the correct loss if we use num_items_in_batch - but how do you combine the losses of 2 ranks. straight average will give a very skewed result, because the rank0's loss contributes 9x less non-masked tokens.
The denominator of the losses is both num_items_in_batch, the value of each loss already takes into account the number of non-masked tokens as we do reduction = "sum". So we just sum them to get the final loss. In your first examples, num_items_in_batch will be equal to 1000. For rank0, the loss will be equal to (L1+...L100)/1000 and for rank1, it will be (l1+..+l900)/1000
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a feeling we are missing each other. I'm talking about differentiable loss combination across ranks and I think you're talking about the local rank's loss.
Could you please point me to the code in HF Trainer that performs a differentiable loss combination across multiple ranks? I couldn't find any.
|
Thanks @SunMarc I did an initial run with default settings in TRL and these branches, and it all worked nicely (apart from model saving due to not updating deepspeed i think). I will check the config that uses a bespoke |
We indeed do not have docs related to that. @kashif as you added cp support in trainer, would you be willing to add some docs around that. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@SunMarc docs on TRL side: huggingface/trl#4420 |
|
i can look at the docs on the accelerate side as well... |
Ideally we would have a dedicated doc like you suggested, which could then link into deepspeed for nuances as one way to do that. The key is for the user to quickly understand what's possible, thus a single context parallel entry point doc would be very useful to users. |
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
|
Is there a mismatch between the docs and the code? |
|
@zhangwj618 the doc is wrong... my bad let me fix it! |
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
|
Thanks a lot @stas00 for your work 🤗 |
|
super! Thanks a lot to Marc and Kashif for help with integration and Weijie Zhang for being the first early adopter! |
Integrates HF Accelerate's support for ALST/Ulysses sequence parallelism huggingface/accelerate#3817 into HF Trainer
TODO: