Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate MS-AMP Support for FP8 as a seperate backend #2232

Merged
merged 20 commits into from
Dec 15, 2023
Merged

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Dec 8, 2023

Integrate MS-AMP to the Accelerator (round 2)

What does this add?

This PR introduces an additional backend for FP8 support through MS-AMP which has shown to decrease memory when using FP8 precision while maintaining accuracy

Who is it for?

Individuals training with FP8 (H100/4090's, etc)

Issues linked to

Azure/MS-AMP#128

What parts of the API does this impact?

User-facing:

Two new arguments were added to the FP8RecipeKwargs:

  • backend (str): Whether a user should use MS-AMP or TE (transformerengine). Uses MS-AMP by default.
  • optimization_level (str), should be one of "O1" or "O2". "O3" is for DeepSpeed and we need to wait for them to update to v0.9.3 of deepspeed to match what Accelerate supports

General guideline to optimization levels:

  • O1: Weight gradients and all_reduce communications are done in fp8, reducing GPU
    memory usage and communication bandwidth
  • O2: First-order optimizer states are in 8-bit, and second order states are in FP16.
    Only available when using Adam or AdamW. This maintains accuracy and can potentially save the highest
    memory.
  • 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models
    are stored in FP8. If fp8 is selected and deepspeed is enabled, will be used by default.
    (Not available currently).

As a result, "O2" is the default. Here is an overview of each optimization level and what it does, taken from their docs:

Optimization Level Computation(GEMM) Comm Weight Master Weight Weight Gradient Optimizer States
FP16 AMP FP16 FP32 FP32 N/A FP32 FP32+FP32
Nvidia TE FP8 FP32 FP32 N/A FP32 FP32+FP32
MS-AMP O1 FP8 FP8 FP16 N/A FP8 FP32+FP32
MS-AMP O2 FP8 FP8 FP16 N/A FP8 FP8+FP16
MS-AMP O3 FP8 FP8 FP8 FP16 FP8 FP8+FP16

Basic Usage Example(s):

A user can either do:

accelerator = Accelerator(mixed_precision="fp8")

Or use the FP8RecipeKwargs:

# To use TransformerEngine instead
kwarg_handlers = [FP8RecipeKwargs(backend="TE")]

# To change the optimization level
kwarg_handlers = [FP8RecipeKwargs(optimization_level="O1")]

accelerator = Accelerator(
    mixed_precision="fp8",
    kwargs_handlers=kwarg_handlers,
)

Benchmarks

When running on bloomz-560m I saw a memory decrease of ~1/3.

More experiments need to be conducted on the behavior between TE x MS-AMP wrt performance. For instance, when running my sample script I use here for speed (here) running on the first 100 batches, I saw a stark contrast in the ending training loss between TE and MS-AMP:

BF16 (baseline): 2.4867
TE: 11.3125
MS-AMP: 2.89

I also found overall there wasn't much of a time save with MS-AMP, as it actually added time instead (BF16 was ~0.139s/batch, while MS-AMP was 0.169s/batch). I want to run some more tests to verify but these were some local results.

This performance difference isn't much in the case of BF16 vs MS-AMP, but it is starkly contrast when compared to TE. More work is needed to investigate why, so as a result I've opted to make MS-AMP just an entirely separate backend to use, rather than combine the two.

What went wrong in the last PR

While the training speed results looked very good, I quickly noticed issues with the losses that didn't make sense. Models just simply weren't training or converging, which was surprising. For now taking this more staged approach to the integration while we discover behaviors (both with TE and MS-AMP) through longer training runs.

@muellerzr
Copy link
Collaborator Author

TODO: write some doc guides on MS-AMP

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for enabling MS-AMP FP8 in accelerate. Overall this looks good, but I had a couple of comments, please have a look.

In addition, I have these more general comments:

  1. IIUC, when using FP8 with TE, when it is detected that the device does not support it, there is an automatic fallback to fp16. Is there a similar mechanism for MS-AMP?
  2. Implementation-wise, the arguments for FP8 with MS-AMP vs TE are completely disjoint, right? This is a bit unfortunate, as users could e.g. set MS-AMP as backend and change amax_history_len and wonder why it has no effect. To be super user-friendly, we would have to add checks and docs that only arguments are changed that are valid for the given backend. A cleaner solution would be to use a completely separate dataclass for MS-AMP, although that might clash with the accelerate philosophy of abstracting away such implementation details.
  3. I guess no way to run CI tests for this :(

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/imports.py Outdated Show resolved Hide resolved
src/accelerate/accelerator.py Outdated Show resolved Hide resolved
src/accelerate/accelerator.py Show resolved Hide resolved
@muellerzr
Copy link
Collaborator Author

@BenjaminBossan re;

A cleaner solution would be to use a completely separate dataclass for MS-AMP, although that might clash with the accelerate philosophy of abstracting away such implementation details.

Eventually we will support a "mixed" backend that combines both, as MS-AMP has support for converting te.Linear layers to MS-AMP ones to get the best of both worlds. I'm not enabling this though until tests I run show that combining them does not degrade final accuracy, so for now they are disjoint and part of the same Plugin

@muellerzr
Copy link
Collaborator Author

@BenjaminBossan re; 1, we straight up don't allow it. In AcceleratorState we have:

            if mixed_precision == "fp8" and not is_fp8_available():
                raise ValueError("Using `fp8` precision requires `transformer_engine` to be installed.")

@BenjaminBossan
Copy link
Member

we straight up don't allow it

Oh okay. I was looking at this part of the code:

elif self.mixed_precision == "fp8":
if not has_transformer_engine_layers(model):
with torch.no_grad():
convert_model(model)
model._converted_to_transformer_engine = True
model._original_forward = model.forward
kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
cuda_device_capacity = torch.cuda.get_device_capability()
fp8_enabled = cuda_device_capacity >= (8, 9)
if not fp8_enabled:
logger.warn(
f"The current device has compute capability of {cuda_device_capacity} which is "
"insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)

@muellerzr
Copy link
Collaborator Author

We should probably refactor this then as part of is_fp8_available. Thanks for the smell :)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Unfortunately, I can't test it :(

src/accelerate/accelerator.py Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a few nits

docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/usage_guides/low_precision_training.md Outdated Show resolved Hide resolved
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @muellerzr for working on the MS-AMP FP8 support! ✨ Overall looks good wrt integration and the memory savings of 33% (1/3) for 560M param model. However, It would be great to see more experiments at scale at intermediate model scales of ~10B model sizes as the paper claims:

Experiment results show that, during the training of GPT-175B model
on H100 GPU platform, our FP8 mixed-precision training framework not only achieved a
remarkable 42% reduction in real memory usage but also ran 64% faster than the widely
adopted BF16 framework (i.e., Megatron-LM), surpassing the speed of Nvidia Transformer Engine by 17%.

Here, it is odd that we see no savings in time. Maybe is it the case that with reduced memory, they fit larger batches leading to faster training?

@muellerzr
Copy link
Collaborator Author

@pacman100 I'd expect that likely to be the case, I'll check FLOPS mainly on the scaled up training to see what they can result in :)

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@muellerzr
Copy link
Collaborator Author

cc @MKhalusova for docs and then we can merge 🤗

@skyshine102
Copy link

Thank you @muellerzr for working on the MS-AMP FP8 support! ✨ Overall looks good wrt integration and the memory savings of 33% (1/3) for 560M param model. However, It would be great to see more experiments at scale at intermediate model scales of ~10B model sizes as the paper claims:

Experiment results show that, during the training of GPT-175B model
on H100 GPU platform, our FP8 mixed-precision training framework not only achieved a
remarkable 42% reduction in real memory usage but also ran 64% faster than the widely
adopted BF16 framework (i.e., Megatron-LM), surpassing the speed of Nvidia Transformer Engine by 17%.

Here, it is odd that we see no savings in time. Maybe is it the case that with reduced memory, they fit larger batches leading to faster training?

Original paper shows that TE>MS-AMP>BF16 under same batch size setting (paper table 5). It is interesting to see MS-AMP slower than BF16.
Thank you are all for making this feature integrated to accelerate!

Copy link
Contributor

@MKhalusova MKhalusova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the docs! I left a few suggestions :)

docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/concept_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/usage_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/usage_guides/low_precision_training.md Outdated Show resolved Hide resolved
docs/source/usage_guides/low_precision_training.md Outdated Show resolved Hide resolved
@muellerzr
Copy link
Collaborator Author

@skyshine102:

Original paper shows that TE>MS-AMP>BF16 under same batch size setting (paper table 5). It is interesting to see MS-AMP slower than BF16.

I have an inkling model size plays a huge factor in this. With TE I've pretty much always seen it to be slower unless our model size is > 3B

muellerzr and others added 2 commits December 15, 2023 12:43
Co-authored-by: Maria Khalusova <kafooster@gmail.com>
@muellerzr muellerzr merged commit b052839 into main Dec 15, 2023
25 checks passed
@muellerzr muellerzr deleted the msamp-v2 branch December 15, 2023 18:07
@skyshine102
Copy link

@skyshine102:

Original paper shows that TE>MS-AMP>BF16 under same batch size setting (paper table 5). It is interesting to see MS-AMP slower than BF16.

I have an inkling model size plays a huge factor in this. With TE I've pretty much always seen it to be slower unless our model size is > 3B

I see!

@wangpengfei1013
Copy link

wangpengfei1013 commented Jan 19, 2024

I have tried this work,It's amazing. But some questions:
1、TE and MSAMP didn't have any difference?
2、BF16 is a little slower than the two,and comsumed a little more GPU memory comparation to fp8

