Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to make the deepspeed zero3 integration work with falcon7b #739

Closed
vwxyzjn opened this issue Sep 5, 2023 · 8 comments
Closed

Unable to make the deepspeed zero3 integration work with falcon7b #739

vwxyzjn opened this issue Sep 5, 2023 · 8 comments

Comments

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Sep 5, 2023

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sentiment_tuning.py --batch_size 32 --model_name tiiuae/falcon-7b --mini_batch_size 1 --
log_with wandb

Trace

Traceback (most recent call last):
  File "examples/scripts/sentiment_tuning.py", line 154, in <module>
    model = trl_model_class.from_pretrained(
  File "/fsx/costa/trl/trl/models/modeling_base.py", line 199, in from_pretrained
    pretrained_model = cls.transformers_parent_class.from_pretrained(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 479, in from_pretrained
    return model_class.from_pretrained(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2881, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3278, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for RWForCausalLM:
        size mismatch for transformer.h.0.self_attention.query_key_value.weight: copying a param with shape torch.Size([4672, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.0.self_attention.dense.weight: copying a param with shape torch.Size([4544, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.0.mlp.dense_h_to_4h.weight: copying a param with shape torch.Size([18176, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.0.mlp.dense_4h_to_h.weight: copying a param with shape torch.Size([4544, 18176]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.1.self_attention.query_key_value.weight: copying a param with shape torch.Size([4672, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.1.self_attention.dense.weight: copying a param with shape torch.Size([4544, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.1.mlp.dense_h_to_4h.weight: copying a param with shape torch.Size([18176, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
...

Traceback (most recent call last):
  File "examples/scripts/sentiment_tuning.py", line 154, in <module>
    model = trl_model_class.from_pretrained(
  File "/fsx/costa/trl/trl/models/modeling_base.py", line 199, in from_pretrained
    model = trl_model_class.from_pretrained(
  File "/fsx/costa/trl/trl/models/modeling_base.py", line 199, in from_pretrained
    pretrained_model = cls.transformers_parent_class.from_pretrained(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 479, in from_pretrained
    pretrained_model = cls.transformers_parent_class.from_pretrained(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 479, in from_pretrained
    return model_class.from_pretrained(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2881, in from_pretrained
    return model_class.from_pretrained(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2881, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3278, in _load_pretrained_model
    ) = cls._load_pretrained_model(
  File "/admin/home/costa/.pyenv/versions/3.8.11/envs/trl/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3278, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
@lewtun
Copy link
Member

lewtun commented Sep 6, 2023

I think it's an issue with the FalconLinear layer not having the right device set somehow - looking up the logs, one sees this:

[2023-09-06 09:12:38,823] [WARNING] [partition_parameters.py:836:_post_init_method] param `weight` in FalconLinear not on GPU so was not broadcasted from rank 0
[2023-09-06 09:12:38,828] [WARNING] [partition_parameters.py:836:_post_init_method] param `bias` in FalconLinear not on GPU so was not broadcasted from rank 0

A fix is probably required on the transformers side, but first I need to check if this is just a trl or transformers related issue :)

@pacman100
Copy link
Contributor

trust_remote_code=True solves this issue
Screenshot 2023-09-06 at 3 44 06 PM

@lewtun
Copy link
Member

lewtun commented Sep 6, 2023

OK I think I've traced it to an issue with ZeRO-3 and two instances of AutoCausalLMWithValueHead (one for the reference model, the other for the active model).

Here's a minimal repro:

# falcon_zero3_bug.py
from trl import AutoModelForCausalLMWithValueHead

print("Loading CLM")
model_clm = AutoModelForCausalLMWithValueHead.from_pretrained("tiiuae/falcon-rw-1b")
print("Loaded CLM!")

print("Loading VH...")
model_vh = AutoModelForCausalLMWithValueHead.from_pretrained("tiiuae/falcon-rw-1b")
print("Loaded VH!")

Command to reproduce error:

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml falcon_zero3_bug.py

Curiously, there isn't a problem with instantiating two AutoCausalLM instances, so there must be an issue with the custom FalconLinear layer in transformers and how deepspeed interacts with them

@lewtun
Copy link
Member

lewtun commented Sep 6, 2023

@pacman100 which transformers version are you using? For transformers==4.33.0 I find that trust_remote_code=True still produces the same error.

Here's my complete env:

- `transformers` version: 4.33.0
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.10.10
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.1
- Accelerate version: 0.22.0
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@lewtun
Copy link
Member

lewtun commented Sep 6, 2023

Aha, I figured out the root cause: it's coming from the zero.init() (link) and the fact that modeling_falcon.py in transformers has a custom FalconLinear layer. Disabling the init in examples/accelerate_configs/deepspeed_zero3.yaml with:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
- zero3_init_flag: true
+ zero3_init_flag: false
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

and then running the following works:

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sentiment_tuning.py --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name tiiuae/falcon-rw-1b

I think the solution is to adjust the modeling_falcon.py code in `transformers to support the custom linear layer - I'll take a look at this.

Comparison plot to GPT-2 runs

Screenshot 2023-09-06 at 12 42 41

@pacman100
Copy link
Contributor

Thank you @lewtun for the deep dive. I was also looking into this and the weird part is that the minimal code you gave works sometimes and fails the other times.

@lewtun
Copy link
Member

lewtun commented Sep 8, 2023

To make zero.init() work for Falcon models, my current hunch is that we need to change the Falcon modelling code in transformers to use deepspeed.zero.GatheredParameters like other models do (example). Does this sound correct @pacman100 ?

@younesbelkada
Copy link
Contributor

Perhaps this should be fixed now with #758 but not sure

@lvwerra lvwerra closed this as completed Sep 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants