Skip to content

Conversation

@3outeille
Copy link

@3outeille 3outeille commented Nov 17, 2025

Context

Reference PR: huggingface#1

This PR enables:

  • Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP (and the combinations between them). The following models were tested:
    • meta-llama/Llama-3.2-1B
    • microsoft/phi-2
    • Qwen/Qwen2.5-7B
    • mistralai/Mistral-7B-v0.1
    • ByteDance-Seed/Seed-Coder-8B-Instruct
    • Qwen/Qwen3-4B-Instruct-2507
    • arcee-ai/AFM-4.5B
    • ibm-granite/granite-3b-code-base-2k
    • baidu/ERNIE-4.5-0.3B-Base-PT
    • kyutai/helium-1-preview-2b
    • allenai/OLMo-7B-hf
    • mistralai/Ministral-8B-Instruct-2410
  • Patching HF models weights initialisation. Without this, the the loss and grad_norm starts very high

Usage

  • Requirements transformers==4.57.1
  • Config: torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml
...
[model]
- name = "llama3"
+ name = "transformers_backend"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
  • Train: LOG_RANK=7 CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable
image

Testing methodology

image
  • Following the converging.md guidelines, I am comparing the baseline FSDP=2 vs FSDP=2 & <other //-ism>
  • More precisely, the test_hf_integration.pyis going to do:
    results/
        |_ meta-llama
            |_ Llama-3.2-1B
                |_ debugmodel/
                    |_ seed_checkpoint/
                        |_ config.toml
                        |_ seed.slurm
                        |_ step-0/
                           |_ ....
                    |_ fsdp2_tp1_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                    |_ fsdp2_tp2_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp1_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log`
                |_ full/
                ...
  • Here is the grid search to test the HF modelling
#!/usr/bin/bash
model_names=(
     "meta-llama/Llama-3.2-1B"
     "microsoft/phi-2" 
     "Qwen/Qwen2.5-7B"
     "mistralai/Mistral-7B-v0.1"
     "ByteDance-Seed/Seed-Coder-8B-Instruct"
     "Qwen/Qwen3-4B-Instruct-2507" 
     "arcee-ai/AFM-4.5B" 
     "ibm-granite/granite-3b-code-base-2k" 
     "baidu/ERNIE-4.5-0.3B-Base-PT" 
     "kyutai/helium-1-preview-2b" 
     "allenai/OLMo-7B-hf"
     "mistralai/Ministral-8B-Instruct-2410" 
)

for model_name in "${model_names[@]}"; do
    rm -rf slurm_results/${model_name}

    python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
    while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
        echo "Waiting for seed checkpoint from ${model_name} to complete ..."
        sleep 1
    done
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
    echo "================"
done

Further tasks

  • Moe (handle in PR Add transformer backend (MoE) clean  huggingface/torchtitan#3)
    • Missing build_optimizers_with_moe_load_balancing support for MoE
    • Missing TP/PP/EP supports for MoE
  • When using HF modeling, the test FSDP=2 vs FSDP=2 + PP=2, the loss and grad_norm not bitwise matching (but converging) while it is the case with Torchtitan modeling. (issue is tracked in Fix pp convergence to be bitwise huggingface/torchtitan#4)
  • Add convergence tests to CI by doing tiny model + gloo backend (once PP is bitwise matching)
  • the HF modeling has lower MFU than Torchtitan MFU
  • NOTE: import torch._dynamo.config; torch._dynamo.config.cache_size_limit = 128 to avoid recomputation for graph when using torch.compile and activation checkpointing

@meta-cla
Copy link

meta-cla bot commented Nov 17, 2025

Hi @3outeille!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@3outeille 3outeille changed the title 3outeille/transformers backend 3outeille/transformers backend (Dense model only) Nov 17, 2025
@meta-cla
Copy link

meta-cla bot commented Nov 17, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 17, 2025
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work again, let some comments

num_layers: int,
input_weight: int = 1,
output_weight: int = 1,
include_rotary_emb: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not included in https://github.com/huggingface/torchtitan/pull/1/files , can you quickly remind me why we need to include rotary embedding when PP is applied?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in torchtitan models, we make rotary_emb a function, not a module, but looks like for HF models, rotary_emb is a module, that's why we need this module being included in PP?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that was to address the issue you mentionned here: huggingface#1 (comment)

Can we modify and add a parameter in function signature in pytorch/torchtitan@main/torchtitan/distributed/pipeline_parallel.py#L41 instead of keep 2 copies? I feel it's very easy to get diverged in the future

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in torchtitan models, we make rotary_emb a function, not a module, but looks like for HF models, rotary_emb is a module, that's why we need this module being included in PP?

Yes exactly !

setattr(model, module_name, None)
# Replace with Identity or None based on configuration
replacement = (
nn.Identity() if use_identity_for_missing_modules else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you quicly remind me why we need to use Identity() here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because HF define their models without things like if toke_embeddings is None.

I still worry about such identities breaks DCP and could be the source of PP numerics issue. The concrete question is, when loading from seed checkpoint, are all the PP ranks restored perfectly?

cc @fegin if you know this definitively.

Copy link
Author

@3outeille 3outeille Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concrete question is, when loading from seed checkpoint, are all the PP ranks restored perfectly?

Seems like PP ranks are restored perfectly because we have perfect match with Qwen but not with Llama for example (cf the screenshot at huggingface#4)

)


def apply_fsdp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By reading this function, the function is the same as the apply_fsdp function in llama4/parallelize (I know we will keep MoE capability for next PR), can we reuse the apply_fsdp function from llama4 and avoid keeping multiple copies?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the difference - The only difference is moe_block = transformer_block.mlp line 337, in transformers models, the MoE module is named mlp, instead of moe. Can we use the same getter/setter way to rename it in model.py, so we can reuse the apply_fsdp function from llama4.

I don't have strong opinion on this, but I'm a little bit concerned if we have several copies, they will become diverged easily in the future

Copy link
Author

@3outeille 3outeille Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid concern. i'll reuse fsdp from llama3 for now as this PR handles only dense. It will make more sense to handle the getter/setter in the MoE PR

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address final comments.

setattr(model, module_name, None)
# Replace with Identity or None based on configuration
replacement = (
nn.Identity() if use_identity_for_missing_modules else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because HF define their models without things like if toke_embeddings is None.

I still worry about such identities breaks DCP and could be the source of PP numerics issue. The concrete question is, when loading from seed checkpoint, are all the PP ranks restored perfectly?

cc @fegin if you know this definitively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds the changes are caused by specific ways transformers define models. Then let's fork the changed functions into experiments/transformers_backend/. I apologize for the back & forth.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but isnt the compromise good enough ? Copy pasting means not noticing changes in Pipeline parallel later on

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants