From 0bbff4a9b389b99949ccd0b429442607e62f89ed Mon Sep 17 00:00:00 2001 From: xingchensong Date: Tue, 28 Nov 2023 18:34:18 +0800 Subject: [PATCH] feat(train): set grad_ckpt=True for whisper --- wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 2dced14fd..f32ddb85e 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -57,6 +57,7 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['encoder'] = 'transformer' configs['encoder_conf'] = {} + configs['encoder_conf']['gradient_checkpointing'] = True configs['encoder_conf']['input_layer'] = 'conv1d2' configs['encoder_conf']['output_size'] = dims['n_audio_state'] configs['encoder_conf']['attention_heads'] = dims['n_audio_head'] @@ -75,6 +76,7 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['decoder'] = 'transformer' configs['decoder_conf'] = {} + configs['decoder_conf']['gradient_checkpointing'] = True configs['decoder_conf']['attention_heads'] = dims['n_text_head'] configs['decoder_conf']['linear_units'] = dims['n_text_state'] * 4 configs['decoder_conf']['num_blocks'] = dims['n_text_layer']