Skip to content

Commit

Permalink
feat(train): set grad_ckpt=True for whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Nov 29, 2023
1 parent 4d35b5a commit 79e8f2f
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str):

configs['encoder'] = 'transformer'
configs['encoder_conf'] = {}
configs['encoder_conf']['gradient_checkpointing'] = True
configs['encoder_conf']['input_layer'] = 'conv1d2'
configs['encoder_conf']['output_size'] = dims['n_audio_state']
configs['encoder_conf']['attention_heads'] = dims['n_audio_head']
Expand All @@ -75,6 +76,7 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str):

configs['decoder'] = 'transformer'
configs['decoder_conf'] = {}
configs['decoder_conf']['gradient_checkpointing'] = True
configs['decoder_conf']['attention_heads'] = dims['n_text_head']
configs['decoder_conf']['linear_units'] = dims['n_text_state'] * 4
configs['decoder_conf']['num_blocks'] = dims['n_text_layer']
Expand Down

0 comments on commit 79e8f2f

Please sign in to comment.