-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(train): Support gradient checkpointing for Conformer & Transformer (whisper) #2173
Conversation
Reduce the memory in training, and the cost is slower training speed? |
yes, "fit 10x larger neural nets into memory at the cost of an additional 20% computation time" |
I think it's possible to fintune whisper-large-v3 (full parameter) on 24 * 2080ti, by Deepspeed + Grad Checkpointing + Flash Attention. |
3b318ec
to
fbb0557
Compare
Overall, we give a flag for users and defaulted it to |
79e8f2f
to
0bbff4a
Compare
benckmark on whisper-large-v34 * 3090 without 8 * 2080ti works .... amazing config of deepspeed stage-3{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 4,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 1e3,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e3,
"contiguous_gradients" : true,
"stage3_max_live_parameters": 2e3,
"stage3_max_reuse_distance": 2e3,
"stage3_prefetch_bucket_size": 1e3,
"stage3_param_persistence_threshold": 1e3
},
"activation_checkpointing": {
"partition_activations": false,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": false
},
"flops_profiler": {
"enabled": false,
"profile_step": 100,
"module_depth": -1,
"top_modules": 1,
"detailed": true,
"output_file": null
},
"tensorboard": {
"enabled": false,
"output_path": "tensorboard/ds_logs/",
"job_name": "deepspeed"
}
} |
benchmark on conformer
|
fp32? |
yes !!! |
amazing again!!! |
TBD: |
pass unit test (gradient is equal), ready for a final review |
wait, I just find that |
与模型有关的 可能放yaml里会好些, 也不想影响其他模块 btw: 与train 有关的放arg里边 |
所以这里两个参数
|
都放yanl里 |
known issue: can not pass JIT due to traceback : Traceback (most recent call last):
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/jfs-hdfs/user/xingchen.song/workspace/github/tmp/wenet.latest.grad_ckpt/wenet/examples/aishell/s0/wenet/bin/train.py", line 95, in main
trace_and_print_model(args, model)
File "/jfs-hdfs/user/xingchen.song/workspace/github/tmp/wenet.latest.grad_ckpt/wenet/wenet/utils/train_utils.py", line 372, in trace_and_print_model
script_model = torch.jit.script(model)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_script.py", line 1286, in script
return torch.jit._recursive.create_script_module(
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 476, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_script.py", line 615, in _construct
init_fn(script_module)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 516, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 393, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_recursive.py", line 863, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/_script.py", line 1340, in script
ast = get_jit_def(obj, obj.__name__)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/frontend.py", line 293, in get_jit_def
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/frontend.py", line 331, in build_def
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/jit/frontend.py", line 355, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/bucket/output/jfs-hdfs/user/xingchen.song/tools/miniconda3/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 164
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
~~~~~~~ <--- HERE
r"""Checkpoint a model or part of the model
|
What is gradient ckpt? |
一句话总结: 2080ti 全量参数微调whisper-large 1.5B 不再是梦😜(具体技术细节可见开头的blog 链接) |
2be317c
to
0ecbd37
Compare
pass JIT ! please review this PR, @robin1001 @Mddct |
2 brief intros on
gradient checkpointing
TODO