Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train): Support gradient checkpointing for Conformer & Transformer (whisper) #2173

Merged
merged 11 commits into from
Dec 1, 2023

Conversation

xingchensong
Copy link
Member

@xingchensong xingchensong commented Nov 27, 2023

@robin1001
Copy link
Collaborator

Reduce the memory in training, and the cost is slower training speed?

@xingchensong
Copy link
Member Author

Reduce the memory in training, and the cost is slower training speed?

yes, "fit 10x larger neural nets into memory at the cost of an additional 20% computation time"

@xingchensong
Copy link
Member Author

I think it's possible to fintune whisper-large-v3 (full parameter) on 24 * 2080ti, by Deepspeed + Grad Checkpointing + Flash Attention.

@xingchensong
Copy link
Member Author

image

  1. find_unused_parameters=False is required by activation checkpointing in DDP
  2. find_unused_parameters=False is faster, according to
  3. find_unused_parameters=True is required when ctc_weight == 0.0 or 1.0

Overall, we give a flag for users and defaulted it to False

@xingchensong xingchensong changed the title feat(train): Support gradient checkpointing for Conformer & Transformer feat(train): Support gradient checkpointing for Conformer & Transformer (whisper) Nov 28, 2023
@xingchensong xingchensong force-pushed the xcsong-grad-ckpt branch 3 times, most recently from 79e8f2f to 0bbff4a Compare November 29, 2023 03:55
@xingchensong
Copy link
Member Author

xingchensong commented Nov 29, 2023

benckmark on whisper-large-v3

img_v3_025l_f327d50d-2558-4e5a-aea3-24275575b55g

4 * 3090 without gradient checkpointing, it will run into OOM

img_v3_025l_243999e5-813f-48c3-85d6-7aff0c375a0g

8 * 2080ti works .... amazing

config of deepspeed stage-3

{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 4,
  "steps_per_print": 100,
  "gradient_clipping": 5,
  "fp16": {
    "enabled": false,
    "auto_cast": false,
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "consecutive_hysteresis": false,
    "min_loss_scale": 1
  },
  "bf16": {
   "enabled": false
  },
  "zero_force_ds_cpu_optimizer": false,
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "none",
      "pin_memory": true
    },
    "offload_param": {
      "device": "none",
      "pin_memory": true
    },
    "allgather_partitions": true,
    "allgather_bucket_size": 1e3,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 1e3,
    "contiguous_gradients" : true,
    "stage3_max_live_parameters": 2e3,
    "stage3_max_reuse_distance": 2e3,
    "stage3_prefetch_bucket_size": 1e3,
    "stage3_param_persistence_threshold": 1e3
  },
  "activation_checkpointing": {
    "partition_activations": false,
    "cpu_checkpointing": false,
    "contiguous_memory_optimization": false,
    "number_checkpoints": null,
    "synchronize_checkpoint_boundary": false,
    "profile": false
  },
  "flops_profiler": {
    "enabled": false,
    "profile_step": 100,
    "module_depth": -1,
    "top_modules": 1,
    "detailed": true,
    "output_file": null
  },
  "tensorboard": {
    "enabled": false,
    "output_path": "tensorboard/ds_logs/",
    "job_name": "deepspeed"
  }
}

@Mddct Mddct self-requested a review November 29, 2023 05:38
@xingchensong
Copy link
Member Author

xingchensong commented Nov 29, 2023

benchmark on conformer

  • conformer.wenet.2.2.1: ddp
  • conformer.latest.373109c.grad_ckpt: ddp + grad_ckpt
  • conformer.latest.373109c.grad_ckpt.ds: deepspeed stage1 + grad_ckpt

image

  1. cv_loss is almost identical to wenet v2.2.1
  2. training with grad_ckpt has 15%~20% additional computation time compared to v2.2.1 (under the same train_config.yaml)

@Mddct
Copy link
Collaborator

Mddct commented Nov 29, 2023

fp32?

@xingchensong
Copy link
Member Author

fp32?

yes !!!

@Mddct
Copy link
Collaborator

Mddct commented Nov 29, 2023

amazing again!!!

@xingchensong
Copy link
Member Author

TBD: gradient_checkpointing 应该在yaml中配置还是在train args中配置?目前是在yaml中配置的,不在args中配置,是考虑到目前只有transformer/conformer支持了grad_ckpt,其他模型都没支持,用户使用 squeezeformer 等模型时如果加了参数 --gradient_checkpoiting ,实际并不起作用(因为squeezeformer的forward里没实现这个)

@xingchensong
Copy link
Member Author

pass unit test (gradient is equal), ready for a final review

@xingchensong
Copy link
Member Author

xingchensong commented Nov 30, 2023

wait, I just find that find_unused_param=False actually works in conjunction with ctc_weight=0.0. Perhaps we could simply eliminate the argument and retain find_unused_param=False as our default configuration.

@Mddct
Copy link
Collaborator

Mddct commented Nov 30, 2023

wait, I just find that find_unused_param=False actually works in conjunction with ctc_weight=0.0. Perhaps we could simply eliminate the argument and retain find_unused_param=False as our default configuration.

与模型有关的 可能放yaml里会好些, 也不想影响其他模块

btw: 与train 有关的放arg里边

@xingchensong
Copy link
Member Author

xingchensong commented Nov 30, 2023

所以这里两个参数

  1. find_unused_param ,加 or 不加?加的话加在哪合适
  2. gradient_checkpointing,加 or 不加?加的话加在哪合适

@Mddct
Copy link
Collaborator

Mddct commented Nov 30, 2023

都放yanl里

@xingchensong
Copy link
Member Author

xingchensong commented Nov 30, 2023

known issue: can not pass JIT due to checkpoint function

  traceback : Traceback (most recent call last):
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
      return f(*args, **kwargs)
    File "/jfs-hdfs/user/xingchen.song/workspace/github/tmp/wenet.latest.grad_ckpt/wenet/examples/aishell/s0/wenet/bin/train.py", line 95, in main
      trace_and_print_model(args, model)
    File "/jfs-hdfs/user/xingchen.song/workspace/github/tmp/wenet.latest.grad_ckpt/wenet/wenet/utils/train_utils.py", line 372, in trace_and_print_model
      script_model = torch.jit.script(model)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_script.py", line 1286, in script
      return torch.jit._recursive.create_script_module(
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 476, in create_script_module
      return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
      script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_script.py", line 615, in _construct
      init_fn(script_module)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 516, in init_fn
      scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
      create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 393, in create_methods_and_properties_from_stubs
      concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 863, in try_compile_fn
      return torch.jit.script(fn, _rcb=rcb)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_script.py", line 1340, in script
      ast = get_jit_def(obj, obj.__name__)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/frontend.py", line 293, in get_jit_def
      return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/frontend.py", line 331, in build_def
      param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/frontend.py", line 355, in build_param_list
      raise NotSupportedError(ctx_range, _vararg_kwarg_err)
  torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
    File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 164
  def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
                                                               ~~~~~~~ <--- HERE
      r"""Checkpoint a model or part of the model

@BenBill2077
Copy link

What is gradient ckpt?

@xingchensong
Copy link
Member Author

xingchensong commented Dec 1, 2023

What is gradient ckpt?

一句话总结: 2080ti 全量参数微调whisper-large 1.5B 不再是梦😜(具体技术细节可见开头的blog 链接)

@xingchensong
Copy link
Member Author

pass JIT ! please review this PR, @robin1001 @Mddct

@Mddct Mddct merged commit 8d6d23f into main Dec 1, 2023
6 checks passed
@Mddct Mddct deleted the xcsong-grad-ckpt branch December 1, 2023 07:25
@Mddct Mddct mentioned this pull request Mar 27, 2024
24 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants