diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py index c53525c..a3dcd66 100644 --- a/run_flax_speech_recognition_ctc.py +++ b/run_flax_speech_recognition_ctc.py @@ -106,12 +106,30 @@ class ModelArguments: freeze_feature_encoder: bool = field( default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) hidden_dropout: float = field( default=0.1, metadata={ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." }, ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) @flax.struct.dataclass @@ -835,7 +853,10 @@ def main(): config.update( { "gradient_checkpointing": training_args.gradient_checkpointing, + "activation_dropout": model_args.activation_dropout, "hidden_dropout": model_args.hidden_dropout, + "feat_proj_dropout": model_args.feat_proj_dropout, + "mask_time_prob": model_args.mask_time_prob, "vocab_size": tokenizer.vocab_size, } ) diff --git a/run_flax_speech_recognition_ctc_ngram.py b/run_flax_speech_recognition_ctc_ngram.py index 6f757de..8cf0c7a 100644 --- a/run_flax_speech_recognition_ctc_ngram.py +++ b/run_flax_speech_recognition_ctc_ngram.py @@ -110,12 +110,30 @@ class ModelArguments: freeze_feature_encoder: bool = field( default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) hidden_dropout: float = field( default=0.1, metadata={ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." }, ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) @flax.struct.dataclass @@ -849,7 +867,10 @@ def main(): config.update( { "gradient_checkpointing": training_args.gradient_checkpointing, + "activation_dropout": model_args.activation_dropout, "hidden_dropout": model_args.hidden_dropout, + "feat_proj_dropout": model_args.feat_proj_dropout, + "mask_time_prob": model_args.mask_time_prob, "vocab_size": tokenizer.vocab_size, } )