GPU :L20
MODEL SIZE: yuan 2B
code:

compute_dtype = (
        torch.float16
        if training_args.fp16
        else (torch.bfloat16 if training_args.bf16 else torch.float32)
    )

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        device_map=device_map,
        trust_remote_code=True,
        torch_dtype=compute_dtype,

    )

    #kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID")
    kwargs = FP8RecipeKwargs(backend="msamp", opt_level="O1")

    accelerator = Accelerator(mixed_precision='fp8',  kwargs_handlers=[kwargs])
    #accelerator = Accelerator()
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = accelerator.device
    model.to(device)


    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True
    )
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.add_tokens(
        ['<eod>', '<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>',
         '<commit_before>',
         '<commit_msg>', '<commit_after>', '<jupyter_start>', '<jupyter_text>', '<jupyter_code>',
         '<jupyter_output>', '<empty_output>'], special_tokens=True)

    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
    train_dataloader = DataLoader(data_module['train_dataset'], shuffle=True, batch_size=8,)  # 通过这里的dataloader,每个batch的seq_len可能不同
    eval_dataloader = DataLoader(data_module['eval_dataset'], batch_size=8, )
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

    num_epochs = 3
    num_training_steps = num_epochs * len(train_dataloader)  # num of batches * num of epochs
    lr_scheduler=torch.optim.lr_scheduler.StepLR(optimizer, num_training_steps, gamma=0.1, last_epoch=-1, verbose=False)

    model, optimizer, train_dataloader,eval_dataloader,lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader,eval_dataloader,lr_scheduler)

    i=0
    for epoch in range(num_epochs):
        for batch in tqdm(train_dataloader):
            # 要在GPU上训练,需要把数据集都移动到GPU上:
            batch = {k: v.to(device) for k, v in batch.items()}
            loss = model(**batch).loss
            print("step: ",i,"loss: ",loss.item())
            #loss.backward()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            if (i+1) % 500==0:
                accelerator.save_state("wpf_test_abcd")
            i=i+1


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants