Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I freeze weights when using FSDP? #807

Closed
2 of 4 tasks
antopost opened this issue Nov 1, 2022 · 3 comments
Closed
2 of 4 tasks

How do I freeze weights when using FSDP? #807

antopost opened this issue Nov 1, 2022 · 3 comments

Comments

@antopost
Copy link

antopost commented Nov 1, 2022

System Info

- `Accelerate` version: 0.14.0.dev0
- Platform: Linux-5.4.0-128-generic-x86_64-with-glibc2.17
- Python version: 3.8.12
- Numpy version: 1.21.5
- PyTorch version (GPU?): 1.12.1+cu113 (True)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: FSDP
        - mixed_precision: fp16
        - use_cpu: False
        - num_processes: 2
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: None
        - main_process_ip: None
        - main_process_port: None
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - deepspeed_config: {}
        - fsdp_config: {'fsdp_auto_wrap_policy': 'NO_WRAP', 'fsdp_backward_prefetch_policy': 'NO_PREFETCH', 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 1, 'fsdp_state_dict_type': 'FULL_STATE_DICT'}
        - downcast_bf16: no

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Running into some issues when freezing weights doing multi-GPU training using FSDP.
I've tried preparing my model before and after freezing the weights, both with different and equally disappointing results.

def freeze_layers(model, to_freeze, verbose=True):
    for i, (name, param) in enumerate(model.named_parameters()):
        freeze = '--> freeze' if i in to_freeze else ''
        if verbose:
            print(i, name, ' '*(45 - len(name)-len(str(i))), freeze)    # >>> layer_name   ---> freeze
        if freeze:
            param.requires_grad = False

Preparing before:

accelerator = Accelerator()
device = accelerator.device

# init model
model = load_model(args.model).to(device)
model = accelerator.prepare(model)

# freeze layers
to_freeze = [0, 1, 2, 3]
freeze_layers(model, to_freeze)

freeze_layers prints this to console:

0 _fsdp_wrapped_module.flat_param --> freeze

Preparing after:

accelerator = Accelerator()
device = accelerator.device

# init model
model = load_model(args.model).to(device)

# freeze layers 0 to 3
to_freeze = [0, 1, 2, 3]
freeze_layers(model, to_freeze)

# wait for all processes to freeze the desired layers
accelerate.wait_for_everyone()
model = accelerate.prepare(model)

I get this error:

Traceback (most recent call last):
  File "train.py", line 673, in <module>
    main(config, output_dir, args)
  File "train.py", line 636, in main
    TA = TrainAgent(config, output_dir, args)
  File "train.py", line 112, in __init__
    self.model = self.accelerator.prepare(self.model)
  File "/opt/conda/lib/python3.8/site-packages/accelerate/accelerator.py", line 681, in prepare
    result = tuple(
  File "/opt/conda/lib/python3.8/site-packages/accelerate/accelerator.py", line 682, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/opt/conda/lib/python3.8/site-packages/accelerate/accelerator.py", line 556, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/opt/conda/lib/python3.8/site-packages/accelerate/accelerator.py", line 731, in prepare_model
    model = FSDP(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 814, in __init__
    self._fsdp_wrapped_module: FlattenParamsWrapper = FlattenParamsWrapper(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/flatten_params_wrapper.py", line 319, in __init__
    params, param_infos, shared_param_infos = self._init_flatten_params()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/flatten_params_wrapper.py", line 370, in _init_flatten_params
    assert (
AssertionError: expects all parameters to have same requires_grad

Any help would be much appreciated :)

Expected behavior

The specified model layers of each respective process should be set to requires grad=False
@sgugger
Copy link
Collaborator

sgugger commented Nov 1, 2022

cc @pacman100

@pacman100
Copy link
Contributor

Hello @antopost, the "NO_WRAP" policy doesn't save any CUDA memory as all the parameters of the entire model are gathered during the forward pass instead of a few layers. More details here: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html

FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency. The way it works is that, suppose your model contains 100 Linear layers. If you do FSDP(model), there will only be one FSDP unit which wraps the entire model. In that case, the allgather would collect the full parameters for all 100 linear layers, and hence won’t save CUDA memory for parameter sharding. Also, there is only one blocking allgather call for the all 100 linear layers, there will not be communication and computation overlapping between layers.

To avoid that, you can pass in an fsdp_auto_wrap_policy, which will seal the current FSDP unit and start a new one automatically when the specified condition is met (e.g., size limit). In that way you will have multiple FSDP units, and only one FSDP unit needs to collect full parameters at a time. E.g., suppose you have 5 FSDP units, and each wraps 20 linear layers. Then, in the forward, the 1st FSDP unit will allgather parameters for the first 20 linear layers, do computation, discard the parameters and then move on to the next 20 linear layers. So, at any point in time, each rank only materializes parameters/grads for 20 linear layers instead of 100.

Next, regarding freezing weights when using FSDP, the weights of FSDP units are flattened wherein each unit can span multiple layers.

  1. This is the reason preparing post freezing is leading to expects all parameters to have same requires_grad because all layers are part of a single FSDP unit, as such all of them are combined and flattened, resulting in few flattened params without requires_grad.
  2. Preparing prior to freezing leads to model params of the single FSDP unit (NO_WRAP) being flattened with only one named parameter _fsdp_wrapped_module.flat_param as you noted above. When freezing this , it would lead to all params being requires_grad=False.

The freezing of certain weights requires manual wrapping of the model with each frozen layer wrapped into a separate FSDP unit so that all parameters of that wrapped FSDP unit have the same requires_grad. Please go through https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/ for an example of manual wrapping. You can freeze the model layers, then do manual wrapping into FSD units and pass the model to accelerator.prepare. In such case where model is manually wrapped, accelerator.prepare would be a no-op as model is already a FSDP model.

As NO_WRAP doesn't save any CUDA memory, you might as well use standard PyTorch DDP in which freezing weights is straightforward. Also, this doesn't concern the integration of FSDP in 🤗 Accelerate. Please open issue with PyTorch for more assistance on this issue.

@antopost
Copy link
Author

antopost commented Nov 2, 2022

Thanks for the clarification!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants