From 24c3676a8d4c2e60d2726e9bcd9bdbed740610e0 Mon Sep 17 00:00:00 2001 From: kamo-naoyuki Date: Wed, 18 May 2022 16:16:53 +0900 Subject: [PATCH] Apply isort --- egs/arctic/tts1/local/clean_text.py | 1 - egs/chime6/asr1/local/extract_noises.py | 5 +- egs/chime6/asr1/local/make_noise_list.py | 1 - egs/cmu_indic/tts1/local/clean_text.py | 1 - egs/covost2/st1/local/process_tsv.py | 2 +- egs/csj/asr1/local/csj_rm_tag.py | 3 +- egs/iwslt16/mt1/local/extract_recog_text.py | 2 +- egs/iwslt16/mt1/local/generate_json.py | 6 +- egs/iwslt16/mt1/local/generate_vocab.py | 2 +- egs/iwslt18/st1/local/parse_xml.py | 2 +- egs/iwslt21/asr1/local/filter_parentheses.py | 1 + .../st1/local/data_prep.py | 1 - egs/jnas/asr1/local/filter_text.py | 3 +- .../asr1/local/get_space_normalized_hyps.py | 2 +- .../asr1/local/get_transcriptions.py | 3 +- egs/libri_css/asr1/local/best_wer_matching.py | 3 +- .../asr1/local/get_perspeaker_output.py | 2 +- egs/libri_css/asr1/local/prepare_data.py | 1 + .../local/segmentation/apply_webrtcvad.py | 1 + egs/ljspeech/tts1/local/clean_text.py | 2 +- egs/lrs/avsr1/local/se_batch.py | 5 +- egs/mgb2/asr1/local/process_xml.py | 3 +- egs/mgb2/asr1/local/text_segmenting.py | 1 + .../asr1/local/data_prep.py | 5 +- .../asr1/local/construct_dataset.py | 1 - egs/puebla_nahuatl/asr1/local/data_prep.py | 2 - egs/puebla_nahuatl/st1/local/data_prep.py | 2 +- egs/reverb/asr1/local/filterjson.py | 3 +- egs/reverb/asr1/local/run_wpe.py | 7 +-- egs/reverb/asr1_multich/local/filterjson.py | 2 +- egs/tweb/tts1/local/clean_text.py | 1 - egs/vais1000/tts1/local/clean_text.py | 1 - .../tts1_en_fi/local/clean_text_css10.py | 11 ++-- .../vc1_task1/local/clean_text_asr_result.py | 2 +- .../vc1_task2/local/clean_text_finnish.py | 15 ++--- .../vc1_task2/local/clean_text_german.py | 3 +- .../vc1_task2/local/clean_text_mandarin.py | 7 +-- egs/vcc20/voc1/local/subset_data_dir.py | 2 +- egs/voxforge/asr1/local/filter_text.py | 3 +- egs/wsj/asr1/local/filtering_samples.py | 5 +- egs/wsj_mix/asr1/local/merge_scp2json.py | 2 +- egs/wsj_mix/asr1/local/mergejson.py | 1 - .../asr1/local/data_prep.py | 5 +- .../asr1/pyscripts/audio/format_wav_scp.py | 6 +- .../pyscripts/utils/convert_text_to_phn.py | 4 +- .../asr1/pyscripts/utils/evaluate_f0.py | 6 +- .../asr1/pyscripts/utils/evaluate_mcd.py | 6 +- .../asr1/pyscripts/utils/extract_xvectors.py | 10 ++-- .../asr1/pyscripts/utils/plot_sinc_filters.py | 5 +- .../asr1/pyscripts/utils/rotate_logfile.py | 2 +- .../asr1/pyscripts/utils/score_intent.py | 3 +- .../pyscripts/utils/score_summarization.py | 9 ++- .../diar1/pyscripts/utils/convert_rttm.py | 13 ++-- .../diar1/pyscripts/utils/make_rttm.py | 6 +- egs2/TEMPLATE/ssl1/pyscripts/dump_km_label.py | 9 ++- .../TEMPLATE/ssl1/pyscripts/feature_loader.py | 1 - egs2/TEMPLATE/ssl1/pyscripts/sklearn_km.py | 12 ++-- .../asr1/local/remove_missing.py | 1 - egs2/aishell3/tts1/local/data_prep.py | 1 + .../local/prepare_audioset_category_list.py | 2 +- egs2/aishell4/enh1/local/split_train_dev.py | 10 ++-- .../enh1/local/split_train_dev_by_column.py | 7 +-- .../enh1/local/split_train_dev_by_prefix.py | 7 +-- egs2/bn_openslr53/asr1/local/data_prep.py | 1 - egs2/bur_openslr80/asr1/local/data_prep.py | 1 - egs2/catslu/asr1/local/data_prep.py | 4 +- egs2/chime4/asr1/local/sym_channel.py | 2 +- egs2/clarity21/enh1/local/prep_data.py | 1 - .../enh1/local/prepare_dev_data.py | 2 +- .../dirha_wsj/asr1/local/prepare_dirha_wsj.py | 4 +- egs2/dsing/asr1/local/data_prep.py | 8 +-- egs2/fsc/asr1/local/data_prep.py | 1 + egs2/fsc_challenge/asr1/local/data_prep.py | 3 +- egs2/fsc_unseen/asr1/local/data_prep.py | 3 +- egs2/grabo/asr1/local/data_prep.py | 5 +- egs2/grabo/asr1/local/score.py | 2 +- egs2/indic_speech/tts1/local/data_prep.py | 1 - .../asr1/local/prepare_alffa_data.py | 2 +- .../asr1/local/prepare_iwslt_data.py | 1 + egs2/iwslt22_dialect/asr1/local/preprocess.py | 6 +- egs2/iwslt22_dialect/st1/local/preprocess.py | 6 +- egs2/jdcinal/asr1/local/score.py | 1 + egs2/jkac/tts1/local/prep_segments.py | 3 +- egs2/jmd/tts1/local/clean_text.py | 1 - egs2/jtubespeech/tts1/local/prune.py | 7 ++- egs2/jtubespeech/tts1/local/split.py | 7 ++- egs2/jv_openslr35/asr1/local/data_prep.py | 1 - .../asr1/local/get_space_normalized_hyps.py | 2 +- .../asr1/local/get_transcriptions.py | 3 +- .../diar1/local/prepare_diarization.py | 2 +- .../local/feature_extract/cvtransforms.py | 1 + .../feature_extract/extract_visual_feature.py | 4 +- .../feature_extract/models/pretrained.py | 3 +- .../local/feature_extract/video_processing.py | 6 +- egs2/lrs3/asr1/local/data_prep.py | 7 ++- egs2/mediaspeech/asr1/local/data_prep.py | 9 ++- egs2/microsoft_speech/asr1/local/process.py | 6 +- .../diar1/local/simulation/make_mixture.py | 5 +- .../simulation/make_mixture_nooverlap.py | 5 +- .../diar1/local/simulation/random_mixture.py | 7 ++- .../simulation/random_mixture_nooverlap.py | 7 ++- egs2/misp2021/asr1/local/find_wav.py | 6 +- egs2/misp2021/asr1/local/prepare_far_data.py | 8 +-- egs2/misp2021/asr1/local/run_beamformit.py | 2 +- egs2/misp2021/asr1/local/run_wpe.py | 9 +-- .../avsr1/local/concatenate_feature.py | 5 +- egs2/misp2021/avsr1/local/find_wav.py | 6 +- egs2/misp2021/avsr1/local/prepare_far_data.py | 8 +-- .../avsr1/local/prepare_far_video_roi.py | 11 ++-- .../prepare_visual_embedding_extractor.py | 5 +- egs2/misp2021/avsr1/local/run_beamformit.py | 2 +- egs2/misp2021/avsr1/local/run_wpe.py | 9 +-- egs2/ml_openslr63/asr1/local/data_prep.py | 1 - egs2/mr_openslr64/asr1/local/data_prep.py | 1 - egs2/ms_indic_18/asr1/local/prepare_data.py | 2 +- egs2/open_li52/asr1/local/filter_text.py | 3 +- .../asr1/local/data_prep.py | 9 ++- egs2/seame/asr1/local/preprocess.py | 8 +-- egs2/seame/asr1/local/split_lang_trn.py | 12 ++-- egs2/sinhala/asr1/local/data_prep.py | 8 +-- .../asr1/local/data_prep_slue.py | 1 + egs2/slue-voxceleb/asr1/local/f1_score.py | 6 +- .../asr1/local/generate_asr_files.py | 3 +- .../local/data_prep_original_slue_format.py | 3 +- ...ta_prep_original_slue_format_transcript.py | 3 +- egs2/slue-voxpopuli/asr1/local/eval_utils.py | 5 +- egs2/slue-voxpopuli/asr1/local/score.py | 4 +- egs2/slurp/asr1/local/prepare_slurp_data.py | 4 +- .../asr1/local/convert_to_entity_file.py | 4 +- .../asr1/local/evaluation/evaluate.py | 5 +- .../asr1/local/evaluation/metrics/__init__.py | 3 +- .../asr1/local/evaluation/metrics/distance.py | 3 +- .../asr1/local/evaluation/util.py | 5 +- .../asr1/local/prepare_slurp_data.py | 4 +- .../asr1/local/prepare_slurp_entity_data.py | 4 +- egs2/snips/asr1/local/data_prep.py | 2 +- .../speechcommands/asr1/local/data_prep_12.py | 8 +-- .../speechcommands/asr1/local/data_prep_35.py | 4 +- egs2/speechcommands/asr1/local/score.py | 2 +- .../asr1/local/sunda_data_prep.py | 1 - .../asr1/local/prepare_sentiment.py | 4 +- egs2/swbd_sentiment/asr1/local/score_f1.py | 3 +- egs2/totonac/asr1/local/data_prep.py | 5 +- egs2/wenetspeech/asr1/local/extract_meta.py | 4 +- egs2/wenetspeech/asr1/local/process_opus.py | 5 +- .../asr1/local/filter_text.py | 3 +- egs2/zh_openslr38/asr1/local/data_split.py | 4 +- espnet/asr/asr_utils.py | 1 - espnet/asr/chainer_backend/asr.py | 33 ++++------- espnet/asr/pytorch_backend/asr.py | 39 +++++------- espnet/asr/pytorch_backend/asr_init.py | 6 +- espnet/asr/pytorch_backend/asr_mix.py | 37 +++++------- espnet/asr/pytorch_backend/recog.py | 10 ++-- espnet/bin/asr_align.py | 16 +++-- espnet/bin/asr_enhance.py | 4 +- espnet/bin/asr_recog.py | 2 +- espnet/bin/mt_trans.py | 2 +- espnet/bin/tts_decode.py | 3 +- espnet/bin/vc_decode.py | 3 +- espnet/lm/chainer_backend/extlm.py | 1 + espnet/lm/chainer_backend/lm.py | 27 +++------ espnet/lm/lm_utils.py | 10 ++-- espnet/lm/pytorch_backend/lm.py | 34 ++++------- espnet/mt/pytorch_backend/mt.py | 32 ++++------ espnet/nets/batch_beam_search.py | 9 +-- espnet/nets/batch_beam_search_online.py | 22 ++++--- espnet/nets/batch_beam_search_online_sim.py | 3 +- espnet/nets/beam_search.py | 13 ++-- espnet/nets/beam_search_transducer.py | 17 ++---- espnet/nets/chainer_backend/ctc.py | 2 +- .../chainer_backend/deterministic_embed_id.py | 11 +--- espnet/nets/chainer_backend/e2e_asr.py | 5 +- .../chainer_backend/e2e_asr_transformer.py | 27 ++++----- espnet/nets/chainer_backend/rnn/attentions.py | 1 - espnet/nets/chainer_backend/rnn/decoders.py | 6 +- espnet/nets/chainer_backend/rnn/encoders.py | 3 +- espnet/nets/chainer_backend/rnn/training.py | 15 ++--- .../chainer_backend/transformer/attention.py | 2 - .../chainer_backend/transformer/decoder.py | 7 +-- .../transformer/decoder_layer.py | 9 ++- .../chainer_backend/transformer/embedding.py | 1 - .../chainer_backend/transformer/encoder.py | 14 ++--- .../transformer/encoder_layer.py | 9 ++- .../transformer/label_smoothing_loss.py | 1 - .../transformer/positionwise_feed_forward.py | 2 - .../transformer/subsampling.py | 10 ++-- .../chainer_backend/transformer/training.py | 9 ++- espnet/nets/ctc_prefix_score.py | 3 +- espnet/nets/e2e_asr_common.py | 2 +- .../pytorch_backend/conformer/argument.py | 2 +- .../contextual_block_encoder_layer.py | 3 +- .../nets/pytorch_backend/conformer/encoder.py | 38 ++++++------ .../conformer/encoder_layer.py | 1 - espnet/nets/pytorch_backend/ctc.py | 2 +- espnet/nets/pytorch_backend/e2e_asr.py | 30 +++++----- .../nets/pytorch_backend/e2e_asr_conformer.py | 11 ++-- .../nets/pytorch_backend/e2e_asr_maskctc.py | 16 ++--- espnet/nets/pytorch_backend/e2e_asr_mix.py | 28 ++++----- .../e2e_asr_mix_transformer.py | 10 ++-- espnet/nets/pytorch_backend/e2e_asr_mulenc.py | 13 ++-- .../pytorch_backend/e2e_asr_transducer.py | 58 ++++++++++-------- .../pytorch_backend/e2e_asr_transformer.py | 42 +++++++------ espnet/nets/pytorch_backend/e2e_mt.py | 18 +++--- .../pytorch_backend/e2e_mt_transformer.py | 26 ++++---- espnet/nets/pytorch_backend/e2e_st.py | 28 ++++----- .../nets/pytorch_backend/e2e_st_conformer.py | 11 ++-- .../pytorch_backend/e2e_st_transformer.py | 30 +++++----- .../pytorch_backend/e2e_tts_fastspeech.py | 30 +++++----- .../nets/pytorch_backend/e2e_tts_tacotron2.py | 8 +-- .../pytorch_backend/e2e_tts_transformer.py | 27 +++++---- .../nets/pytorch_backend/e2e_vc_tacotron2.py | 17 +++--- .../pytorch_backend/e2e_vc_transformer.py | 26 ++++---- .../frontends/dnn_beamformer.py | 9 ++- .../nets/pytorch_backend/frontends/dnn_wpe.py | 2 +- .../frontends/feature_transform.py | 4 +- .../pytorch_backend/frontends/frontend.py | 5 +- .../frontends/mask_estimator.py | 3 +- espnet/nets/pytorch_backend/lm/default.py | 6 +- espnet/nets/pytorch_backend/lm/transformer.py | 9 ++- espnet/nets/pytorch_backend/rnn/attentions.py | 5 +- espnet/nets/pytorch_backend/rnn/decoders.py | 16 ++--- espnet/nets/pytorch_backend/rnn/encoders.py | 8 +-- espnet/nets/pytorch_backend/tacotron2/cbhg.py | 4 +- .../nets/pytorch_backend/tacotron2/decoder.py | 1 - .../nets/pytorch_backend/tacotron2/encoder.py | 5 +- .../pytorch_backend/transducer/arguments.py | 2 +- .../nets/pytorch_backend/transducer/blocks.py | 50 +++++++--------- .../pytorch_backend/transducer/conv1d_nets.py | 4 +- .../transducer/custom_decoder.py | 18 ++---- .../transducer/custom_encoder.py | 8 +-- .../transducer/error_calculator.py | 4 +- .../pytorch_backend/transducer/initializer.py | 2 +- .../pytorch_backend/transducer/rnn_decoder.py | 12 +--- .../pytorch_backend/transducer/rnn_encoder.py | 11 +--- .../transducer/transducer_tasks.py | 10 +--- .../transducer/transformer_decoder_layer.py | 8 +-- .../nets/pytorch_backend/transducer/utils.py | 10 +--- .../nets/pytorch_backend/transducer/vgg2l.py | 3 +- .../contextual_block_encoder_layer.py | 1 - .../pytorch_backend/transformer/decoder.py | 28 +++++---- .../transformer/dynamic_conv.py | 3 +- .../transformer/dynamic_conv2d.py | 3 +- .../pytorch_backend/transformer/embedding.py | 1 + .../pytorch_backend/transformer/encoder.py | 33 ++++++----- .../transformer/encoder_layer.py | 1 - .../transformer/encoder_mix.py | 9 ++- .../pytorch_backend/transformer/lightconv.py | 3 +- .../transformer/lightconv2d.py | 3 +- .../transformer/longformer_attention.py | 3 +- .../nets/pytorch_backend/transformer/plot.py | 2 +- .../transformer/subsampling.py | 3 +- .../transformer/subsampling_without_posenc.py | 1 + espnet/nets/pytorch_backend/wavenet.py | 1 - espnet/nets/scorer_interface.py | 6 +- espnet/nets/scorers/ctc.py | 3 +- espnet/nets/scorers/length_bonus.py | 4 +- espnet/nets/scorers/ngram.py | 4 +- espnet/nets/transducer_decoder_interface.py | 7 +-- espnet/nets/tts_interface.py | 1 - espnet/optimizer/chainer.py | 4 +- espnet/optimizer/pytorch.py | 4 +- espnet/st/pytorch_backend/st.py | 38 +++++------- espnet/transform/transformation.py | 7 +-- espnet/tts/pytorch_backend/tts.py | 22 +++---- espnet/utils/cli_utils.py | 2 +- espnet/utils/io_utils.py | 2 +- espnet/utils/training/iterators.py | 8 +-- espnet/utils/training/train_utils.py | 3 +- espnet/vc/pytorch_backend/vc.py | 22 +++---- espnet2/asr/decoder/abs_decoder.py | 3 +- espnet2/asr/decoder/mlm_decoder.py | 13 ++-- espnet2/asr/decoder/rnn_decoder.py | 5 +- espnet2/asr/decoder/transformer_decoder.py | 30 +++++----- espnet2/asr/encoder/abs_encoder.py | 6 +- espnet2/asr/encoder/conformer_encoder.py | 59 +++++++++---------- .../contextual_block_conformer_encoder.py | 47 +++++++-------- .../contextual_block_transformer_encoder.py | 38 ++++++------ espnet2/asr/encoder/hubert_encoder.py | 18 +++--- espnet2/asr/encoder/longformer_encoder.py | 42 ++++++------- espnet2/asr/encoder/rnn_encoder.py | 9 +-- espnet2/asr/encoder/transformer_encoder.py | 32 +++++----- espnet2/asr/encoder/vgg_rnn_encoder.py | 6 +- espnet2/asr/encoder/wav2vec2_encoder.py | 7 +-- espnet2/asr/espnet_model.py | 21 +++---- espnet2/asr/frontend/abs_frontend.py | 3 +- espnet2/asr/frontend/default.py | 6 +- espnet2/asr/frontend/fused.py | 10 ++-- espnet2/asr/frontend/s3prl.py | 10 ++-- espnet2/asr/frontend/windowing.py | 6 +- espnet2/asr/maskctc_model.py | 23 +++----- espnet2/asr/postencoder/abs_postencoder.py | 3 +- .../hugging_face_transformers_postencoder.py | 11 ++-- espnet2/asr/preencoder/abs_preencoder.py | 3 +- espnet2/asr/preencoder/linear.py | 5 +- espnet2/asr/preencoder/sinc.py | 11 ++-- espnet2/asr/specaug/abs_specaug.py | 3 +- espnet2/asr/specaug/specaug.py | 8 +-- .../asr/transducer/beam_search_transducer.py | 16 ++--- espnet2/asr/transducer/error_calculator.py | 3 +- espnet2/asr/transducer/transducer_decoder.py | 11 +--- espnet2/bin/aggregate_stats_dirs.py | 5 +- espnet2/bin/asr_align.py | 28 ++++----- espnet2/bin/asr_inference.py | 44 ++++++-------- espnet2/bin/asr_inference_k2.py | 19 ++---- espnet2/bin/asr_inference_maskctc.py | 23 +++----- espnet2/bin/asr_inference_streaming.py | 52 ++++++++-------- espnet2/bin/diar_inference.py | 18 ++---- espnet2/bin/enh_inference.py | 20 ++----- espnet2/bin/enh_scoring.py | 10 ++-- espnet2/bin/launch.py | 5 +- espnet2/bin/lm_calc_perplexity.py | 15 ++--- espnet2/bin/mt_inference.py | 30 ++++------ espnet2/bin/split_scps.py | 7 +-- espnet2/bin/st_inference.py | 30 ++++------ espnet2/bin/st_inference_streaming.py | 50 +++++++--------- espnet2/bin/tokenize_text.py | 12 ++-- espnet2/bin/tts_inference.py | 20 ++----- espnet2/diar/abs_diar.py | 3 +- espnet2/diar/attractor/abs_attractor.py | 3 +- espnet2/diar/decoder/abs_decoder.py | 3 +- espnet2/diar/espnet_model.py | 8 +-- espnet2/enh/abs_enh.py | 3 +- espnet2/enh/decoder/abs_decoder.py | 3 +- espnet2/enh/decoder/stft_decoder.py | 2 +- espnet2/enh/encoder/abs_encoder.py | 3 +- espnet2/enh/encoder/stft_encoder.py | 2 +- espnet2/enh/espnet_enh_s2t_model.py | 9 +-- espnet2/enh/espnet_model.py | 9 +-- espnet2/enh/layers/beamformer.py | 21 ++----- espnet2/enh/layers/complex_utils.py | 7 +-- espnet2/enh/layers/dc_crn.py | 4 +- espnet2/enh/layers/dnn_beamformer.py | 41 ++++++------- espnet2/enh/layers/dnn_wpe.py | 8 +-- espnet2/enh/layers/dprnn.py | 3 +- espnet2/enh/layers/ifasnet.py | 3 +- espnet2/enh/layers/mask_estimator.py | 11 ++-- espnet2/enh/layers/skim.py | 4 +- espnet2/enh/layers/wpe.py | 10 +--- espnet2/enh/loss/criterions/abs_loss.py | 4 +- espnet2/enh/loss/criterions/tf_domain.py | 13 ++-- espnet2/enh/loss/criterions/time_domain.py | 3 +- espnet2/enh/loss/wrappers/abs_wrapper.py | 7 +-- espnet2/enh/separator/abs_separator.py | 7 +-- espnet2/enh/separator/asteroid_models.py | 6 +- espnet2/enh/separator/conformer_separator.py | 16 ++--- espnet2/enh/separator/dan_separator.py | 8 +-- espnet2/enh/separator/dc_crn_separator.py | 12 +--- espnet2/enh/separator/dccrn_separator.py | 16 ++--- espnet2/enh/separator/dpcl_e2e_separator.py | 8 +-- espnet2/enh/separator/dpcl_separator.py | 8 +-- espnet2/enh/separator/dprnn_separator.py | 13 +--- espnet2/enh/separator/fasnet_separator.py | 8 +-- espnet2/enh/separator/neural_beamformer.py | 6 +- espnet2/enh/separator/rnn_separator.py | 11 +--- espnet2/enh/separator/skim_separator.py | 6 +- espnet2/enh/separator/svoice_separator.py | 10 +--- espnet2/enh/separator/tcn_separator.py | 9 +-- .../enh/separator/transformer_separator.py | 25 +++----- espnet2/fileio/datadir_writer.py | 5 +- espnet2/fileio/read_text.py | 4 +- espnet2/fileio/rttm.py | 7 +-- espnet2/fst/lm_rescore.py | 5 +- espnet2/gan_tts/abs_gan_tts.py | 7 +-- espnet2/gan_tts/espnet_model.py | 7 +-- espnet2/gan_tts/hifigan/__init__.py | 14 +++-- espnet2/gan_tts/hifigan/hifigan.py | 6 +- espnet2/gan_tts/hifigan/loss.py | 5 +- espnet2/gan_tts/hifigan/residual_block.py | 4 +- espnet2/gan_tts/jets/alignments.py | 1 - espnet2/gan_tts/jets/generator.py | 35 +++++------ espnet2/gan_tts/jets/jets.py | 27 ++++----- espnet2/gan_tts/jets/loss.py | 7 +-- espnet2/gan_tts/joint/joint_text2wav.py | 39 ++++++------ espnet2/gan_tts/melgan/melgan.py | 5 +- espnet2/gan_tts/melgan/pqmf.py | 1 - espnet2/gan_tts/melgan/residual_stack.py | 3 +- espnet2/gan_tts/parallel_wavegan/__init__.py | 8 +-- .../parallel_wavegan/parallel_wavegan.py | 10 +--- espnet2/gan_tts/parallel_wavegan/upsample.py | 5 +- espnet2/gan_tts/style_melgan/__init__.py | 6 +- espnet2/gan_tts/style_melgan/style_melgan.py | 6 +- espnet2/gan_tts/utils/__init__.py | 3 +- espnet2/gan_tts/vits/duration_predictor.py | 9 +-- espnet2/gan_tts/vits/flow.py | 8 +-- espnet2/gan_tts/vits/generator.py | 7 +-- .../gan_tts/vits/monotonic_align/__init__.py | 4 +- espnet2/gan_tts/vits/monotonic_align/setup.py | 7 +-- espnet2/gan_tts/vits/posterior_encoder.py | 7 +-- espnet2/gan_tts/vits/residual_coupling.py | 4 +- espnet2/gan_tts/vits/text_encoder.py | 1 - espnet2/gan_tts/vits/transform.py | 4 +- espnet2/gan_tts/vits/vits.py | 23 ++++---- espnet2/gan_tts/wavenet/residual_block.py | 4 +- espnet2/gan_tts/wavenet/wavenet.py | 4 +- espnet2/hubert/espnet_model.py | 11 +--- espnet2/hubert/hubert_loss.py | 2 +- espnet2/iterators/abs_iter_factory.py | 3 +- espnet2/iterators/chunk_iter_factory.py | 8 +-- espnet2/iterators/multiple_iter_factory.py | 4 +- espnet2/iterators/sequence_iter_factory.py | 4 +- espnet2/layers/abs_normalize.py | 3 +- espnet2/layers/global_mvn.py | 5 +- espnet2/layers/inversible_interface.py | 3 +- espnet2/layers/label_aggregation.py | 4 +- espnet2/layers/log_mel.py | 3 +- espnet2/layers/mask_along_axis.py | 4 +- espnet2/layers/sinc_conv.py | 3 +- espnet2/layers/stft.py | 12 ++-- espnet2/layers/utterance_mvn.py | 2 +- espnet2/lm/abs_model.py | 3 +- espnet2/lm/espnet_model.py | 6 +- espnet2/lm/seq_rnn_lm.py | 3 +- espnet2/lm/transformer_lm.py | 9 ++- espnet2/main_funcs/average_nbest_models.py | 7 +-- .../main_funcs/calculate_all_attentions.py | 26 +++----- espnet2/main_funcs/collect_stats.py | 8 +-- espnet2/main_funcs/pack_funcs.py | 12 ++-- espnet2/mt/espnet_model.py | 21 +++---- espnet2/mt/frontend/embedding.py | 9 ++- espnet2/samplers/abs_sampler.py | 6 +- espnet2/samplers/build_batch_sampler.py | 9 +-- espnet2/samplers/folded_batch_sampler.py | 9 +-- espnet2/samplers/length_batch_sampler.py | 5 +- .../samplers/num_elements_batch_sampler.py | 5 +- espnet2/samplers/sorted_batch_sampler.py | 3 +- espnet2/samplers/unsorted_batch_sampler.py | 3 +- espnet2/schedulers/abs_scheduler.py | 3 +- espnet2/schedulers/noam_lr.py | 2 +- espnet2/st/espnet_model.py | 23 +++----- espnet2/tasks/abs_task.py | 54 ++++++----------- espnet2/tasks/asr.py | 53 ++++++----------- espnet2/tasks/diar.py | 14 +---- espnet2/tasks/enh.py | 31 ++++------ espnet2/tasks/enh_s2t.py | 32 ++++------ espnet2/tasks/gan_tts.py | 19 ++---- espnet2/tasks/hubert.py | 21 ++----- espnet2/tasks/lm.py | 14 +---- espnet2/tasks/mt.py | 41 +++++-------- espnet2/tasks/st.py | 47 ++++++--------- espnet2/tasks/tts.py | 19 ++---- espnet2/text/abs_tokenizer.py | 6 +- espnet2/text/build_tokenizer.py | 3 +- espnet2/text/char_tokenizer.py | 6 +- espnet2/text/cleaner.py | 2 +- espnet2/text/phoneme_tokenizer.py | 23 +++----- espnet2/text/sentencepiece_tokenizer.py | 4 +- espnet2/text/token_id_converter.py | 5 +- espnet2/text/word_tokenizer.py | 6 +- espnet2/torch_utils/initialize.py | 1 + espnet2/torch_utils/load_pretrained_model.py | 6 +- espnet2/train/abs_espnet_model.py | 6 +- espnet2/train/abs_gan_espnet_model.py | 6 +- espnet2/train/class_choices.py | 7 +-- espnet2/train/collate_fn.py | 9 +-- espnet2/train/dataset.py | 21 ++----- espnet2/train/gan_trainer.py | 21 ++----- espnet2/train/iterable_dataset.py | 7 +-- espnet2/train/preprocessor.py | 12 +--- espnet2/train/reporter.py | 20 ++----- espnet2/train/trainer.py | 33 ++++------- espnet2/tts/abs_tts.py | 6 +- espnet2/tts/espnet_model.py | 7 +-- espnet2/tts/fastspeech/fastspeech.py | 39 +++++------- espnet2/tts/fastspeech2/fastspeech2.py | 34 +++++------ espnet2/tts/fastspeech2/loss.py | 6 +- espnet2/tts/fastspeech2/variance_predictor.py | 1 - .../tts/feats_extract/abs_feats_extract.py | 7 +-- espnet2/tts/feats_extract/dio.py | 9 +-- espnet2/tts/feats_extract/energy.py | 8 +-- .../tts/feats_extract/linear_spectrogram.py | 5 +- espnet2/tts/feats_extract/log_mel_fbank.py | 6 +- espnet2/tts/feats_extract/log_spectrogram.py | 5 +- espnet2/tts/gst/style_encoder.py | 7 +-- espnet2/tts/tacotron2/tacotron2.py | 22 +++---- espnet2/tts/transformer/transformer.py | 35 ++++++----- espnet2/tts/utils/__init__.py | 5 +- .../parallel_wavegan_pretrained_vocoder.py | 7 +-- espnet2/utils/griffin_lim.py | 5 +- espnet2/utils/types.py | 4 +- setup.py | 5 +- test/espnet2/asr/decoder/test_rnn_decoder.py | 2 +- .../asr/decoder/test_transformer_decoder.py | 21 +++---- ...st_contextual_block_transformer_encoder.py | 5 +- .../asr/encoder/test_longformer_encoder.py | 3 +- test/espnet2/asr/frontend/test_fused.py | 2 +- test/espnet2/asr/frontend/test_s3prl.py | 3 +- ...t_hugging_face_transformers_postencoder.py | 5 +- test/espnet2/asr/preencoder/test_linear.py | 3 +- test/espnet2/asr/preencoder/test_sinc.py | 4 +- test/espnet2/asr/test_maskctc_model.py | 3 +- test/espnet2/bin/test_aggregate_stats_dirs.py | 3 +- test/espnet2/bin/test_asr_align.py | 8 +-- test/espnet2/bin/test_asr_inference.py | 8 +-- test/espnet2/bin/test_asr_inference_k2.py | 3 +- .../espnet2/bin/test_asr_inference_maskctc.py | 8 +-- test/espnet2/bin/test_asr_train.py | 3 +- test/espnet2/bin/test_diar_inference.py | 4 +- test/espnet2/bin/test_diar_train.py | 3 +- test/espnet2/bin/test_enh_inference.py | 6 +- test/espnet2/bin/test_enh_s2t_train.py | 3 +- test/espnet2/bin/test_enh_scoring.py | 3 +- test/espnet2/bin/test_enh_train.py | 3 +- test/espnet2/bin/test_hubert_train.py | 3 +- test/espnet2/bin/test_lm_calc_perplexity.py | 3 +- test/espnet2/bin/test_lm_train.py | 3 +- test/espnet2/bin/test_pack.py | 3 +- test/espnet2/bin/test_st_inference.py | 8 +-- test/espnet2/bin/test_st_train.py | 3 +- test/espnet2/bin/test_tokenize_text.py | 3 +- test/espnet2/bin/test_tts_inference.py | 6 +- test/espnet2/bin/test_tts_train.py | 3 +- test/espnet2/enh/decoder/test_stft_decoder.py | 1 - test/espnet2/enh/layers/test_complex_utils.py | 15 ++--- test/espnet2/enh/layers/test_conv_utils.py | 4 +- test/espnet2/enh/layers/test_enh_layers.py | 10 ++-- .../enh/loss/criterions/test_tf_domain.py | 14 ++--- .../enh/loss/criterions/test_time_domain.py | 10 ++-- .../wrappers/test_multilayer_pit_solver.py | 1 - .../enh/loss/wrappers/test_pit_solver.py | 4 +- test/espnet2/enh/separator/test_beamformer.py | 3 +- .../enh/separator/test_conformer_separator.py | 1 - .../enh/separator/test_dan_separator.py | 1 - .../enh/separator/test_dc_crn_separator.py | 4 +- .../enh/separator/test_dccrn_separator.py | 3 +- .../enh/separator/test_dpcl_e2e_separator.py | 1 - .../enh/separator/test_dpcl_separator.py | 1 - .../enh/separator/test_dprnn_separator.py | 1 - .../enh/separator/test_fasnet_separator.py | 1 - .../enh/separator/test_rnn_separator.py | 1 - .../enh/separator/test_skim_separator.py | 1 - .../enh/separator/test_svoice_separator.py | 1 - .../enh/separator/test_tcn_separator.py | 1 - .../separator/test_transformer_separator.py | 1 - test/espnet2/enh/test_espnet_enh_s2t_model.py | 1 - test/espnet2/enh/test_espnet_model.py | 8 +-- test/espnet2/fileio/test_npy_scp.py | 6 +- test/espnet2/fileio/test_read_text.py | 3 +- test/espnet2/gan_tts/hifigan/test_hifigan.py | 12 ++-- .../gan_tts/joint/test_joint_text2wav.py | 3 +- test/espnet2/gan_tts/melgan/test_melgan.py | 10 ++-- .../parallel_wavegan/test_parallel_wavegan.py | 13 ++-- .../gan_tts/style_melgan/test_style_melgan.py | 11 ++-- test/espnet2/hubert/test_hubert_loss.py | 5 +- .../iterators/test_chunk_iter_factory.py | 4 +- test/espnet2/layers/test_sinc_filters.py | 6 +- test/espnet2/lm/test_seq_rnn_lm.py | 2 +- test/espnet2/lm/test_transformer_lm.py | 2 +- .../test_calculate_all_attentions.py | 8 ++- test/espnet2/main_funcs/test_pack_funcs.py | 7 +-- .../text/test_sentencepiece_tokenizer.py | 2 +- test/espnet2/torch_utils/test_device_funcs.py | 3 +- test/espnet2/train/test_collate_fn.py | 3 +- test/espnet2/train/test_distributed_utils.py | 7 +-- test/espnet2/train/test_reporter.py | 7 +-- .../tts/feats_extract/test_log_mel_fbank.py | 2 +- .../tts/feats_extract/test_log_spectrogram.py | 2 +- test/espnet2/utils/test_build_dataclass.py | 2 +- test/espnet2/utils/test_sized_dict.py | 3 +- test/espnet2/utils/test_types.py | 13 ++-- test/test_asr_init.py | 5 +- test/test_batch_beam_search.py | 9 +-- test/test_custom_transducer.py | 8 +-- test/test_e2e_asr.py | 5 +- test/test_e2e_asr_conformer.py | 1 + test/test_e2e_asr_maskctc.py | 1 + test/test_e2e_asr_mulenc.py | 2 +- test/test_e2e_asr_transducer.py | 8 +-- test/test_e2e_asr_transformer.py | 7 ++- test/test_e2e_compatibility.py | 6 +- test/test_e2e_mt.py | 2 +- test/test_e2e_mt_transformer.py | 1 + test/test_e2e_st.py | 2 +- test/test_e2e_st_conformer.py | 1 + test/test_e2e_st_transformer.py | 1 + test/test_e2e_tts_fastspeech.py | 12 ++-- test/test_e2e_tts_tacotron2.py | 7 +-- test/test_e2e_tts_transformer.py | 8 +-- test/test_e2e_vc_tacotron2.py | 7 +-- test/test_e2e_vc_transformer.py | 8 +-- test/test_lm.py | 7 +-- test/test_multi_spkrs.py | 4 +- test/test_ngram.py | 4 +- test/test_positional_encoding.py | 6 +- test/test_recog.py | 2 +- test/test_scheduler.py | 8 +-- test/test_sentencepiece.py | 1 - test/test_transformer_decode.py | 1 - test/test_utils.py | 6 +- utils/addjson.py | 1 - utils/apply-cmvn.py | 5 +- utils/calculate_rtf.py | 3 +- utils/compute-cmvn-stats.py | 3 +- utils/compute-fbank-feats.py | 4 +- utils/compute-stft-feats.py | 4 +- utils/convert_fbank_to_wav.py | 4 +- utils/copy-feats.py | 5 +- utils/dump-pcm.py | 2 +- utils/eval-source-separation.py | 10 ++-- utils/eval_perm_free_error.py | 2 +- utils/feat-to-shape.py | 3 +- utils/feats2npy.py | 7 ++- utils/generate_wav_from_fbank.py | 6 +- utils/json2sctm.py | 5 +- utils/make_pair_json.py | 2 +- utils/mcd_calculate.py | 5 +- utils/merge_scp2json.py | 4 +- utils/spm_train | 1 - utils/text2vocabulary.py | 3 +- 608 files changed, 2047 insertions(+), 3066 deletions(-) diff --git a/egs/arctic/tts1/local/clean_text.py b/egs/arctic/tts1/local/clean_text.py index 7b14f47a61a..6fd5ce649e0 100755 --- a/egs/arctic/tts1/local/clean_text.py +++ b/egs/arctic/tts1/local/clean_text.py @@ -8,7 +8,6 @@ from tacotron_cleaner.cleaners import custom_english_cleaners - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("text", type=str, help="text to be cleaned") diff --git a/egs/chime6/asr1/local/extract_noises.py b/egs/chime6/asr1/local/extract_noises.py index 7c96666b5c9..b79e6fcaeaa 100755 --- a/egs/chime6/asr1/local/extract_noises.py +++ b/egs/chime6/asr1/local/extract_noises.py @@ -6,11 +6,12 @@ import argparse import json import math -import numpy as np import os -import scipy.io.wavfile as siw import sys +import numpy as np +import scipy.io.wavfile as siw + def get_args(): parser = argparse.ArgumentParser( diff --git a/egs/chime6/asr1/local/make_noise_list.py b/egs/chime6/asr1/local/make_noise_list.py index 1674bb71b4d..b8f84fc3fed 100755 --- a/egs/chime6/asr1/local/make_noise_list.py +++ b/egs/chime6/asr1/local/make_noise_list.py @@ -7,7 +7,6 @@ import os import sys - if len(sys.argv) != 2: print("Usage: {} ".format(sys.argv[0])) raise SystemExit(1) diff --git a/egs/cmu_indic/tts1/local/clean_text.py b/egs/cmu_indic/tts1/local/clean_text.py index 7b14f47a61a..6fd5ce649e0 100755 --- a/egs/cmu_indic/tts1/local/clean_text.py +++ b/egs/cmu_indic/tts1/local/clean_text.py @@ -8,7 +8,6 @@ from tacotron_cleaner.cleaners import custom_english_cleaners - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("text", type=str, help="text to be cleaned") diff --git a/egs/covost2/st1/local/process_tsv.py b/egs/covost2/st1/local/process_tsv.py index 2c46d83df75..84609ad1f68 100755 --- a/egs/covost2/st1/local/process_tsv.py +++ b/egs/covost2/st1/local/process_tsv.py @@ -5,8 +5,8 @@ import argparse import codecs -import pandas as pd +import pandas as pd parser = argparse.ArgumentParser(description="extract translation from tsv file") parser.add_argument("tsv_path", type=str, default=None, help="input tsv path") diff --git a/egs/csj/asr1/local/csj_rm_tag.py b/egs/csj/asr1/local/csj_rm_tag.py index 0a23ca59708..dfe5ba5e4f3 100755 --- a/egs/csj/asr1/local/csj_rm_tag.py +++ b/egs/csj/asr1/local/csj_rm_tag.py @@ -6,9 +6,8 @@ import argparse import codecs -from io import open import sys - +from io import open PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) diff --git a/egs/iwslt16/mt1/local/extract_recog_text.py b/egs/iwslt16/mt1/local/extract_recog_text.py index bf2dbdfda9e..bba75a17b9a 100755 --- a/egs/iwslt16/mt1/local/extract_recog_text.py +++ b/egs/iwslt16/mt1/local/extract_recog_text.py @@ -4,9 +4,9 @@ """ import argparse import glob -from itertools import takewhile import json import os +from itertools import takewhile def get_args(): diff --git a/egs/iwslt16/mt1/local/generate_json.py b/egs/iwslt16/mt1/local/generate_json.py index 2dd4d66a098..4e81eb8d7f1 100755 --- a/egs/iwslt16/mt1/local/generate_json.py +++ b/egs/iwslt16/mt1/local/generate_json.py @@ -5,11 +5,9 @@ """ import argparse import json -from logging import getLogger import os -from typing import Dict -from typing import List - +from logging import getLogger +from typing import Dict, List logger = getLogger(__name__) diff --git a/egs/iwslt16/mt1/local/generate_vocab.py b/egs/iwslt16/mt1/local/generate_vocab.py index c97c4c069c5..f060d3b4aae 100755 --- a/egs/iwslt16/mt1/local/generate_vocab.py +++ b/egs/iwslt16/mt1/local/generate_vocab.py @@ -6,8 +6,8 @@ format: token + whitespace + index """ import argparse -from collections import defaultdict import fileinput +from collections import defaultdict def get_args(): diff --git a/egs/iwslt18/st1/local/parse_xml.py b/egs/iwslt18/st1/local/parse_xml.py index e42f8e2c79e..067926ee50f 100755 --- a/egs/iwslt18/st1/local/parse_xml.py +++ b/egs/iwslt18/st1/local/parse_xml.py @@ -6,10 +6,10 @@ import argparse import codecs -from collections import OrderedDict import os import re import xml.etree.ElementTree as etree +from collections import OrderedDict def main(): diff --git a/egs/iwslt21/asr1/local/filter_parentheses.py b/egs/iwslt21/asr1/local/filter_parentheses.py index 8c27bf39d27..b0c77d3a314 100755 --- a/egs/iwslt21/asr1/local/filter_parentheses.py +++ b/egs/iwslt21/asr1/local/filter_parentheses.py @@ -7,6 +7,7 @@ import argparse import codecs import re + import regex parser = argparse.ArgumentParser() diff --git a/egs/iwslt21_low_resource/st1/local/data_prep.py b/egs/iwslt21_low_resource/st1/local/data_prep.py index 75153cc426f..60df7d00d8c 100644 --- a/egs/iwslt21_low_resource/st1/local/data_prep.py +++ b/egs/iwslt21_low_resource/st1/local/data_prep.py @@ -1,7 +1,6 @@ import argparse import os - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert data into kaldi format") parser.add_argument("data_dir", type=str) diff --git a/egs/jnas/asr1/local/filter_text.py b/egs/jnas/asr1/local/filter_text.py index db35c1754da..c5b000ce4c0 100755 --- a/egs/jnas/asr1/local/filter_text.py +++ b/egs/jnas/asr1/local/filter_text.py @@ -6,9 +6,8 @@ import argparse import codecs -from io import open import sys - +from io import open PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) diff --git a/egs/ksponspeech/asr1/local/get_space_normalized_hyps.py b/egs/ksponspeech/asr1/local/get_space_normalized_hyps.py index c105b47c578..1f5225bfe83 100755 --- a/egs/ksponspeech/asr1/local/get_space_normalized_hyps.py +++ b/egs/ksponspeech/asr1/local/get_space_normalized_hyps.py @@ -4,11 +4,11 @@ # Copyright 2020 Electronics and Telecommunications Research Institute (Jeong-Uk, Bang) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import configargparse import logging import os import sys +import configargparse from numpy import zeros space_sym = "▁" diff --git a/egs/ksponspeech/asr1/local/get_transcriptions.py b/egs/ksponspeech/asr1/local/get_transcriptions.py index 9d1db4b9225..771c377641f 100644 --- a/egs/ksponspeech/asr1/local/get_transcriptions.py +++ b/egs/ksponspeech/asr1/local/get_transcriptions.py @@ -5,13 +5,14 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import codecs -import configargparse import logging import os import re import shutil import sys +import configargparse + def get_parser(): """Get default arguments.""" diff --git a/egs/libri_css/asr1/local/best_wer_matching.py b/egs/libri_css/asr1/local/best_wer_matching.py index d9688496ad6..67e1d4b808a 100755 --- a/egs/libri_css/asr1/local/best_wer_matching.py +++ b/egs/libri_css/asr1/local/best_wer_matching.py @@ -5,9 +5,10 @@ import io import itertools import math +import sys + import numpy as np from scipy.optimize import linear_sum_assignment -import sys # Helper function to group the list by ref/hyp ids diff --git a/egs/libri_css/asr1/local/get_perspeaker_output.py b/egs/libri_css/asr1/local/get_perspeaker_output.py index 3dcdfae1340..3f0361ca320 100755 --- a/egs/libri_css/asr1/local/get_perspeaker_output.py +++ b/egs/libri_css/asr1/local/get_perspeaker_output.py @@ -5,9 +5,9 @@ into per_speaker output (text) file""" import argparse -from collections import defaultdict import itertools import os +from collections import defaultdict def get_args(): diff --git a/egs/libri_css/asr1/local/prepare_data.py b/egs/libri_css/asr1/local/prepare_data.py index f5b2e409f5c..f3800935c47 100755 --- a/egs/libri_css/asr1/local/prepare_data.py +++ b/egs/libri_css/asr1/local/prepare_data.py @@ -7,6 +7,7 @@ import argparse import glob import os + import soundfile as sf import tqdm diff --git a/egs/libri_css/asr1/local/segmentation/apply_webrtcvad.py b/egs/libri_css/asr1/local/segmentation/apply_webrtcvad.py index 08ca2f9d765..e30005fd518 100755 --- a/egs/libri_css/asr1/local/segmentation/apply_webrtcvad.py +++ b/egs/libri_css/asr1/local/segmentation/apply_webrtcvad.py @@ -12,6 +12,7 @@ import os import sys import wave + import webrtcvad diff --git a/egs/ljspeech/tts1/local/clean_text.py b/egs/ljspeech/tts1/local/clean_text.py index 14c6721ece4..ee7c5fcfa1f 100755 --- a/egs/ljspeech/tts1/local/clean_text.py +++ b/egs/ljspeech/tts1/local/clean_text.py @@ -5,8 +5,8 @@ import argparse import codecs -import nltk +import nltk from tacotron_cleaner.cleaners import custom_english_cleaners try: diff --git a/egs/lrs/avsr1/local/se_batch.py b/egs/lrs/avsr1/local/se_batch.py index c5f0a58bf6b..6b78ee965eb 100755 --- a/egs/lrs/avsr1/local/se_batch.py +++ b/egs/lrs/avsr1/local/se_batch.py @@ -5,11 +5,12 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.""" -from deepxi.utils import read_wav import glob -import numpy as np import os +import numpy as np +from deepxi.utils import read_wav + def Batch(fdir, snr_l=[]): """REQUIRES REWRITING. WILL BE MOVED TO deepxi/utils.py diff --git a/egs/mgb2/asr1/local/process_xml.py b/egs/mgb2/asr1/local/process_xml.py index dadfb97845e..e0fa189d083 100644 --- a/egs/mgb2/asr1/local/process_xml.py +++ b/egs/mgb2/asr1/local/process_xml.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 import argparse -from bs4 import BeautifulSoup import sys +from bs4 import BeautifulSoup + def get_args(): parser = argparse.ArgumentParser(description="""This script process xml file.""") diff --git a/egs/mgb2/asr1/local/text_segmenting.py b/egs/mgb2/asr1/local/text_segmenting.py index ec9004a20b1..6cfa58fb135 100644 --- a/egs/mgb2/asr1/local/text_segmenting.py +++ b/egs/mgb2/asr1/local/text_segmenting.py @@ -4,6 +4,7 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse + import pandas as pd diff --git a/egs/polyphone_swiss_french/asr1/local/data_prep.py b/egs/polyphone_swiss_french/asr1/local/data_prep.py index 6926ceabc1f..6a41a0b8717 100755 --- a/egs/polyphone_swiss_french/asr1/local/data_prep.py +++ b/egs/polyphone_swiss_french/asr1/local/data_prep.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 -from collections import defaultdict import os import pathlib -from random import shuffle import re import subprocess import sys +from collections import defaultdict +from random import shuffle class FrPolyphonePrepper: @@ -570,6 +570,7 @@ def _generate_random(self, corpus, splits): if __name__ == "__main__": import argparse + import yaml example = "{0} --config conf/dataprep.yml".format(sys.argv[0]) diff --git a/egs/puebla_nahuatl/asr1/local/construct_dataset.py b/egs/puebla_nahuatl/asr1/local/construct_dataset.py index 752c915e22f..fd471e44ebf 100644 --- a/egs/puebla_nahuatl/asr1/local/construct_dataset.py +++ b/egs/puebla_nahuatl/asr1/local/construct_dataset.py @@ -1,5 +1,4 @@ import os - from argparse import ArgumentParser diff --git a/egs/puebla_nahuatl/asr1/local/data_prep.py b/egs/puebla_nahuatl/asr1/local/data_prep.py index 959d9a91250..6d90f9a0d4c 100755 --- a/egs/puebla_nahuatl/asr1/local/data_prep.py +++ b/egs/puebla_nahuatl/asr1/local/data_prep.py @@ -5,11 +5,9 @@ import shutil import string import sys - from argparse import ArgumentParser from xml.dom.minidom import parse - s = "".join(chr(c) for c in range(sys.maxunicode + 1)) ws = "".join(re.findall(r"\s", s)) outtab = " " * len(ws) diff --git a/egs/puebla_nahuatl/st1/local/data_prep.py b/egs/puebla_nahuatl/st1/local/data_prep.py index 74a39fdf478..3d07917fdbc 100644 --- a/egs/puebla_nahuatl/st1/local/data_prep.py +++ b/egs/puebla_nahuatl/st1/local/data_prep.py @@ -1,10 +1,10 @@ # -*- coding: UTF-8 -*- -from argparse import ArgumentParser import os import re import string import sys +from argparse import ArgumentParser from xml.dom.minidom import parse s = "".join(chr(c) for c in range(sys.maxunicode + 1)) diff --git a/egs/reverb/asr1/local/filterjson.py b/egs/reverb/asr1/local/filterjson.py index 00dff00fca3..400177e3d17 100755 --- a/egs/reverb/asr1/local/filterjson.py +++ b/egs/reverb/asr1/local/filterjson.py @@ -6,12 +6,11 @@ import argparse import codecs -from io import open import json import logging import re import sys - +from io import open PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) diff --git a/egs/reverb/asr1/local/run_wpe.py b/egs/reverb/asr1/local/run_wpe.py index 309cf609d90..84d21b3b5c7 100755 --- a/egs/reverb/asr1/local/run_wpe.py +++ b/egs/reverb/asr1/local/run_wpe.py @@ -6,12 +6,11 @@ import argparse import errno import os -import soundfile as sf -from nara_wpe.utils import istft -from nara_wpe.utils import stft -from nara_wpe.wpe import wpe import numpy as np +import soundfile as sf +from nara_wpe.utils import istft, stft +from nara_wpe.wpe import wpe parser = argparse.ArgumentParser() parser.add_argument("--files", "-f", nargs="+") diff --git a/egs/reverb/asr1_multich/local/filterjson.py b/egs/reverb/asr1_multich/local/filterjson.py index 8841d546dc2..400177e3d17 100755 --- a/egs/reverb/asr1_multich/local/filterjson.py +++ b/egs/reverb/asr1_multich/local/filterjson.py @@ -6,11 +6,11 @@ import argparse import codecs -from io import open import json import logging import re import sys +from io import open PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) diff --git a/egs/tweb/tts1/local/clean_text.py b/egs/tweb/tts1/local/clean_text.py index 07a34438f24..c7634744928 100755 --- a/egs/tweb/tts1/local/clean_text.py +++ b/egs/tweb/tts1/local/clean_text.py @@ -8,7 +8,6 @@ from tacotron_cleaner.cleaners import custom_english_cleaners - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("text", type=str, help="text to be cleaned") diff --git a/egs/vais1000/tts1/local/clean_text.py b/egs/vais1000/tts1/local/clean_text.py index 8f89b943092..d1e320c654e 100755 --- a/egs/vais1000/tts1/local/clean_text.py +++ b/egs/vais1000/tts1/local/clean_text.py @@ -8,7 +8,6 @@ from vietnamese_cleaner.vietnamese_cleaners import vietnamese_cleaner - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("text", type=str, help="text to be cleaned") diff --git a/egs/vcc20/tts1_en_fi/local/clean_text_css10.py b/egs/vcc20/tts1_en_fi/local/clean_text_css10.py index 2b64d394028..df477d1e59c 100755 --- a/egs/vcc20/tts1_en_fi/local/clean_text_css10.py +++ b/egs/vcc20/tts1_en_fi/local/clean_text_css10.py @@ -9,13 +9,10 @@ import os import nltk -from tacotron_cleaner.cleaners import collapse_whitespace -from tacotron_cleaner.cleaners import expand_abbreviations -from tacotron_cleaner.cleaners import expand_numbers -from tacotron_cleaner.cleaners import expand_symbols -from tacotron_cleaner.cleaners import lowercase -from tacotron_cleaner.cleaners import remove_unnecessary_symbols -from tacotron_cleaner.cleaners import uppercase +from tacotron_cleaner.cleaners import (collapse_whitespace, + expand_abbreviations, expand_numbers, + expand_symbols, lowercase, + remove_unnecessary_symbols, uppercase) try: # For phoneme conversion, use https://github.com/Kyubyong/g2p. diff --git a/egs/vcc20/vc1_task1/local/clean_text_asr_result.py b/egs/vcc20/vc1_task1/local/clean_text_asr_result.py index 9dc253855e6..381f9ffaf6b 100755 --- a/egs/vcc20/vc1_task1/local/clean_text_asr_result.py +++ b/egs/vcc20/vc1_task1/local/clean_text_asr_result.py @@ -5,8 +5,8 @@ import argparse import codecs -import nltk +import nltk from tacotron_cleaner.cleaners import custom_english_cleaners try: diff --git a/egs/vcc20/vc1_task2/local/clean_text_finnish.py b/egs/vcc20/vc1_task2/local/clean_text_finnish.py index fbbe6fa8a76..ac28a8be708 100755 --- a/egs/vcc20/vc1_task2/local/clean_text_finnish.py +++ b/egs/vcc20/vc1_task2/local/clean_text_finnish.py @@ -5,16 +5,13 @@ import argparse import codecs -import nltk -from tacotron_cleaner.cleaners import collapse_whitespace -from tacotron_cleaner.cleaners import custom_english_cleaners -from tacotron_cleaner.cleaners import expand_abbreviations -from tacotron_cleaner.cleaners import expand_numbers -from tacotron_cleaner.cleaners import expand_symbols -from tacotron_cleaner.cleaners import lowercase -from tacotron_cleaner.cleaners import remove_unnecessary_symbols -from tacotron_cleaner.cleaners import uppercase +import nltk +from tacotron_cleaner.cleaners import (collapse_whitespace, + custom_english_cleaners, + expand_abbreviations, expand_numbers, + expand_symbols, lowercase, + remove_unnecessary_symbols, uppercase) E_lang_tag = "en_US" diff --git a/egs/vcc20/vc1_task2/local/clean_text_german.py b/egs/vcc20/vc1_task2/local/clean_text_german.py index b9123de1578..a10fd4e8f2e 100755 --- a/egs/vcc20/vc1_task2/local/clean_text_german.py +++ b/egs/vcc20/vc1_task2/local/clean_text_german.py @@ -5,11 +5,10 @@ import argparse import codecs -import nltk +import nltk from tacotron_cleaner.cleaners import custom_english_cleaners - E_lang_tag = "en_US" try: diff --git a/egs/vcc20/vc1_task2/local/clean_text_mandarin.py b/egs/vcc20/vc1_task2/local/clean_text_mandarin.py index e1932ceebd0..9a2784f0a2c 100755 --- a/egs/vcc20/vc1_task2/local/clean_text_mandarin.py +++ b/egs/vcc20/vc1_task2/local/clean_text_mandarin.py @@ -5,14 +5,13 @@ import argparse import codecs -import nltk +import nltk +from pypinyin import Style from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin from pypinyin.converter import DefaultConverter from pypinyin.core import Pinyin -from pypinyin import Style -from pypinyin.style._utils import get_finals -from pypinyin.style._utils import get_initials +from pypinyin.style._utils import get_finals, get_initials from tacotron_cleaner.cleaners import custom_english_cleaners diff --git a/egs/vcc20/voc1/local/subset_data_dir.py b/egs/vcc20/voc1/local/subset_data_dir.py index 841d0fb2bfc..968cd3d02d1 100755 --- a/egs/vcc20/voc1/local/subset_data_dir.py +++ b/egs/vcc20/voc1/local/subset_data_dir.py @@ -5,8 +5,8 @@ # consisting of some specified number of utterances. import argparse -from io import open import sys +from io import open def get_parser(): diff --git a/egs/voxforge/asr1/local/filter_text.py b/egs/voxforge/asr1/local/filter_text.py index db35c1754da..c5b000ce4c0 100755 --- a/egs/voxforge/asr1/local/filter_text.py +++ b/egs/voxforge/asr1/local/filter_text.py @@ -6,9 +6,8 @@ import argparse import codecs -from io import open import sys - +from io import open PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) diff --git a/egs/wsj/asr1/local/filtering_samples.py b/egs/wsj/asr1/local/filtering_samples.py index 27766d43e58..4b91b004373 100755 --- a/egs/wsj/asr1/local/filtering_samples.py +++ b/egs/wsj/asr1/local/filtering_samples.py @@ -4,16 +4,15 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -from functools import reduce import json -from operator import mul import sys +from functools import reduce +from operator import mul from espnet.bin.asr_train import get_parser from espnet.nets.pytorch_backend.nets_utils import get_subsample from espnet.utils.dynamic_import import dynamic_import - if __name__ == "__main__": cmd_args = sys.argv[1:] parser = get_parser(required=False) diff --git a/egs/wsj_mix/asr1/local/merge_scp2json.py b/egs/wsj_mix/asr1/local/merge_scp2json.py index 52260785b9d..7cf55f2d35f 100755 --- a/egs/wsj_mix/asr1/local/merge_scp2json.py +++ b/egs/wsj_mix/asr1/local/merge_scp2json.py @@ -3,10 +3,10 @@ import argparse import codecs -from io import open import json import logging import sys +from io import open from espnet.utils.cli_utils import get_commandline_args diff --git a/egs/wsj_mix/asr1/local/mergejson.py b/egs/wsj_mix/asr1/local/mergejson.py index 0926a858469..8b965cb97e5 100755 --- a/egs/wsj_mix/asr1/local/mergejson.py +++ b/egs/wsj_mix/asr1/local/mergejson.py @@ -11,7 +11,6 @@ import logging import sys - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("jsons", type=str, nargs="+", help="json files") diff --git a/egs/yoloxochitl_mixtec/asr1/local/data_prep.py b/egs/yoloxochitl_mixtec/asr1/local/data_prep.py index 91fcee41249..e96e633df00 100755 --- a/egs/yoloxochitl_mixtec/asr1/local/data_prep.py +++ b/egs/yoloxochitl_mixtec/asr1/local/data_prep.py @@ -1,14 +1,15 @@ # -*- coding: UTF-8 -*- -from argparse import ArgumentParser import os import re import shutil -import soundfile as sf import string import sys +from argparse import ArgumentParser from xml.dom.minidom import parse +import soundfile as sf + s = "".join(chr(c) for c in range(sys.maxunicode + 1)) ws = "".join(re.findall(r"\s", s)) outtab = " " * len(ws) diff --git a/egs2/TEMPLATE/asr1/pyscripts/audio/format_wav_scp.py b/egs2/TEMPLATE/asr1/pyscripts/audio/format_wav_scp.py index cca465bb93c..06bb01f926b 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/audio/format_wav_scp.py +++ b/egs2/TEMPLATE/asr1/pyscripts/audio/format_wav_scp.py @@ -3,19 +3,19 @@ import logging from io import BytesIO from pathlib import Path -from typing import Tuple, Optional +from typing import Optional, Tuple -import kaldiio import humanfriendly +import kaldiio import numpy as np import resampy import soundfile from tqdm import tqdm from typeguard import check_argument_types -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.read_text import read_2column_text from espnet2.fileio.sound_scp import SoundScpWriter +from espnet.utils.cli_utils import get_commandline_args def humanfriendly_or_none(value: str): diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py b/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py index a6605409f15..052b23ca636 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py @@ -9,9 +9,7 @@ import codecs import contextlib -from joblib import delayed -from joblib import Parallel -from joblib import parallel +from joblib import Parallel, delayed, parallel from tqdm import tqdm from espnet2.text.cleaner import TextCleaner diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_f0.py b/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_f0.py index e27e57624ee..bc9a3709f99 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_f0.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_f0.py @@ -10,17 +10,13 @@ import logging import multiprocessing as mp import os - -from typing import Dict -from typing import List -from typing import Tuple +from typing import Dict, List, Tuple import librosa import numpy as np import pysptk import pyworld as pw import soundfile as sf - from fastdtw import fastdtw from scipy import spatial diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_mcd.py b/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_mcd.py index 379438217ea..213dc60b563 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_mcd.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/evaluate_mcd.py @@ -10,16 +10,12 @@ import logging import multiprocessing as mp import os - -from typing import Dict -from typing import List -from typing import Tuple +from typing import Dict, List, Tuple import librosa import numpy as np import pysptk import soundfile as sf - from fastdtw import fastdtw from scipy import spatial diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/extract_xvectors.py b/egs2/TEMPLATE/asr1/pyscripts/utils/extract_xvectors.py index e64b82dc515..a58a844be0a 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/extract_xvectors.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/extract_xvectors.py @@ -3,14 +3,14 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -import kaldiio import logging -from pathlib import Path -import sys -import torch import os -import numpy as np +import sys +from pathlib import Path +import kaldiio +import numpy as np +import torch from tqdm.contrib import tqdm from espnet2.fileio.sound_scp import SoundScpReader diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/plot_sinc_filters.py b/egs2/TEMPLATE/asr1/pyscripts/utils/plot_sinc_filters.py index 001ba49d34b..56e06d73d51 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/plot_sinc_filters.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/plot_sinc_filters.py @@ -12,10 +12,11 @@ """ import argparse +import sys +from pathlib import Path + import matplotlib.pyplot as plt import numpy as np -from pathlib import Path -import sys import torch diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/rotate_logfile.py b/egs2/TEMPLATE/asr1/pyscripts/utils/rotate_logfile.py index e30c7a1e682..aa2818d3a9f 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/rotate_logfile.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/rotate_logfile.py @@ -6,8 +6,8 @@ """Rotate log-file.""" import argparse -from pathlib import Path import shutil +from pathlib import Path def rotate(path, max_num_log_files=1000): diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/score_intent.py b/egs2/TEMPLATE/asr1/pyscripts/utils/score_intent.py index 4f0f074c9db..ccfba96010d 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/score_intent.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/score_intent.py @@ -5,11 +5,12 @@ # Apache 2.0 +import argparse import os import re import sys + import pandas as pd -import argparse def get_classification_result(hyp_file, ref_file, hyp_write, ref_write): diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/score_summarization.py b/egs2/TEMPLATE/asr1/pyscripts/utils/score_summarization.py index 35202f1ce88..781ecebfd12 100644 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/score_summarization.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/score_summarization.py @@ -1,10 +1,9 @@ -import sys import os -from datasets import load_metric -import numpy as np -from nlgeval import compute_metrics -from nlgeval import NLGEval +import sys +import numpy as np +from datasets import load_metric +from nlgeval import NLGEval, compute_metrics ref_file = sys.argv[1] hyp_file = sys.argv[2] diff --git a/egs2/TEMPLATE/diar1/pyscripts/utils/convert_rttm.py b/egs2/TEMPLATE/diar1/pyscripts/utils/convert_rttm.py index d5d4b257b36..e3e1047d7bb 100755 --- a/egs2/TEMPLATE/diar1/pyscripts/utils/convert_rttm.py +++ b/egs2/TEMPLATE/diar1/pyscripts/utils/convert_rttm.py @@ -1,19 +1,20 @@ #!/usr/bin/env python3 +import argparse import collections.abc -import humanfriendly +import logging +import os +import re from pathlib import Path from typing import Union -import argparse -import logging +import humanfriendly import numpy as np -import re -import os import soundfile -from espnet2.utils.types import str_or_int from typeguard import check_argument_types +from espnet2.utils.types import str_or_int + def convert_rttm_text( path: Union[Path, str], diff --git a/egs2/TEMPLATE/diar1/pyscripts/utils/make_rttm.py b/egs2/TEMPLATE/diar1/pyscripts/utils/make_rttm.py index f8b9c8c05af..1f08fce0060 100755 --- a/egs2/TEMPLATE/diar1/pyscripts/utils/make_rttm.py +++ b/egs2/TEMPLATE/diar1/pyscripts/utils/make_rttm.py @@ -5,11 +5,13 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -from espnet2.fileio.npy_scp import NpyScpReader import logging + +import humanfriendly import numpy as np from scipy.signal import medfilt -import humanfriendly + +from espnet2.fileio.npy_scp import NpyScpReader def get_parser() -> argparse.Namespace: diff --git a/egs2/TEMPLATE/ssl1/pyscripts/dump_km_label.py b/egs2/TEMPLATE/ssl1/pyscripts/dump_km_label.py index 552c84f89ad..7b98eb97daa 100644 --- a/egs2/TEMPLATE/ssl1/pyscripts/dump_km_label.py +++ b/egs2/TEMPLATE/ssl1/pyscripts/dump_km_label.py @@ -1,16 +1,15 @@ import argparse import logging import os +import pdb import sys -import numpy as np - import joblib +import numpy as np import torch import tqdm -import pdb - -from sklearn_km import MfccFeatureReader, get_path_iterator, HubertFeatureReader +from sklearn_km import (HubertFeatureReader, MfccFeatureReader, + get_path_iterator) logging.basicConfig( level=logging.DEBUG, diff --git a/egs2/TEMPLATE/ssl1/pyscripts/feature_loader.py b/egs2/TEMPLATE/ssl1/pyscripts/feature_loader.py index b0dae8a2074..050817349f4 100644 --- a/egs2/TEMPLATE/ssl1/pyscripts/feature_loader.py +++ b/egs2/TEMPLATE/ssl1/pyscripts/feature_loader.py @@ -14,7 +14,6 @@ import sys import fairseq - import soundfile as sf import torch import torchaudio diff --git a/egs2/TEMPLATE/ssl1/pyscripts/sklearn_km.py b/egs2/TEMPLATE/ssl1/pyscripts/sklearn_km.py index ce0c82fcd3c..d97e9df26c1 100644 --- a/egs2/TEMPLATE/ssl1/pyscripts/sklearn_km.py +++ b/egs2/TEMPLATE/ssl1/pyscripts/sklearn_km.py @@ -8,28 +8,24 @@ import argparse import logging +import math import os import sys -from random import sample import warnings +from random import sample +import fairseq import joblib import numpy as np -import math - import soundfile as sf import torch import torchaudio import tqdm - +from feature_loader import HubertFeatureReader, MfccFeatureReader from sklearn.cluster import MiniBatchKMeans -import fairseq from espnet2.asr.encoder.hubert_encoder import FairseqHubertEncoder -from feature_loader import MfccFeatureReader -from feature_loader import HubertFeatureReader - logging.basicConfig( level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", diff --git a/egs2/accented_french_openslr57/asr1/local/remove_missing.py b/egs2/accented_french_openslr57/asr1/local/remove_missing.py index 937144f75d8..1469b4a55bb 100644 --- a/egs2/accented_french_openslr57/asr1/local/remove_missing.py +++ b/egs2/accented_french_openslr57/asr1/local/remove_missing.py @@ -4,7 +4,6 @@ import argparse import os - parser = argparse.ArgumentParser(description="Normalize test text.") parser.add_argument("--folder", type=str, help="path of download folder") parser.add_argument("--train", type=str, help="path of train folder") diff --git a/egs2/aishell3/tts1/local/data_prep.py b/egs2/aishell3/tts1/local/data_prep.py index 706c28d5642..679232b9f3e 100644 --- a/egs2/aishell3/tts1/local/data_prep.py +++ b/egs2/aishell3/tts1/local/data_prep.py @@ -1,5 +1,6 @@ import argparse import os + from espnet2.utils.types import str2bool SPK_LABEL_LEN = 7 diff --git a/egs2/aishell4/enh1/local/prepare_audioset_category_list.py b/egs2/aishell4/enh1/local/prepare_audioset_category_list.py index 2c9a09bb0c6..af591399f3f 100644 --- a/egs2/aishell4/enh1/local/prepare_audioset_category_list.py +++ b/egs2/aishell4/enh1/local/prepare_audioset_category_list.py @@ -2,9 +2,9 @@ # Copyright 2022 Shanghai Jiao Tong University (Author: Wangyou Zhang) # Apache 2.0 -from pathlib import Path import re import sys +from pathlib import Path def prepare_audioset_category(audio_list, audioset_dir, output_file, skip_csv_rows=3): diff --git a/egs2/aishell4/enh1/local/split_train_dev.py b/egs2/aishell4/enh1/local/split_train_dev.py index 8961c40b12d..e7e7d75e239 100755 --- a/egs2/aishell4/enh1/local/split_train_dev.py +++ b/egs2/aishell4/enh1/local/split_train_dev.py @@ -2,14 +2,12 @@ # Copyright 2022 Shanghai Jiao Tong University (Authors: Wangyou Zhang) # Apache 2.0 -from collections import Counter -from collections import defaultdict -from fractions import Fraction import math -from pathlib import Path import random -from typing import List -from typing import Tuple +from collections import Counter, defaultdict +from fractions import Fraction +from pathlib import Path +from typing import List, Tuple def int_or_float_or_numstr(value): diff --git a/egs2/aishell4/enh1/local/split_train_dev_by_column.py b/egs2/aishell4/enh1/local/split_train_dev_by_column.py index ff50a9407a7..2555244065f 100755 --- a/egs2/aishell4/enh1/local/split_train_dev_by_column.py +++ b/egs2/aishell4/enh1/local/split_train_dev_by_column.py @@ -3,13 +3,12 @@ # Copyright 2022 Shanghai Jiao Tong University (Authors: Wangyou Zhang) # Apache 2.0 import argparse +import random from collections import defaultdict from pathlib import Path -import random -from split_train_dev import int_or_float_or_numstr -from split_train_dev import split_train_dev -from split_train_dev import split_train_dev_v2 +from split_train_dev import (int_or_float_or_numstr, split_train_dev, + split_train_dev_v2) def get_parser(): diff --git a/egs2/aishell4/enh1/local/split_train_dev_by_prefix.py b/egs2/aishell4/enh1/local/split_train_dev_by_prefix.py index c04cfb1a584..e02c927cecc 100755 --- a/egs2/aishell4/enh1/local/split_train_dev_by_prefix.py +++ b/egs2/aishell4/enh1/local/split_train_dev_by_prefix.py @@ -3,13 +3,12 @@ # Copyright 2022 Shanghai Jiao Tong University (Authors: Wangyou Zhang) # Apache 2.0 import argparse +import random from collections import defaultdict from pathlib import Path -import random -from split_train_dev import int_or_float_or_numstr -from split_train_dev import split_train_dev -from split_train_dev import split_train_dev_v2 +from split_train_dev import (int_or_float_or_numstr, split_train_dev, + split_train_dev_v2) def get_parser(): diff --git a/egs2/bn_openslr53/asr1/local/data_prep.py b/egs2/bn_openslr53/asr1/local/data_prep.py index 4cb5a47596b..5d831435277 100644 --- a/egs2/bn_openslr53/asr1/local/data_prep.py +++ b/egs2/bn_openslr53/asr1/local/data_prep.py @@ -8,7 +8,6 @@ import os import random - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/bur_openslr80/asr1/local/data_prep.py b/egs2/bur_openslr80/asr1/local/data_prep.py index 98180ea4b2e..654779696aa 100644 --- a/egs2/bur_openslr80/asr1/local/data_prep.py +++ b/egs2/bur_openslr80/asr1/local/data_prep.py @@ -8,7 +8,6 @@ import os import random - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/catslu/asr1/local/data_prep.py b/egs2/catslu/asr1/local/data_prep.py index 2ce83727a07..55bf6d2d979 100755 --- a/egs2/catslu/asr1/local/data_prep.py +++ b/egs2/catslu/asr1/local/data_prep.py @@ -4,11 +4,11 @@ # 2021 Carnegie Mellon University # Apache 2.0 +import json import os +import string as string_lib import sys from pathlib import Path -import json -import string as string_lib if len(sys.argv) != 2: print("Usage: python data_prep.py [catslu_root]") diff --git a/egs2/chime4/asr1/local/sym_channel.py b/egs2/chime4/asr1/local/sym_channel.py index 8a3bdcce2a9..dcffd487c4c 100644 --- a/egs2/chime4/asr1/local/sym_channel.py +++ b/egs2/chime4/asr1/local/sym_channel.py @@ -1,6 +1,6 @@ +import argparse import os from os import path -import argparse def create_sym(data_dir, track, wav): diff --git a/egs2/clarity21/enh1/local/prep_data.py b/egs2/clarity21/enh1/local/prep_data.py index fa61e757742..5ff0dc61b1d 100644 --- a/egs2/clarity21/enh1/local/prep_data.py +++ b/egs2/clarity21/enh1/local/prep_data.py @@ -2,7 +2,6 @@ import json import os - parser = argparse.ArgumentParser("Clarity") parser.add_argument( "--clarity_root", diff --git a/egs2/conferencingspeech21/enh1/local/prepare_dev_data.py b/egs2/conferencingspeech21/enh1/local/prepare_dev_data.py index 7ea801511b8..1e0d6f97354 100755 --- a/egs2/conferencingspeech21/enh1/local/prepare_dev_data.py +++ b/egs2/conferencingspeech21/enh1/local/prepare_dev_data.py @@ -3,8 +3,8 @@ # Copyright 2021 Shanghai Jiao Tong University (Authors: Wangyou Zhang) # Apache 2.0 import argparse -from pathlib import Path import re +from pathlib import Path from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.utils.types import str2bool diff --git a/egs2/dirha_wsj/asr1/local/prepare_dirha_wsj.py b/egs2/dirha_wsj/asr1/local/prepare_dirha_wsj.py index 8f017acd1e4..b29dabcf123 100755 --- a/egs2/dirha_wsj/asr1/local/prepare_dirha_wsj.py +++ b/egs2/dirha_wsj/asr1/local/prepare_dirha_wsj.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import argparse +import warnings +import xml.etree.ElementTree as ET from pathlib import Path from typing import Optional -import xml.etree.ElementTree as ET -import warnings import numpy as np import soundfile diff --git a/egs2/dsing/asr1/local/data_prep.py b/egs2/dsing/asr1/local/data_prep.py index 98d82fe1259..4cc3e893e2d 100644 --- a/egs2/dsing/asr1/local/data_prep.py +++ b/egs2/dsing/asr1/local/data_prep.py @@ -1,11 +1,11 @@ # Source from https://github.com/groadabike/Kaldi-Dsing-task -import json import argparse -from os.path import join, exists, isfile -from os import makedirs, listdir -import re import hashlib +import json +import re +from os import listdir, makedirs +from os.path import exists, isfile, join class DataSet: diff --git a/egs2/fsc/asr1/local/data_prep.py b/egs2/fsc/asr1/local/data_prep.py index f6cc9cb42ce..d430799ecbf 100644 --- a/egs2/fsc/asr1/local/data_prep.py +++ b/egs2/fsc/asr1/local/data_prep.py @@ -8,6 +8,7 @@ import os import re import sys + import pandas as pd if len(sys.argv) != 2: diff --git a/egs2/fsc_challenge/asr1/local/data_prep.py b/egs2/fsc_challenge/asr1/local/data_prep.py index 95de097faac..86dbd5b1ecf 100644 --- a/egs2/fsc_challenge/asr1/local/data_prep.py +++ b/egs2/fsc_challenge/asr1/local/data_prep.py @@ -5,9 +5,10 @@ import os import re +import string import sys + import pandas as pd -import string if len(sys.argv) != 2: print("Usage: python data_prep.py [fsc_root]") diff --git a/egs2/fsc_unseen/asr1/local/data_prep.py b/egs2/fsc_unseen/asr1/local/data_prep.py index 68d2d1798cf..cd81e313131 100644 --- a/egs2/fsc_unseen/asr1/local/data_prep.py +++ b/egs2/fsc_unseen/asr1/local/data_prep.py @@ -5,9 +5,10 @@ import os import re +import string import sys + import pandas as pd -import string if len(sys.argv) != 2: print("Usage: python data_prep.py [fsc_root]") diff --git a/egs2/grabo/asr1/local/data_prep.py b/egs2/grabo/asr1/local/data_prep.py index ef91723c164..8b76c178cc8 100644 --- a/egs2/grabo/asr1/local/data_prep.py +++ b/egs2/grabo/asr1/local/data_prep.py @@ -9,13 +9,12 @@ # https://arxiv.org/pdf/2008.01994.pdf (for train/dev/test split) -import os +import argparse import glob +import os import random -import argparse import xml.etree.ElementTree as ET - parser = argparse.ArgumentParser(description="Process Grabo dataset.") parser.add_argument( "--data_path", diff --git a/egs2/grabo/asr1/local/score.py b/egs2/grabo/asr1/local/score.py index b1c79a976c9..9ba1b759cdc 100644 --- a/egs2/grabo/asr1/local/score.py +++ b/egs2/grabo/asr1/local/score.py @@ -2,9 +2,9 @@ # Copyright 2021 Carnegie Mellon University (Yifan Peng) +import argparse import os import os.path -import argparse parser = argparse.ArgumentParser(description="Calculate classification accuracy.") parser.add_argument("--wer_dir", type=str, help="folder containing hyp.trn and ref.trn") diff --git a/egs2/indic_speech/tts1/local/data_prep.py b/egs2/indic_speech/tts1/local/data_prep.py index 6229dc0e179..b3932f2f445 100644 --- a/egs2/indic_speech/tts1/local/data_prep.py +++ b/egs2/indic_speech/tts1/local/data_prep.py @@ -11,7 +11,6 @@ from tqdm import tqdm - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/iwslt21_low_resource/asr1/local/prepare_alffa_data.py b/egs2/iwslt21_low_resource/asr1/local/prepare_alffa_data.py index 3fde4f274f0..a727ba21019 100755 --- a/egs2/iwslt21_low_resource/asr1/local/prepare_alffa_data.py +++ b/egs2/iwslt21_low_resource/asr1/local/prepare_alffa_data.py @@ -4,8 +4,8 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import os -from shutil import copyfile import sys +from shutil import copyfile idir = sys.argv[1] diff --git a/egs2/iwslt21_low_resource/asr1/local/prepare_iwslt_data.py b/egs2/iwslt21_low_resource/asr1/local/prepare_iwslt_data.py index 95bd9f933a5..c5779746b97 100755 --- a/egs2/iwslt21_low_resource/asr1/local/prepare_iwslt_data.py +++ b/egs2/iwslt21_low_resource/asr1/local/prepare_iwslt_data.py @@ -6,6 +6,7 @@ import argparse import os import re + import yaml parser = argparse.ArgumentParser( diff --git a/egs2/iwslt22_dialect/asr1/local/preprocess.py b/egs2/iwslt22_dialect/asr1/local/preprocess.py index bbd1e42d342..f92f965e6dd 100755 --- a/egs2/iwslt22_dialect/asr1/local/preprocess.py +++ b/egs2/iwslt22_dialect/asr1/local/preprocess.py @@ -5,11 +5,11 @@ TBD """ -import re -import os -import sys import argparse import itertools +import os +import re +import sys parser = argparse.ArgumentParser() parser.add_argument( diff --git a/egs2/iwslt22_dialect/st1/local/preprocess.py b/egs2/iwslt22_dialect/st1/local/preprocess.py index 2d02de1eb64..2e21e0d604c 100755 --- a/egs2/iwslt22_dialect/st1/local/preprocess.py +++ b/egs2/iwslt22_dialect/st1/local/preprocess.py @@ -5,11 +5,11 @@ TBD """ -import re -import os -import sys import argparse import itertools +import os +import re +import sys parser = argparse.ArgumentParser() parser.add_argument( diff --git a/egs2/jdcinal/asr1/local/score.py b/egs2/jdcinal/asr1/local/score.py index 8b68151c4e7..59d9d4ef900 100755 --- a/egs2/jdcinal/asr1/local/score.py +++ b/egs2/jdcinal/asr1/local/score.py @@ -8,6 +8,7 @@ import os import re import sys + import pandas as pd diff --git a/egs2/jkac/tts1/local/prep_segments.py b/egs2/jkac/tts1/local/prep_segments.py index 2090624521e..3a6213890d2 100755 --- a/egs2/jkac/tts1/local/prep_segments.py +++ b/egs2/jkac/tts1/local/prep_segments.py @@ -4,10 +4,11 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -from collections import namedtuple import os import re import sys +from collections import namedtuple + import yaml diff --git a/egs2/jmd/tts1/local/clean_text.py b/egs2/jmd/tts1/local/clean_text.py index 5110effc9e4..1a73ef3bf7b 100755 --- a/egs2/jmd/tts1/local/clean_text.py +++ b/egs2/jmd/tts1/local/clean_text.py @@ -6,7 +6,6 @@ import argparse import re - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/egs2/jtubespeech/tts1/local/prune.py b/egs2/jtubespeech/tts1/local/prune.py index a6beac9d8d3..f09ce6164b8 100644 --- a/egs2/jtubespeech/tts1/local/prune.py +++ b/egs2/jtubespeech/tts1/local/prune.py @@ -3,11 +3,12 @@ # Copyright 2021 Takaaki Saeki # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import os +import argparse import glob -import tqdm +import os + import soundfile as sf -import argparse +import tqdm if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/egs2/jtubespeech/tts1/local/split.py b/egs2/jtubespeech/tts1/local/split.py index df59e412941..277e42dc9a5 100644 --- a/egs2/jtubespeech/tts1/local/split.py +++ b/egs2/jtubespeech/tts1/local/split.py @@ -3,11 +3,12 @@ # Copyright 2021 Takaaki Saeki # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import os +import argparse import glob -import tqdm +import os + import soundfile as sf -import argparse +import tqdm if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/egs2/jv_openslr35/asr1/local/data_prep.py b/egs2/jv_openslr35/asr1/local/data_prep.py index 4cb5a47596b..5d831435277 100644 --- a/egs2/jv_openslr35/asr1/local/data_prep.py +++ b/egs2/jv_openslr35/asr1/local/data_prep.py @@ -8,7 +8,6 @@ import os import random - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/ksponspeech/asr1/local/get_space_normalized_hyps.py b/egs2/ksponspeech/asr1/local/get_space_normalized_hyps.py index c105b47c578..1f5225bfe83 100755 --- a/egs2/ksponspeech/asr1/local/get_space_normalized_hyps.py +++ b/egs2/ksponspeech/asr1/local/get_space_normalized_hyps.py @@ -4,11 +4,11 @@ # Copyright 2020 Electronics and Telecommunications Research Institute (Jeong-Uk, Bang) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import configargparse import logging import os import sys +import configargparse from numpy import zeros space_sym = "▁" diff --git a/egs2/ksponspeech/asr1/local/get_transcriptions.py b/egs2/ksponspeech/asr1/local/get_transcriptions.py index 9d1db4b9225..771c377641f 100644 --- a/egs2/ksponspeech/asr1/local/get_transcriptions.py +++ b/egs2/ksponspeech/asr1/local/get_transcriptions.py @@ -5,13 +5,14 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import codecs -import configargparse import logging import os import re import shutil import sys +import configargparse + def get_parser(): """Get default arguments.""" diff --git a/egs2/librimix/diar1/local/prepare_diarization.py b/egs2/librimix/diar1/local/prepare_diarization.py index b4ab66e5c44..ca42c42e21f 100755 --- a/egs2/librimix/diar1/local/prepare_diarization.py +++ b/egs2/librimix/diar1/local/prepare_diarization.py @@ -1,6 +1,6 @@ +import argparse import os import re -import argparse def float2str(number, size=6): diff --git a/egs2/lrs2/lipreading1/local/feature_extract/cvtransforms.py b/egs2/lrs2/lipreading1/local/feature_extract/cvtransforms.py index 8a2c0710a7d..b7dac9c0d1f 100644 --- a/egs2/lrs2/lipreading1/local/feature_extract/cvtransforms.py +++ b/egs2/lrs2/lipreading1/local/feature_extract/cvtransforms.py @@ -1,5 +1,6 @@ # coding: utf-8 import random + import cv2 import numpy as np diff --git a/egs2/lrs2/lipreading1/local/feature_extract/extract_visual_feature.py b/egs2/lrs2/lipreading1/local/feature_extract/extract_visual_feature.py index 8164bdb54ba..b4efd021081 100644 --- a/egs2/lrs2/lipreading1/local/feature_extract/extract_visual_feature.py +++ b/egs2/lrs2/lipreading1/local/feature_extract/extract_visual_feature.py @@ -4,17 +4,17 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -from distutils.util import strtobool import logging +from distutils.util import strtobool import kaldiio import numpy import resampy from video_processing import * +from espnet2.utils.types import int_or_none from espnet.utils.cli_utils import get_commandline_args from espnet.utils.cli_writers import file_writer_helper -from espnet2.utils.types import int_or_none def get_parser(): diff --git a/egs2/lrs2/lipreading1/local/feature_extract/models/pretrained.py b/egs2/lrs2/lipreading1/local/feature_extract/models/pretrained.py index ef0fd388231..feb10aa5a89 100755 --- a/egs2/lrs2/lipreading1/local/feature_extract/models/pretrained.py +++ b/egs2/lrs2/lipreading1/local/feature_extract/models/pretrained.py @@ -1,8 +1,7 @@ # coding: utf-8 import math -import numpy as np - +import numpy as np import torch import torch.nn as nn from torch.autograd import Variable diff --git a/egs2/lrs2/lipreading1/local/feature_extract/video_processing.py b/egs2/lrs2/lipreading1/local/feature_extract/video_processing.py index 216de5d03f1..b9812b18d3c 100644 --- a/egs2/lrs2/lipreading1/local/feature_extract/video_processing.py +++ b/egs2/lrs2/lipreading1/local/feature_extract/video_processing.py @@ -1,8 +1,8 @@ -import skvideo.io -import skimage.transform +import cvtransforms import face_alignment import numpy as np -import cvtransforms +import skimage.transform +import skvideo.io import torch from models import pretrained diff --git a/egs2/lrs3/asr1/local/data_prep.py b/egs2/lrs3/asr1/local/data_prep.py index 2ba8c7a816b..a3458813eeb 100644 --- a/egs2/lrs3/asr1/local/data_prep.py +++ b/egs2/lrs3/asr1/local/data_prep.py @@ -5,12 +5,13 @@ # Apache 2.0 -import os import argparse import logging -import numpy as np +import os from pathlib import Path -from typing import Union, List +from typing import List, Union + +import numpy as np class Utils: diff --git a/egs2/mediaspeech/asr1/local/data_prep.py b/egs2/mediaspeech/asr1/local/data_prep.py index 42162c53da0..ba92504859d 100755 --- a/egs2/mediaspeech/asr1/local/data_prep.py +++ b/egs2/mediaspeech/asr1/local/data_prep.py @@ -1,12 +1,11 @@ -import os -import os.path -import json +import argparse import glob +import json import math -import argparse +import os +import os.path import random - parser = argparse.ArgumentParser(description="Prepare mediaspeech") parser.add_argument( "--data_path", type=str, help="Path to the directory containing all files" diff --git a/egs2/microsoft_speech/asr1/local/process.py b/egs2/microsoft_speech/asr1/local/process.py index eab7b85de9a..f0ff3b862a2 100644 --- a/egs2/microsoft_speech/asr1/local/process.py +++ b/egs2/microsoft_speech/asr1/local/process.py @@ -1,10 +1,10 @@ -import os -import wave import contextlib -from tqdm import tqdm +import os import random import sys +import wave +from tqdm import tqdm microsoft_speech_corpus_path = sys.argv[1] lang = sys.argv[2] diff --git a/egs2/mini_librispeech/diar1/local/simulation/make_mixture.py b/egs2/mini_librispeech/diar1/local/simulation/make_mixture.py index ad16f72ec18..8f5fbaef5f9 100755 --- a/egs2/mini_librispeech/diar1/local/simulation/make_mixture.py +++ b/egs2/mini_librispeech/diar1/local/simulation/make_mixture.py @@ -13,12 +13,13 @@ import argparse +import json +import math import os + import common import numpy as np -import math import soundfile as sf -import json parser = argparse.ArgumentParser() parser.add_argument("script", help="list of json") diff --git a/egs2/mini_librispeech/diar1/local/simulation/make_mixture_nooverlap.py b/egs2/mini_librispeech/diar1/local/simulation/make_mixture_nooverlap.py index 9b8c24cd87f..2d79dbddc0f 100755 --- a/egs2/mini_librispeech/diar1/local/simulation/make_mixture_nooverlap.py +++ b/egs2/mini_librispeech/diar1/local/simulation/make_mixture_nooverlap.py @@ -14,12 +14,13 @@ import argparse +import json +import math import os + import common import numpy as np -import math import soundfile as sf -import json parser = argparse.ArgumentParser() parser.add_argument("script", help="list of json") diff --git a/egs2/mini_librispeech/diar1/local/simulation/random_mixture.py b/egs2/mini_librispeech/diar1/local/simulation/random_mixture.py index 7d67d056d99..61022c56ff5 100755 --- a/egs2/mini_librispeech/diar1/local/simulation/random_mixture.py +++ b/egs2/mini_librispeech/diar1/local/simulation/random_mixture.py @@ -40,12 +40,13 @@ """ import argparse +import itertools +import json import os -import common import random + +import common import numpy as np -import json -import itertools parser = argparse.ArgumentParser() parser.add_argument("data_dir", help="data dir of single-speaker recordings") diff --git a/egs2/mini_librispeech/diar1/local/simulation/random_mixture_nooverlap.py b/egs2/mini_librispeech/diar1/local/simulation/random_mixture_nooverlap.py index b6e417f81ab..acbdde34f20 100755 --- a/egs2/mini_librispeech/diar1/local/simulation/random_mixture_nooverlap.py +++ b/egs2/mini_librispeech/diar1/local/simulation/random_mixture_nooverlap.py @@ -42,12 +42,13 @@ """ import argparse +import itertools +import json import os -import common import random + +import common import numpy as np -import json -import itertools parser = argparse.ArgumentParser() parser.add_argument("data_dir", help="data dir of single-speaker recordings") diff --git a/egs2/misp2021/asr1/local/find_wav.py b/egs2/misp2021/asr1/local/find_wav.py index 216c30de4c9..22979044bdc 100755 --- a/egs2/misp2021/asr1/local/find_wav.py +++ b/egs2/misp2021/asr1/local/find_wav.py @@ -1,9 +1,9 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import glob -import codecs import argparse +import codecs +import glob +import os def find_wav(data_root, scp_dir, scp_name="wpe", wav_type="Far", n_split=1): diff --git a/egs2/misp2021/asr1/local/prepare_far_data.py b/egs2/misp2021/asr1/local/prepare_far_data.py index b98766f2168..ac845a8043d 100755 --- a/egs2/misp2021/asr1/local/prepare_far_data.py +++ b/egs2/misp2021/asr1/local/prepare_far_data.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -- coding: UTF-8 -import os -import glob -import codecs import argparse -from multiprocessing import Pool +import codecs +import glob +import os import sys +from multiprocessing import Pool def text2lines(textpath, lines_content=None): diff --git a/egs2/misp2021/asr1/local/run_beamformit.py b/egs2/misp2021/asr1/local/run_beamformit.py index 8070542bb30..d55b7bf30a4 100755 --- a/egs2/misp2021/asr1/local/run_beamformit.py +++ b/egs2/misp2021/asr1/local/run_beamformit.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os import argparse +import os def beamformit_worker( diff --git a/egs2/misp2021/asr1/local/run_wpe.py b/egs2/misp2021/asr1/local/run_wpe.py index 0815dd1dfae..b66037a20a0 100755 --- a/egs2/misp2021/asr1/local/run_wpe.py +++ b/egs2/misp2021/asr1/local/run_wpe.py @@ -1,13 +1,14 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import codecs import argparse +import codecs +import os +from multiprocessing import Pool + import numpy as np import scipy.io.wavfile as wf -from multiprocessing import Pool +from nara_wpe.utils import istft, stft from nara_wpe.wpe import wpe_v8 as wpe -from nara_wpe.utils import stft, istft def wpe_worker( diff --git a/egs2/misp2021/avsr1/local/concatenate_feature.py b/egs2/misp2021/avsr1/local/concatenate_feature.py index 41676ca21a1..ff6a1fcc8b8 100755 --- a/egs2/misp2021/avsr1/local/concatenate_feature.py +++ b/egs2/misp2021/avsr1/local/concatenate_feature.py @@ -1,8 +1,9 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import codecs import argparse +import codecs +import os + import kaldiio import numpy as np from tqdm import tqdm diff --git a/egs2/misp2021/avsr1/local/find_wav.py b/egs2/misp2021/avsr1/local/find_wav.py index 216c30de4c9..22979044bdc 100755 --- a/egs2/misp2021/avsr1/local/find_wav.py +++ b/egs2/misp2021/avsr1/local/find_wav.py @@ -1,9 +1,9 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import glob -import codecs import argparse +import codecs +import glob +import os def find_wav(data_root, scp_dir, scp_name="wpe", wav_type="Far", n_split=1): diff --git a/egs2/misp2021/avsr1/local/prepare_far_data.py b/egs2/misp2021/avsr1/local/prepare_far_data.py index b98766f2168..ac845a8043d 100755 --- a/egs2/misp2021/avsr1/local/prepare_far_data.py +++ b/egs2/misp2021/avsr1/local/prepare_far_data.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -- coding: UTF-8 -import os -import glob -import codecs import argparse -from multiprocessing import Pool +import codecs +import glob +import os import sys +from multiprocessing import Pool def text2lines(textpath, lines_content=None): diff --git a/egs2/misp2021/avsr1/local/prepare_far_video_roi.py b/egs2/misp2021/avsr1/local/prepare_far_video_roi.py index 32d8a0fae4d..01accc5ec9a 100755 --- a/egs2/misp2021/avsr1/local/prepare_far_video_roi.py +++ b/egs2/misp2021/avsr1/local/prepare_far_video_roi.py @@ -1,14 +1,15 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import cv2 +import argparse +import codecs import json +import os import time -import codecs -import argparse +from multiprocessing import Pool + +import cv2 import numpy as np from tqdm import tqdm -from multiprocessing import Pool def crop_frame_roi(frame, roi_bound, roi_size=(96, 96)): diff --git a/egs2/misp2021/avsr1/local/prepare_visual_embedding_extractor.py b/egs2/misp2021/avsr1/local/prepare_visual_embedding_extractor.py index 38eb7f60611..4cb619f79db 100755 --- a/egs2/misp2021/avsr1/local/prepare_visual_embedding_extractor.py +++ b/egs2/misp2021/avsr1/local/prepare_visual_embedding_extractor.py @@ -1,8 +1,9 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import codecs import argparse +import codecs +import os + from tqdm import tqdm diff --git a/egs2/misp2021/avsr1/local/run_beamformit.py b/egs2/misp2021/avsr1/local/run_beamformit.py index 8070542bb30..d55b7bf30a4 100755 --- a/egs2/misp2021/avsr1/local/run_beamformit.py +++ b/egs2/misp2021/avsr1/local/run_beamformit.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os import argparse +import os def beamformit_worker( diff --git a/egs2/misp2021/avsr1/local/run_wpe.py b/egs2/misp2021/avsr1/local/run_wpe.py index 0815dd1dfae..b66037a20a0 100755 --- a/egs2/misp2021/avsr1/local/run_wpe.py +++ b/egs2/misp2021/avsr1/local/run_wpe.py @@ -1,13 +1,14 @@ #!/usr/bin/env python # _*_ coding: UTF-8 _*_ -import os -import codecs import argparse +import codecs +import os +from multiprocessing import Pool + import numpy as np import scipy.io.wavfile as wf -from multiprocessing import Pool +from nara_wpe.utils import istft, stft from nara_wpe.wpe import wpe_v8 as wpe -from nara_wpe.utils import stft, istft def wpe_worker( diff --git a/egs2/ml_openslr63/asr1/local/data_prep.py b/egs2/ml_openslr63/asr1/local/data_prep.py index bd174f75e68..84cb779d18b 100644 --- a/egs2/ml_openslr63/asr1/local/data_prep.py +++ b/egs2/ml_openslr63/asr1/local/data_prep.py @@ -9,7 +9,6 @@ import os import random - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/mr_openslr64/asr1/local/data_prep.py b/egs2/mr_openslr64/asr1/local/data_prep.py index ed446ef71ae..f1f0245f657 100644 --- a/egs2/mr_openslr64/asr1/local/data_prep.py +++ b/egs2/mr_openslr64/asr1/local/data_prep.py @@ -8,7 +8,6 @@ import os import random - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/ms_indic_18/asr1/local/prepare_data.py b/egs2/ms_indic_18/asr1/local/prepare_data.py index 464a1f43b11..d51f817b481 100755 --- a/egs2/ms_indic_18/asr1/local/prepare_data.py +++ b/egs2/ms_indic_18/asr1/local/prepare_data.py @@ -8,8 +8,8 @@ import os import random import sys -import librosa +import librosa if len(sys.argv) != 3: print("Usage: python prepare_data.py [data-directory] [language-ID]") diff --git a/egs2/open_li52/asr1/local/filter_text.py b/egs2/open_li52/asr1/local/filter_text.py index db35c1754da..c5b000ce4c0 100755 --- a/egs2/open_li52/asr1/local/filter_text.py +++ b/egs2/open_li52/asr1/local/filter_text.py @@ -6,9 +6,8 @@ import argparse import codecs -from io import open import sys - +from io import open PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) diff --git a/egs2/primewords_chinese/asr1/local/data_prep.py b/egs2/primewords_chinese/asr1/local/data_prep.py index 11258bc597f..0c666eb96ac 100644 --- a/egs2/primewords_chinese/asr1/local/data_prep.py +++ b/egs2/primewords_chinese/asr1/local/data_prep.py @@ -3,13 +3,12 @@ # Copyright 2021 Carnegie Mellon University (Yifan Peng) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import os -import os.path -import json +import argparse import glob +import json import math -import argparse - +import os +import os.path parser = argparse.ArgumentParser(description="Prepare Primewords_Chinese") parser.add_argument( diff --git a/egs2/seame/asr1/local/preprocess.py b/egs2/seame/asr1/local/preprocess.py index eb0ccfac47b..51a1e341c61 100755 --- a/egs2/seame/asr1/local/preprocess.py +++ b/egs2/seame/asr1/local/preprocess.py @@ -18,13 +18,13 @@ [3] https://github.com/zengzp0912/SEAME-dev-set """ -import re -import os -import sys import argparse -import itertools import collections +import itertools +import os import random as rd +import re +import sys rd.seed(531) diff --git a/egs2/seame/asr1/local/split_lang_trn.py b/egs2/seame/asr1/local/split_lang_trn.py index 1cff8674a73..896cfec69b0 100755 --- a/egs2/seame/asr1/local/split_lang_trn.py +++ b/egs2/seame/asr1/local/split_lang_trn.py @@ -1,16 +1,12 @@ #!/usr/bin/env python3 # -*- encoding: utf8 -*- -import os import argparse +import os -from preprocess import ( - remove_redundant_whitespaces, - extract_mandarin_only, - extract_non_mandarin, - insert_space_between_mandarin, -) - +from preprocess import (extract_mandarin_only, extract_non_mandarin, + insert_space_between_mandarin, + remove_redundant_whitespaces) if __name__ == "__main__": # Parse arguments diff --git a/egs2/sinhala/asr1/local/data_prep.py b/egs2/sinhala/asr1/local/data_prep.py index ba218d62260..4bdbd737b33 100644 --- a/egs2/sinhala/asr1/local/data_prep.py +++ b/egs2/sinhala/asr1/local/data_prep.py @@ -8,13 +8,11 @@ import os import re import sys -import pandas as pd -from tqdm import tqdm -import pandas as pd -import os + import numpy as np +import pandas as pd from sklearn.model_selection import train_test_split - +from tqdm import tqdm if len(sys.argv) != 2: print("Usage: python data_prep.py [SINHALA]") diff --git a/egs2/slue-voxceleb/asr1/local/data_prep_slue.py b/egs2/slue-voxceleb/asr1/local/data_prep_slue.py index 89b42059e30..43e0428806a 100644 --- a/egs2/slue-voxceleb/asr1/local/data_prep_slue.py +++ b/egs2/slue-voxceleb/asr1/local/data_prep_slue.py @@ -8,6 +8,7 @@ import os import re import sys + import pandas as pd if len(sys.argv) != 2: diff --git a/egs2/slue-voxceleb/asr1/local/f1_score.py b/egs2/slue-voxceleb/asr1/local/f1_score.py index 4f45752a812..4b71566da27 100755 --- a/egs2/slue-voxceleb/asr1/local/f1_score.py +++ b/egs2/slue-voxceleb/asr1/local/f1_score.py @@ -5,13 +5,13 @@ # Apache 2.0 +import argparse import os import re import sys + import pandas as pd -import argparse -from sklearn.metrics import f1_score -from sklearn.metrics import classification_report +from sklearn.metrics import classification_report, f1_score def get_classification_result(hyp_file, ref_file): diff --git a/egs2/slue-voxceleb/asr1/local/generate_asr_files.py b/egs2/slue-voxceleb/asr1/local/generate_asr_files.py index dd8a4645410..4e990b5cda4 100644 --- a/egs2/slue-voxceleb/asr1/local/generate_asr_files.py +++ b/egs2/slue-voxceleb/asr1/local/generate_asr_files.py @@ -5,11 +5,12 @@ # Apache 2.0 +import argparse import os import re import sys + import pandas as pd -import argparse def generate_asr_files(txt_file, transcript_file): diff --git a/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format.py b/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format.py index 005da336b83..32a20ff548b 100644 --- a/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format.py +++ b/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 import os -import pandas as pd import re import sys +import pandas as pd + if len(sys.argv) != 2: print("Usage: python data_prep.py [root]") sys.exit(1) diff --git a/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format_transcript.py b/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format_transcript.py index 622c830e4cf..515e477a00c 100644 --- a/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format_transcript.py +++ b/egs2/slue-voxpopuli/asr1/local/data_prep_original_slue_format_transcript.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 import os -import pandas as pd import re import sys +import pandas as pd + if len(sys.argv) != 2: print("Usage: python data_prep.py [root]") sys.exit(1) diff --git a/egs2/slue-voxpopuli/asr1/local/eval_utils.py b/egs2/slue-voxpopuli/asr1/local/eval_utils.py index 9310735eb56..eaebc2881f4 100644 --- a/egs2/slue-voxpopuli/asr1/local/eval_utils.py +++ b/egs2/slue-voxpopuli/asr1/local/eval_utils.py @@ -1,7 +1,8 @@ -from typing import List from collections import defaultdict -import numpy as np +from typing import List + import editdistance +import numpy as np def get_ner_scores(all_gt, all_predictions): diff --git a/egs2/slue-voxpopuli/asr1/local/score.py b/egs2/slue-voxpopuli/asr1/local/score.py index 4239ed2b151..4663833c58f 100755 --- a/egs2/slue-voxpopuli/asr1/local/score.py +++ b/egs2/slue-voxpopuli/asr1/local/score.py @@ -4,14 +4,14 @@ # 2021 Carnegie Mellon University # Apache 2.0 +import argparse import json import os import re import sys -import pandas as pd -import argparse import eval_utils +import pandas as pd ontonotes_to_combined_label = { "GPE": "PLACE", diff --git a/egs2/slurp/asr1/local/prepare_slurp_data.py b/egs2/slurp/asr1/local/prepare_slurp_data.py index 1120d03f9a5..d7dfb1d2674 100644 --- a/egs2/slurp/asr1/local/prepare_slurp_data.py +++ b/egs2/slurp/asr1/local/prepare_slurp_data.py @@ -5,9 +5,9 @@ import json import os -import sys -import subprocess import re +import subprocess +import sys idir = sys.argv[1] diff --git a/egs2/slurp_entity/asr1/local/convert_to_entity_file.py b/egs2/slurp_entity/asr1/local/convert_to_entity_file.py index e37898f1ae9..9d65c79c5e0 100644 --- a/egs2/slurp_entity/asr1/local/convert_to_entity_file.py +++ b/egs2/slurp_entity/asr1/local/convert_to_entity_file.py @@ -1,7 +1,7 @@ -import json -import sys import argparse +import json import os +import sys def generate_entity_file(line_arr, output_file="result_test.json"): diff --git a/egs2/slurp_entity/asr1/local/evaluation/evaluate.py b/egs2/slurp_entity/asr1/local/evaluation/evaluate.py index 908fb2d77c0..bd6b2e468da 100755 --- a/egs2/slurp_entity/asr1/local/evaluation/evaluate.py +++ b/egs2/slurp_entity/asr1/local/evaluation/evaluate.py @@ -1,10 +1,9 @@ import argparse import logging -from progress.bar import Bar - from metrics import ErrorMetric -from util import format_results, load_predictions, load_gold_data +from progress.bar import Bar +from util import format_results, load_gold_data, load_predictions logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/egs2/slurp_entity/asr1/local/evaluation/metrics/__init__.py b/egs2/slurp_entity/asr1/local/evaluation/metrics/__init__.py index 80b25690873..6148e4fdd55 100755 --- a/egs2/slurp_entity/asr1/local/evaluation/metrics/__init__.py +++ b/egs2/slurp_entity/asr1/local/evaluation/metrics/__init__.py @@ -1,3 +1,2 @@ from .distance import Distance -from .metrics import ErrorMetric -from .metrics import compute_metrics +from .metrics import ErrorMetric, compute_metrics diff --git a/egs2/slurp_entity/asr1/local/evaluation/metrics/distance.py b/egs2/slurp_entity/asr1/local/evaluation/metrics/distance.py index 18928317486..3451a96bd9b 100755 --- a/egs2/slurp_entity/asr1/local/evaluation/metrics/distance.py +++ b/egs2/slurp_entity/asr1/local/evaluation/metrics/distance.py @@ -1,5 +1,6 @@ -from jiwer import wer from typing import List, Union + +from jiwer import wer from textdistance.algorithms.edit_based import levenshtein DISTANCE_OPTIONS = {"word", "char"} diff --git a/egs2/slurp_entity/asr1/local/evaluation/util.py b/egs2/slurp_entity/asr1/local/evaluation/util.py index c5c2b3560d5..a818d476caf 100755 --- a/egs2/slurp_entity/asr1/local/evaluation/util.py +++ b/egs2/slurp_entity/asr1/local/evaluation/util.py @@ -1,10 +1,9 @@ import json import logging import os -import tabulate - -from typing import Dict, Any, Tuple +from typing import Any, Dict, Tuple +import tabulate from progress.bar import Bar logging.basicConfig( diff --git a/egs2/slurp_entity/asr1/local/prepare_slurp_data.py b/egs2/slurp_entity/asr1/local/prepare_slurp_data.py index 1120d03f9a5..d7dfb1d2674 100644 --- a/egs2/slurp_entity/asr1/local/prepare_slurp_data.py +++ b/egs2/slurp_entity/asr1/local/prepare_slurp_data.py @@ -5,9 +5,9 @@ import json import os -import sys -import subprocess import re +import subprocess +import sys idir = sys.argv[1] diff --git a/egs2/slurp_entity/asr1/local/prepare_slurp_entity_data.py b/egs2/slurp_entity/asr1/local/prepare_slurp_entity_data.py index 358d947ced9..220a7cfc042 100644 --- a/egs2/slurp_entity/asr1/local/prepare_slurp_entity_data.py +++ b/egs2/slurp_entity/asr1/local/prepare_slurp_entity_data.py @@ -5,9 +5,9 @@ import json import os -import sys -import subprocess import re +import subprocess +import sys idir = sys.argv[1] diff --git a/egs2/snips/asr1/local/data_prep.py b/egs2/snips/asr1/local/data_prep.py index 79cd5e2b420..1c83ac5e749 100644 --- a/egs2/snips/asr1/local/data_prep.py +++ b/egs2/snips/asr1/local/data_prep.py @@ -2,8 +2,8 @@ # Copyright 2021 Yuekai Zhang -import json import argparse +import json parser = argparse.ArgumentParser(description="Process snips dataset.") parser.add_argument("--wav_path", type=str, help="file path for audios") diff --git a/egs2/speechcommands/asr1/local/data_prep_12.py b/egs2/speechcommands/asr1/local/data_prep_12.py index b61bf6ac0f8..0fa1cdffa37 100644 --- a/egs2/speechcommands/asr1/local/data_prep_12.py +++ b/egs2/speechcommands/asr1/local/data_prep_12.py @@ -8,15 +8,15 @@ # https://www.tensorflow.org/datasets/catalog/speech_commands -import os -import os.path +import argparse import csv import glob -import argparse +import os +import os.path + import numpy as np from scipy.io import wavfile - parser = argparse.ArgumentParser(description="Process speech commands dataset.") parser.add_argument( "--data_path", diff --git a/egs2/speechcommands/asr1/local/data_prep_35.py b/egs2/speechcommands/asr1/local/data_prep_35.py index 6b88e026a46..147c1c76c7e 100644 --- a/egs2/speechcommands/asr1/local/data_prep_35.py +++ b/egs2/speechcommands/asr1/local/data_prep_35.py @@ -6,11 +6,11 @@ # Speech Commands Dataset: https://arxiv.org/abs/1804.03209 +import argparse import os import os.path -import argparse -import numpy as np +import numpy as np parser = argparse.ArgumentParser( description="Process speech commands dataset with 35 commands." diff --git a/egs2/speechcommands/asr1/local/score.py b/egs2/speechcommands/asr1/local/score.py index b1c79a976c9..9ba1b759cdc 100644 --- a/egs2/speechcommands/asr1/local/score.py +++ b/egs2/speechcommands/asr1/local/score.py @@ -2,9 +2,9 @@ # Copyright 2021 Carnegie Mellon University (Yifan Peng) +import argparse import os import os.path -import argparse parser = argparse.ArgumentParser(description="Calculate classification accuracy.") parser.add_argument("--wer_dir", type=str, help="folder containing hyp.trn and ref.trn") diff --git a/egs2/su_openslr36/asr1/local/sunda_data_prep.py b/egs2/su_openslr36/asr1/local/sunda_data_prep.py index f2196874b91..02c1c0d43c9 100644 --- a/egs2/su_openslr36/asr1/local/sunda_data_prep.py +++ b/egs2/su_openslr36/asr1/local/sunda_data_prep.py @@ -8,7 +8,6 @@ import os import random - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", help="downloads directory", type=str, default="downloads") diff --git a/egs2/swbd_sentiment/asr1/local/prepare_sentiment.py b/egs2/swbd_sentiment/asr1/local/prepare_sentiment.py index 8921fa4272d..e76ce12fc5f 100755 --- a/egs2/swbd_sentiment/asr1/local/prepare_sentiment.py +++ b/egs2/swbd_sentiment/asr1/local/prepare_sentiment.py @@ -1,7 +1,7 @@ -import os -import re import argparse import math +import os +import re def float2str(number, size=6): diff --git a/egs2/swbd_sentiment/asr1/local/score_f1.py b/egs2/swbd_sentiment/asr1/local/score_f1.py index a36c37c7b1f..6408b70d8a3 100755 --- a/egs2/swbd_sentiment/asr1/local/score_f1.py +++ b/egs2/swbd_sentiment/asr1/local/score_f1.py @@ -5,11 +5,12 @@ # Apache 2.0 +import argparse import os import re import sys + import pandas as pd -import argparse from sklearn.metrics import f1_score diff --git a/egs2/totonac/asr1/local/data_prep.py b/egs2/totonac/asr1/local/data_prep.py index e3f76e03c0a..6edc792e651 100644 --- a/egs2/totonac/asr1/local/data_prep.py +++ b/egs2/totonac/asr1/local/data_prep.py @@ -1,12 +1,13 @@ -from argparse import ArgumentParser import os import re import shutil -import soundfile as sf import string import sys +from argparse import ArgumentParser from xml.dom.minidom import parse +import soundfile as sf + s = "".join(chr(c) for c in range(sys.maxunicode + 1)) ws = "".join(re.findall(r"\s", s)) outtab = " " * len(ws) diff --git a/egs2/wenetspeech/asr1/local/extract_meta.py b/egs2/wenetspeech/asr1/local/extract_meta.py index 30fa8803406..6074162038b 100755 --- a/egs2/wenetspeech/asr1/local/extract_meta.py +++ b/egs2/wenetspeech/asr1/local/extract_meta.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys -import os import argparse import json +import os +import sys def get_args(): diff --git a/egs2/wenetspeech/asr1/local/process_opus.py b/egs2/wenetspeech/asr1/local/process_opus.py index 7d6a6af8d1a..044d183b93f 100755 --- a/egs2/wenetspeech/asr1/local/process_opus.py +++ b/egs2/wenetspeech/asr1/local/process_opus.py @@ -16,9 +16,10 @@ # usage: python3 process_opus.py wav.scp segments output_wav.scp -from pydub import AudioSegment -import sys import os +import sys + +from pydub import AudioSegment def read_file(wav_scp, segments): diff --git a/egs2/yoloxochitl_mixtec/asr1/local/filter_text.py b/egs2/yoloxochitl_mixtec/asr1/local/filter_text.py index 162d09eeb68..c79c6d3032c 100755 --- a/egs2/yoloxochitl_mixtec/asr1/local/filter_text.py +++ b/egs2/yoloxochitl_mixtec/asr1/local/filter_text.py @@ -5,9 +5,8 @@ import argparse import codecs -from io import open import sys - +from io import open sys.stdin = codecs.getreader("utf-8")(sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) diff --git a/egs2/zh_openslr38/asr1/local/data_split.py b/egs2/zh_openslr38/asr1/local/data_split.py index df952d304cd..9200424b314 100644 --- a/egs2/zh_openslr38/asr1/local/data_split.py +++ b/egs2/zh_openslr38/asr1/local/data_split.py @@ -1,10 +1,10 @@ """ Split data to train, dev, test """ -import sys import os -from collections import defaultdict import random +import sys +from collections import defaultdict train_size = 0.9 random.seed(1) diff --git a/espnet/asr/asr_utils.py b/espnet/asr/asr_utils.py index e8c7387ae4b..ea61f646102 100644 --- a/espnet/asr/asr_utils.py +++ b/espnet/asr/asr_utils.py @@ -12,7 +12,6 @@ import numpy as np import torch - # * -------------------- training iterator related -------------------- * diff --git a/espnet/asr/chainer_backend/asr.py b/espnet/asr/chainer_backend/asr.py index 976d920bfbd..12b4424735e 100644 --- a/espnet/asr/chainer_backend/asr.py +++ b/espnet/asr/chainer_backend/asr.py @@ -6,40 +6,32 @@ import json import logging import os -import six # chainer related import chainer - +import six from chainer import training - from chainer.datasets import TransformDataset from chainer.training import extensions +# rnnlm +import espnet.lm.chainer_backend.extlm as extlm_chainer +import espnet.lm.chainer_backend.lm as lm_chainer # espnet related -from espnet.asr.asr_utils import adadelta_eps_decay -from espnet.asr.asr_utils import add_results_to_json -from espnet.asr.asr_utils import chainer_load -from espnet.asr.asr_utils import CompareValueTrigger -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import restore_snapshot +from espnet.asr.asr_utils import (CompareValueTrigger, adadelta_eps_decay, + add_results_to_json, chainer_load, + get_model_conf, restore_snapshot) from espnet.nets.asr_interface import ASRInterface from espnet.utils.deterministic_utils import set_deterministic_chainer from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.evaluator import BaseEvaluator -from espnet.utils.training.iterators import ShufflingEnabler -from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator -from espnet.utils.training.iterators import ToggleableShufflingSerialIterator -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop - -# rnnlm -import espnet.lm.chainer_backend.extlm as extlm_chainer -import espnet.lm.chainer_backend.lm as lm_chainer - +from espnet.utils.training.iterators import ( + ShufflingEnabler, ToggleableShufflingMultiprocessIterator, + ToggleableShufflingSerialIterator) from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop, set_early_stop def train(args): @@ -278,7 +270,8 @@ def train(args): trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) if args.opt == "noam": - from espnet.nets.chainer_backend.transformer.training import VaswaniRule + from espnet.nets.chainer_backend.transformer.training import \ + VaswaniRule trainer.extend( VaswaniRule( diff --git a/espnet/asr/pytorch_backend/asr.py b/espnet/asr/pytorch_backend/asr.py index 7a265f2badf..1a63a11e660 100644 --- a/espnet/asr/pytorch_backend/asr.py +++ b/espnet/asr/pytorch_backend/asr.py @@ -9,42 +9,35 @@ import logging import math import os -from packaging.version import parse as V +import numpy as np +import torch from chainer import reporter as reporter_module from chainer import training from chainer.training import extensions from chainer.training.updater import StandardUpdater -import numpy as np -import torch +from packaging.version import parse as V from torch.nn.parallel import data_parallel -from espnet.asr.asr_utils import adadelta_eps_decay -from espnet.asr.asr_utils import add_results_to_json -from espnet.asr.asr_utils import CompareValueTrigger -from espnet.asr.asr_utils import format_mulenc_args -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import plot_spectrogram -from espnet.asr.asr_utils import restore_snapshot -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot -from espnet.asr.pytorch_backend.asr_init import freeze_modules -from espnet.asr.pytorch_backend.asr_init import load_trained_model -from espnet.asr.pytorch_backend.asr_init import load_trained_modules import espnet.lm.pytorch_backend.extlm as extlm_pytorch +import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.asr.asr_utils import (CompareValueTrigger, adadelta_eps_decay, + add_results_to_json, format_mulenc_args, + get_model_conf, plot_spectrogram, + restore_snapshot, snapshot_object, + torch_load, torch_resume, torch_snapshot) +from espnet.asr.pytorch_backend.asr_init import (freeze_modules, + load_trained_model, + load_trained_modules) from espnet.nets.asr_interface import ASRInterface from espnet.nets.beam_search_transducer import BeamSearchTransducer from espnet.nets.pytorch_backend.e2e_asr import pad_list -import espnet.nets.pytorch_backend.lm.default as lm_pytorch from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E from espnet.transform.spectrogram import IStft from espnet.transform.transformation import Transformation from espnet.utils.cli_writers import file_writer_helper -from espnet.utils.dataset import ChainerDataLoader -from espnet.utils.dataset import TransformDataset +from espnet.utils.dataset import ChainerDataLoader, TransformDataset from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets @@ -52,8 +45,7 @@ from espnet.utils.training.evaluator import BaseEvaluator from espnet.utils.training.iterators import ShufflingEnabler from espnet.utils.training.tensorboard_logger import TensorboardLogger -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop +from espnet.utils.training.train_utils import check_early_stop, set_early_stop def _recursive_to(xs, device): @@ -508,7 +500,8 @@ def train(args): elif args.opt == "adam": optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay) elif args.opt == "noam": - from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + from espnet.nets.pytorch_backend.transformer.optimizer import \ + get_std_opt if "transducer" in mtl_mode: if args.noam_adim > 0: diff --git a/espnet/asr/pytorch_backend/asr_init.py b/espnet/asr/pytorch_backend/asr_init.py index 51bca5b7808..0a124ea437c 100644 --- a/espnet/asr/pytorch_backend/asr_init.py +++ b/espnet/asr/pytorch_backend/asr_init.py @@ -1,13 +1,13 @@ """Finetuning methods.""" -from collections import OrderedDict import logging import os import re +from collections import OrderedDict + import torch -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import get_model_conf, torch_load from espnet.nets.asr_interface import ASRInterface from espnet.nets.mt_interface import MTInterface from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load diff --git a/espnet/asr/pytorch_backend/asr_mix.py b/espnet/asr/pytorch_backend/asr_mix.py index 53208f16f8e..048e0b8b543 100644 --- a/espnet/asr/pytorch_backend/asr_mix.py +++ b/espnet/asr/pytorch_backend/asr_mix.py @@ -9,41 +9,33 @@ import json import logging import os +from itertools import zip_longest as zip_longest +import numpy as np +import torch # chainer related from chainer import training from chainer.training import extensions -from itertools import zip_longest as zip_longest -import numpy as np -import torch -from espnet.asr.asr_mix_utils import add_results_to_json -from espnet.asr.asr_utils import adadelta_eps_decay - -from espnet.asr.asr_utils import CompareValueTrigger -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import restore_snapshot -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot -from espnet.asr.pytorch_backend.asr import CustomEvaluator -from espnet.asr.pytorch_backend.asr import CustomUpdater -from espnet.asr.pytorch_backend.asr import load_trained_model import espnet.lm.pytorch_backend.extlm as extlm_pytorch +import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.asr.asr_mix_utils import add_results_to_json +from espnet.asr.asr_utils import (CompareValueTrigger, adadelta_eps_decay, + get_model_conf, restore_snapshot, + snapshot_object, torch_load, torch_resume, + torch_snapshot) +from espnet.asr.pytorch_backend.asr import (CustomEvaluator, CustomUpdater, + load_trained_model) from espnet.nets.asr_interface import ASRInterface from espnet.nets.pytorch_backend.e2e_asr_mix import pad_list -import espnet.nets.pytorch_backend.lm.default as lm_pytorch -from espnet.utils.dataset import ChainerDataLoader -from espnet.utils.dataset import TransformDataset +from espnet.utils.dataset import ChainerDataLoader, TransformDataset from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.iterators import ShufflingEnabler from espnet.utils.training.tensorboard_logger import TensorboardLogger -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop +from espnet.utils.training.train_utils import check_early_stop, set_early_stop class CustomConverter(object): @@ -225,7 +217,8 @@ def train(args): elif args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == "noam": - from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + from espnet.nets.pytorch_backend.transformer.optimizer import \ + get_std_opt optimizer = get_std_opt( model.parameters(), diff --git a/espnet/asr/pytorch_backend/recog.py b/espnet/asr/pytorch_backend/recog.py index 0824f6e7b26..2804b317a4d 100644 --- a/espnet/asr/pytorch_backend/recog.py +++ b/espnet/asr/pytorch_backend/recog.py @@ -2,13 +2,12 @@ import json import logging -from packaging.version import parse as V import torch +from packaging.version import parse as V -from espnet.asr.asr_utils import add_results_to_json -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import (add_results_to_json, get_model_conf, + torch_load) from espnet.asr.pytorch_backend.asr import load_trained_model from espnet.nets.asr_interface import ASRInterface from espnet.nets.batch_beam_search import BatchBeamSearch @@ -99,8 +98,7 @@ def recog_v2(args): lm = None if args.ngram_model: - from espnet.nets.scorers.ngram import NgramFullScorer - from espnet.nets.scorers.ngram import NgramPartScorer + from espnet.nets.scorers.ngram import NgramFullScorer, NgramPartScorer if args.ngram_scorer == "full": ngram = NgramFullScorer(args.ngram_model, train_args.char_list) diff --git a/espnet/bin/asr_align.py b/espnet/bin/asr_align.py index e1ba35ffaee..bd44fda99ef 100755 --- a/espnet/bin/asr_align.py +++ b/espnet/bin/asr_align.py @@ -41,23 +41,21 @@ with the option `--gratis-blank`. """ -import configargparse +import json import logging import os import sys +import configargparse +import torch +# imports for CTC segmentation +from ctc_segmentation import (CtcSegmentationParameters, ctc_segmentation, + determine_utterance_segments, prepare_text) + # imports for inference from espnet.asr.pytorch_backend.asr_init import load_trained_model from espnet.nets.asr_interface import ASRInterface from espnet.utils.io_utils import LoadInputsAndTargets -import json -import torch - -# imports for CTC segmentation -from ctc_segmentation import ctc_segmentation -from ctc_segmentation import CtcSegmentationParameters -from ctc_segmentation import determine_utterance_segments -from ctc_segmentation import prepare_text # NOTE: you need this func to generate our sphinx doc diff --git a/espnet/bin/asr_enhance.py b/espnet/bin/asr_enhance.py index 98f0d693caa..2cc33ac4c01 100755 --- a/espnet/bin/asr_enhance.py +++ b/espnet/bin/asr_enhance.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 -import configargparse -from distutils.util import strtobool import logging import os import random import sys +from distutils.util import strtobool +import configargparse import numpy as np from espnet.asr.pytorch_backend.asr import enhance diff --git a/espnet/bin/asr_recog.py b/espnet/bin/asr_recog.py index 3275ecf2243..d641ef4b822 100755 --- a/espnet/bin/asr_recog.py +++ b/espnet/bin/asr_recog.py @@ -6,12 +6,12 @@ """End-to-end speech recognition model decoding script.""" -import configargparse import logging import os import random import sys +import configargparse import numpy as np from espnet.utils.cli_utils import strtobool diff --git a/espnet/bin/mt_trans.py b/espnet/bin/mt_trans.py index c229f16d79f..7aa74ea62ba 100755 --- a/espnet/bin/mt_trans.py +++ b/espnet/bin/mt_trans.py @@ -6,12 +6,12 @@ """Neural machine translation model decoding script.""" -import configargparse import logging import os import random import sys +import configargparse import numpy as np diff --git a/espnet/bin/tts_decode.py b/espnet/bin/tts_decode.py index 71e53439c57..5ddc6ff0b30 100755 --- a/espnet/bin/tts_decode.py +++ b/espnet/bin/tts_decode.py @@ -5,12 +5,13 @@ """TTS decoding script.""" -import configargparse import logging import os import subprocess import sys +import configargparse + from espnet.utils.cli_utils import strtobool diff --git a/espnet/bin/vc_decode.py b/espnet/bin/vc_decode.py index 1802b76769f..319dde112ac 100755 --- a/espnet/bin/vc_decode.py +++ b/espnet/bin/vc_decode.py @@ -5,12 +5,13 @@ """VC decoding script.""" -import configargparse import logging import os import subprocess import sys +import configargparse + from espnet.utils.cli_utils import strtobool diff --git a/espnet/lm/chainer_backend/extlm.py b/espnet/lm/chainer_backend/extlm.py index 711e878c1d8..84051a69544 100644 --- a/espnet/lm/chainer_backend/extlm.py +++ b/espnet/lm/chainer_backend/extlm.py @@ -8,6 +8,7 @@ import chainer import chainer.functions as F + from espnet.lm.lm_utils import make_lexical_tree diff --git a/espnet/lm/chainer_backend/lm.py b/espnet/lm/chainer_backend/lm.py index 3cfcd6fd2d5..ce483663d60 100644 --- a/espnet/lm/chainer_backend/lm.py +++ b/espnet/lm/chainer_backend/lm.py @@ -10,40 +10,31 @@ import copy import json import logging -import numpy as np -import six import chainer -from chainer.dataset import convert import chainer.functions as F import chainer.links as L - +import numpy as np +import six +from chainer import link, reporter, training +from chainer.dataset import convert # for classifier link from chainer.functions.loss import softmax_cross_entropy -from chainer import link -from chainer import reporter -from chainer import training from chainer.training import extensions -from espnet.lm.lm_utils import compute_perplexity -from espnet.lm.lm_utils import count_tokens -from espnet.lm.lm_utils import MakeSymlinkToBestModel -from espnet.lm.lm_utils import ParallelSentenceIterator -from espnet.lm.lm_utils import read_tokens - import espnet.nets.chainer_backend.deterministic_embed_id as DL +from espnet.lm.lm_utils import (MakeSymlinkToBestModel, + ParallelSentenceIterator, compute_perplexity, + count_tokens, read_tokens) from espnet.nets.lm_interface import LMInterface from espnet.optimizer.factory import dynamic_import_optimizer from espnet.scheduler.chainer import ChainerScheduler from espnet.scheduler.scheduler import dynamic_import_scheduler - -from espnet.utils.training.tensorboard_logger import TensorboardLogger - from espnet.utils.deterministic_utils import set_deterministic_chainer from espnet.utils.training.evaluator import BaseEvaluator from espnet.utils.training.iterators import ShufflingEnabler -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop, set_early_stop # TODO(karita): reimplement RNNLM with new interface diff --git a/espnet/lm/lm_utils.py b/espnet/lm/lm_utils.py index bb43e5de0e7..273aeabead9 100644 --- a/espnet/lm/lm_utils.py +++ b/espnet/lm/lm_utils.py @@ -6,16 +6,16 @@ # This code is ported from the following implementation written in Torch. # https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py -import chainer -import h5py import logging -import numpy as np import os import random -import six -from tqdm import tqdm +import chainer +import h5py +import numpy as np +import six from chainer.training import extension +from tqdm import tqdm def load_dataset(path, label_dict, outdir=None): diff --git a/espnet/lm/pytorch_backend/lm.py b/espnet/lm/pytorch_backend/lm.py index 2b4efe529f7..8c820531e0d 100644 --- a/espnet/lm/pytorch_backend/lm.py +++ b/espnet/lm/pytorch_backend/lm.py @@ -9,41 +9,29 @@ import copy import json import logging -import numpy as np +import numpy as np import torch import torch.nn as nn -from torch.nn.parallel import data_parallel - -from chainer import Chain +from chainer import Chain, reporter, training from chainer.dataset import convert -from chainer import reporter -from chainer import training from chainer.training import extensions +from torch.nn.parallel import data_parallel -from espnet.lm.lm_utils import count_tokens -from espnet.lm.lm_utils import load_dataset -from espnet.lm.lm_utils import MakeSymlinkToBestModel -from espnet.lm.lm_utils import ParallelSentenceIterator -from espnet.lm.lm_utils import read_tokens -from espnet.nets.lm_interface import dynamic_import_lm -from espnet.nets.lm_interface import LMInterface +from espnet.asr.asr_utils import (snapshot_object, torch_load, torch_resume, + torch_snapshot) +from espnet.lm.lm_utils import (MakeSymlinkToBestModel, + ParallelSentenceIterator, count_tokens, + load_dataset, read_tokens) +from espnet.nets.lm_interface import LMInterface, dynamic_import_lm from espnet.optimizer.factory import dynamic_import_optimizer from espnet.scheduler.pytorch import PyTorchScheduler from espnet.scheduler.scheduler import dynamic_import_scheduler - -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot - -from espnet.utils.training.tensorboard_logger import TensorboardLogger - from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.training.evaluator import BaseEvaluator from espnet.utils.training.iterators import ShufflingEnabler -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop, set_early_stop def compute_perplexity(result): diff --git a/espnet/mt/pytorch_backend/mt.py b/espnet/mt/pytorch_backend/mt.py index 47f5b817b03..372c11e4c03 100644 --- a/espnet/mt/pytorch_backend/mt.py +++ b/espnet/mt/pytorch_backend/mt.py @@ -11,36 +11,27 @@ import logging import os -from chainer import training -from chainer.training import extensions import numpy as np import torch +from chainer import training +from chainer.training import extensions -from espnet.asr.asr_utils import adadelta_eps_decay -from espnet.asr.asr_utils import adam_lr_decay -from espnet.asr.asr_utils import add_results_to_json -from espnet.asr.asr_utils import CompareValueTrigger -from espnet.asr.asr_utils import restore_snapshot -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.asr_utils import (CompareValueTrigger, adadelta_eps_decay, + adam_lr_decay, add_results_to_json, + restore_snapshot, snapshot_object, + torch_load, torch_resume, torch_snapshot) +from espnet.asr.pytorch_backend.asr import (CustomEvaluator, CustomUpdater, + load_trained_model) from espnet.nets.mt_interface import MTInterface from espnet.nets.pytorch_backend.e2e_asr import pad_list -from espnet.utils.dataset import ChainerDataLoader -from espnet.utils.dataset import TransformDataset +from espnet.utils.dataset import ChainerDataLoader, TransformDataset from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.iterators import ShufflingEnabler from espnet.utils.training.tensorboard_logger import TensorboardLogger -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop - -from espnet.asr.pytorch_backend.asr import CustomEvaluator -from espnet.asr.pytorch_backend.asr import CustomUpdater -from espnet.asr.pytorch_backend.asr import load_trained_model +from espnet.utils.training.train_utils import check_early_stop, set_early_stop class CustomConverter(object): @@ -163,7 +154,8 @@ def train(args): model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) elif args.opt == "noam": - from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + from espnet.nets.pytorch_backend.transformer.optimizer import \ + get_std_opt optimizer = get_std_opt( model.parameters(), diff --git a/espnet/nets/batch_beam_search.py b/espnet/nets/batch_beam_search.py index 9418fadea46..f31d876a934 100644 --- a/espnet/nets/batch_beam_search.py +++ b/espnet/nets/batch_beam_search.py @@ -1,17 +1,12 @@ """Parallel beam search module.""" import logging -from typing import Any -from typing import Dict -from typing import List -from typing import NamedTuple -from typing import Tuple +from typing import Any, Dict, List, NamedTuple, Tuple import torch from torch.nn.utils.rnn import pad_sequence -from espnet.nets.beam_search import BeamSearch -from espnet.nets.beam_search import Hypothesis +from espnet.nets.beam_search import BeamSearch, Hypothesis class BatchHypothesis(NamedTuple): diff --git a/espnet/nets/batch_beam_search_online.py b/espnet/nets/batch_beam_search_online.py index 9190a09144a..cdb23dd9e89 100644 --- a/espnet/nets/batch_beam_search_online.py +++ b/espnet/nets/batch_beam_search_online.py @@ -1,19 +1,17 @@ """Parallel beam search module for online simulation.""" -from espnet.nets.batch_beam_search import ( - BatchBeamSearch, # noqa: H301 - BatchHypothesis, # noqa: H301 -) -from espnet.nets.beam_search import Hypothesis -from espnet.nets.e2e_asr_common import end_detect import logging +from typing import Any # noqa: H301 +from typing import Dict # noqa: H301 +from typing import List # noqa: H301 +from typing import Tuple # noqa: H301 + import torch -from typing import ( - List, # noqa: H301 - Tuple, # noqa: H301 - Dict, # noqa: H301 - Any, # noqa: H301 -) + +from espnet.nets.batch_beam_search import BatchBeamSearch # noqa: H301 +from espnet.nets.batch_beam_search import BatchHypothesis # noqa: H301 +from espnet.nets.beam_search import Hypothesis +from espnet.nets.e2e_asr_common import end_detect class BatchBeamSearchOnline(BatchBeamSearch): diff --git a/espnet/nets/batch_beam_search_online_sim.py b/espnet/nets/batch_beam_search_online_sim.py index 2c0ecf3bfb1..f65e7e1025b 100644 --- a/espnet/nets/batch_beam_search_online_sim.py +++ b/espnet/nets/batch_beam_search_online_sim.py @@ -4,9 +4,8 @@ from pathlib import Path from typing import List -import yaml - import torch +import yaml from espnet.nets.batch_beam_search import BatchBeamSearch from espnet.nets.beam_search import Hypothesis diff --git a/espnet/nets/beam_search.py b/espnet/nets/beam_search.py index 0f33d8c63bf..558f1dd5af7 100644 --- a/espnet/nets/beam_search.py +++ b/espnet/nets/beam_search.py @@ -1,19 +1,14 @@ """Beam search module.""" -from itertools import chain import logging -from typing import Any -from typing import Dict -from typing import List -from typing import NamedTuple -from typing import Tuple -from typing import Union +from itertools import chain +from typing import Any, Dict, List, NamedTuple, Tuple, Union import torch from espnet.nets.e2e_asr_common import end_detect -from espnet.nets.scorer_interface import PartialScorerInterface -from espnet.nets.scorer_interface import ScorerInterface +from espnet.nets.scorer_interface import (PartialScorerInterface, + ScorerInterface) class Hypothesis(NamedTuple): diff --git a/espnet/nets/beam_search_transducer.py b/espnet/nets/beam_search_transducer.py index a14dcd8618a..0983517ed6a 100644 --- a/espnet/nets/beam_search_transducer.py +++ b/espnet/nets/beam_search_transducer.py @@ -1,8 +1,7 @@ """Search algorithms for Transducer models.""" import logging -from typing import List -from typing import Union +from typing import List, Union import numpy as np import torch @@ -10,15 +9,11 @@ from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder -from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_states -from espnet.nets.pytorch_backend.transducer.utils import init_lm_state -from espnet.nets.pytorch_backend.transducer.utils import is_prefix -from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps -from espnet.nets.pytorch_backend.transducer.utils import select_k_expansions -from espnet.nets.pytorch_backend.transducer.utils import select_lm_state -from espnet.nets.pytorch_backend.transducer.utils import subtract -from espnet.nets.transducer_decoder_interface import ExtendedHypothesis -from espnet.nets.transducer_decoder_interface import Hypothesis +from espnet.nets.pytorch_backend.transducer.utils import ( + create_lm_batch_states, init_lm_state, is_prefix, recombine_hyps, + select_k_expansions, select_lm_state, subtract) +from espnet.nets.transducer_decoder_interface import (ExtendedHypothesis, + Hypothesis) class BeamSearchTransducer: diff --git a/espnet/nets/chainer_backend/ctc.py b/espnet/nets/chainer_backend/ctc.py index f1788df4c74..878f90d8834 100644 --- a/espnet/nets/chainer_backend/ctc.py +++ b/espnet/nets/chainer_backend/ctc.py @@ -1,10 +1,10 @@ import logging import chainer -from chainer import cuda import chainer.functions as F import chainer.links as L import numpy as np +from chainer import cuda class CTC(chainer.Chain): diff --git a/espnet/nets/chainer_backend/deterministic_embed_id.py b/espnet/nets/chainer_backend/deterministic_embed_id.py index 22bc3e3b3ae..b0889e87935 100644 --- a/espnet/nets/chainer_backend/deterministic_embed_id.py +++ b/espnet/nets/chainer_backend/deterministic_embed_id.py @@ -1,15 +1,10 @@ +import chainer import numpy import six - -import chainer -from chainer import cuda -from chainer import function_node -from chainer.initializers import normal - # from chainer.functions.connection import embed_id -from chainer import link +from chainer import cuda, function_node, link, variable +from chainer.initializers import normal from chainer.utils import type_check -from chainer import variable """Deterministic EmbedID link and function diff --git a/espnet/nets/chainer_backend/e2e_asr.py b/espnet/nets/chainer_backend/e2e_asr.py index eb3a9a37f98..cd5f22b8c94 100644 --- a/espnet/nets/chainer_backend/e2e_asr.py +++ b/espnet/nets/chainer_backend/e2e_asr.py @@ -7,8 +7,8 @@ import math import chainer -from chainer import reporter import numpy as np +from chainer import reporter from espnet.nets.chainer_backend.asr_interface import ChainerASRInterface from espnet.nets.chainer_backend.ctc import ctc_for @@ -215,7 +215,8 @@ def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1): @staticmethod def custom_parallel_updater(iters, optimizer, converter, devices, accum_grad=1): """Get custom_parallel_updater of the model.""" - from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater + from espnet.nets.chainer_backend.rnn.training import \ + CustomParallelUpdater return CustomParallelUpdater( iters, diff --git a/espnet/nets/chainer_backend/e2e_asr_transformer.py b/espnet/nets/chainer_backend/e2e_asr_transformer.py index 07c63d23697..d3bdfeb78b7 100644 --- a/espnet/nets/chainer_backend/e2e_asr_transformer.py +++ b/espnet/nets/chainer_backend/e2e_asr_transformer.py @@ -1,37 +1,34 @@ # encoding: utf-8 """Transformer-based model for End-to-end ASR.""" -from argparse import Namespace -from distutils.util import strtobool import logging import math +from argparse import Namespace +from distutils.util import strtobool import chainer import chainer.functions as F -from chainer import reporter import numpy as np import six +from chainer import reporter from espnet.nets.chainer_backend.asr_interface import ChainerASRInterface -from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention from espnet.nets.chainer_backend.transformer import ctc +from espnet.nets.chainer_backend.transformer.attention import \ + MultiHeadAttention from espnet.nets.chainer_backend.transformer.decoder import Decoder from espnet.nets.chainer_backend.transformer.encoder import Encoder -from espnet.nets.chainer_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from espnet.nets.chainer_backend.transformer.training import CustomConverter -from espnet.nets.chainer_backend.transformer.training import CustomUpdater -from espnet.nets.chainer_backend.transformer.training import ( - CustomParallelUpdater, # noqa: H301 -) +from espnet.nets.chainer_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 +from espnet.nets.chainer_backend.transformer.training import \ + CustomParallelUpdater # noqa: H301 +from espnet.nets.chainer_backend.transformer.training import (CustomConverter, + CustomUpdater) from espnet.nets.ctc_prefix_score import CTCPrefixScore -from espnet.nets.e2e_asr_common import end_detect -from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.e2e_asr_common import ErrorCalculator, end_detect from espnet.nets.pytorch_backend.nets_utils import get_subsample from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport - CTC_SCORING_RATIO = 1.5 MAX_DECODER_OUTPUT = 5 diff --git a/espnet/nets/chainer_backend/rnn/attentions.py b/espnet/nets/chainer_backend/rnn/attentions.py index e9a776e5b2e..c4343256ddc 100644 --- a/espnet/nets/chainer_backend/rnn/attentions.py +++ b/espnet/nets/chainer_backend/rnn/attentions.py @@ -1,7 +1,6 @@ import chainer import chainer.functions as F import chainer.links as L - import numpy as np diff --git a/espnet/nets/chainer_backend/rnn/decoders.py b/espnet/nets/chainer_backend/rnn/decoders.py index 308f509a8b3..5cba1266f40 100644 --- a/espnet/nets/chainer_backend/rnn/decoders.py +++ b/espnet/nets/chainer_backend/rnn/decoders.py @@ -1,16 +1,14 @@ import logging import random -import six +from argparse import Namespace import chainer import chainer.functions as F import chainer.links as L import numpy as np +import six import espnet.nets.chainer_backend.deterministic_embed_id as DL - -from argparse import Namespace - from espnet.nets.ctc_prefix_score import CTCPrefixScore from espnet.nets.e2e_asr_common import end_detect diff --git a/espnet/nets/chainer_backend/rnn/encoders.py b/espnet/nets/chainer_backend/rnn/encoders.py index 0590ccf8108..aa064304e28 100644 --- a/espnet/nets/chainer_backend/rnn/encoders.py +++ b/espnet/nets/chainer_backend/rnn/encoders.py @@ -1,11 +1,10 @@ import logging -import six import chainer import chainer.functions as F import chainer.links as L import numpy as np - +import six from chainer import cuda from espnet.nets.chainer_backend.nets_utils import _subsamplex diff --git a/espnet/nets/chainer_backend/rnn/training.py b/espnet/nets/chainer_backend/rnn/training.py index bbc37d681a1..981dacd1884 100644 --- a/espnet/nets/chainer_backend/rnn/training.py +++ b/espnet/nets/chainer_backend/rnn/training.py @@ -5,18 +5,13 @@ import collections import logging import math -import six - -# chainer related -from chainer import cuda -from chainer import training -from chainer import Variable - -from chainer.training.updaters.multiprocess_parallel_updater import gather_grads -from chainer.training.updaters.multiprocess_parallel_updater import gather_params -from chainer.training.updaters.multiprocess_parallel_updater import scatter_grads import numpy as np +import six +# chainer related +from chainer import Variable, cuda, training +from chainer.training.updaters.multiprocess_parallel_updater import ( + gather_grads, gather_params, scatter_grads) # copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py diff --git a/espnet/nets/chainer_backend/transformer/attention.py b/espnet/nets/chainer_backend/transformer/attention.py index d26d82fb10f..a79f844e288 100644 --- a/espnet/nets/chainer_backend/transformer/attention.py +++ b/espnet/nets/chainer_backend/transformer/attention.py @@ -2,10 +2,8 @@ """Class Declaration of Transformer's Attention.""" import chainer - import chainer.functions as F import chainer.links as L - import numpy as np MIN_VALUE = float(np.finfo(np.float32).min) diff --git a/espnet/nets/chainer_backend/transformer/decoder.py b/espnet/nets/chainer_backend/transformer/decoder.py index 75c3a7ef410..eae84053dbe 100644 --- a/espnet/nets/chainer_backend/transformer/decoder.py +++ b/espnet/nets/chainer_backend/transformer/decoder.py @@ -2,17 +2,16 @@ """Class Declaration of Transformer's Decoder.""" import chainer - import chainer.functions as F import chainer.links as L +import numpy as np from espnet.nets.chainer_backend.transformer.decoder_layer import DecoderLayer -from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding +from espnet.nets.chainer_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm from espnet.nets.chainer_backend.transformer.mask import make_history_mask -import numpy as np - class Decoder(chainer.Chain): """Decoder layer. diff --git a/espnet/nets/chainer_backend/transformer/decoder_layer.py b/espnet/nets/chainer_backend/transformer/decoder_layer.py index 933290049c2..0223def15b3 100644 --- a/espnet/nets/chainer_backend/transformer/decoder_layer.py +++ b/espnet/nets/chainer_backend/transformer/decoder_layer.py @@ -2,14 +2,13 @@ """Class Declaration of Transformer's Decoder Block.""" import chainer - import chainer.functions as F -from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention +from espnet.nets.chainer_backend.transformer.attention import \ + MultiHeadAttention from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm -from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 class DecoderLayer(chainer.Chain): diff --git a/espnet/nets/chainer_backend/transformer/embedding.py b/espnet/nets/chainer_backend/transformer/embedding.py index d838c085dad..35d3b10b798 100644 --- a/espnet/nets/chainer_backend/transformer/embedding.py +++ b/espnet/nets/chainer_backend/transformer/embedding.py @@ -3,7 +3,6 @@ import chainer import chainer.functions as F - import numpy as np diff --git a/espnet/nets/chainer_backend/transformer/encoder.py b/espnet/nets/chainer_backend/transformer/encoder.py index c0a8e7e64e7..742dd0dca30 100644 --- a/espnet/nets/chainer_backend/transformer/encoder.py +++ b/espnet/nets/chainer_backend/transformer/encoder.py @@ -1,19 +1,19 @@ # encoding: utf-8 """Class Declaration of Transformer's Encoder.""" -import chainer +import logging +import chainer +import numpy as np from chainer import links as L -from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding +from espnet.nets.chainer_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.chainer_backend.transformer.encoder_layer import EncoderLayer from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm from espnet.nets.chainer_backend.transformer.mask import make_history_mask -from espnet.nets.chainer_backend.transformer.subsampling import Conv2dSubsampling -from espnet.nets.chainer_backend.transformer.subsampling import LinearSampling - -import logging -import numpy as np +from espnet.nets.chainer_backend.transformer.subsampling import ( + Conv2dSubsampling, LinearSampling) class Encoder(chainer.Chain): diff --git a/espnet/nets/chainer_backend/transformer/encoder_layer.py b/espnet/nets/chainer_backend/transformer/encoder_layer.py index b742ef34ec3..a1daa60e02a 100644 --- a/espnet/nets/chainer_backend/transformer/encoder_layer.py +++ b/espnet/nets/chainer_backend/transformer/encoder_layer.py @@ -2,14 +2,13 @@ """Class Declaration of Transformer's Encoder Block.""" import chainer - import chainer.functions as F -from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention +from espnet.nets.chainer_backend.transformer.attention import \ + MultiHeadAttention from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm -from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 class EncoderLayer(chainer.Chain): diff --git a/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py b/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py index 5aebc58a625..c81d819e7eb 100644 --- a/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py +++ b/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py @@ -4,7 +4,6 @@ import logging import chainer - import chainer.functions as F diff --git a/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py b/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py index f6d5a7c1a46..f69fcc6cc76 100644 --- a/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py +++ b/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py @@ -2,10 +2,8 @@ """Class Declaration of Transformer's Positionwise Feedforward.""" import chainer - import chainer.functions as F import chainer.links as L - import numpy as np diff --git a/espnet/nets/chainer_backend/transformer/subsampling.py b/espnet/nets/chainer_backend/transformer/subsampling.py index 0ba486c871f..03f767d303e 100644 --- a/espnet/nets/chainer_backend/transformer/subsampling.py +++ b/espnet/nets/chainer_backend/transformer/subsampling.py @@ -1,16 +1,16 @@ # encoding: utf-8 """Class Declaration of Transformer's Input layers.""" -import chainer +import logging +import chainer import chainer.functions as F import chainer.links as L - -from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding - -import logging import numpy as np +from espnet.nets.chainer_backend.transformer.embedding import \ + PositionalEncoding + class Conv2dSubsampling(chainer.Chain): """Convolutional 2D subsampling (to 1/4 length). diff --git a/espnet/nets/chainer_backend/transformer/training.py b/espnet/nets/chainer_backend/transformer/training.py index e6a98651f36..2ac4146c1bd 100644 --- a/espnet/nets/chainer_backend/transformer/training.py +++ b/espnet/nets/chainer_backend/transformer/training.py @@ -4,16 +4,15 @@ import collections import logging import math -import six +import numpy as np +import six from chainer import cuda from chainer import functions as F from chainer import training from chainer.training import extension -from chainer.training.updaters.multiprocess_parallel_updater import gather_grads -from chainer.training.updaters.multiprocess_parallel_updater import gather_params -from chainer.training.updaters.multiprocess_parallel_updater import scatter_grads -import numpy as np +from chainer.training.updaters.multiprocess_parallel_updater import ( + gather_grads, gather_params, scatter_grads) # copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py diff --git a/espnet/nets/ctc_prefix_score.py b/espnet/nets/ctc_prefix_score.py index 0c67ecd096d..9b4cf94ed18 100644 --- a/espnet/nets/ctc_prefix_score.py +++ b/espnet/nets/ctc_prefix_score.py @@ -3,10 +3,9 @@ # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import torch - import numpy as np import six +import torch class CTCPrefixScoreTH(object): diff --git a/espnet/nets/e2e_asr_common.py b/espnet/nets/e2e_asr_common.py index 92f90796a3a..40483e4d5d1 100644 --- a/espnet/nets/e2e_asr_common.py +++ b/espnet/nets/e2e_asr_common.py @@ -9,8 +9,8 @@ import json import logging import sys - from itertools import groupby + import numpy as np import six diff --git a/espnet/nets/pytorch_backend/conformer/argument.py b/espnet/nets/pytorch_backend/conformer/argument.py index d5681565256..03861e4f79a 100644 --- a/espnet/nets/pytorch_backend/conformer/argument.py +++ b/espnet/nets/pytorch_backend/conformer/argument.py @@ -4,8 +4,8 @@ """Conformer common arguments.""" -from distutils.util import strtobool import logging +from distutils.util import strtobool def add_arguments_conformer_common(group): diff --git a/espnet/nets/pytorch_backend/conformer/contextual_block_encoder_layer.py b/espnet/nets/pytorch_backend/conformer/contextual_block_encoder_layer.py index 6f02e5ef151..b4f0bd0933c 100644 --- a/espnet/nets/pytorch_backend/conformer/contextual_block_encoder_layer.py +++ b/espnet/nets/pytorch_backend/conformer/contextual_block_encoder_layer.py @@ -5,10 +5,11 @@ @author: Keqi Deng (UCAS) """ -from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm import torch from torch import nn +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + class ContextualBlockEncoderLayer(nn.Module): """Contexutal Block Encoder layer module. diff --git a/espnet/nets/pytorch_backend/conformer/encoder.py b/espnet/nets/pytorch_backend/conformer/encoder.py index 515cf7e3f7c..4495c2e0d78 100644 --- a/espnet/nets/pytorch_backend/conformer/encoder.py +++ b/espnet/nets/pytorch_backend/conformer/encoder.py @@ -5,31 +5,35 @@ """Encoder definition.""" import logging + import torch from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer from espnet.nets.pytorch_backend.nets_utils import get_activation from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L -from espnet.nets.pytorch_backend.transformer.attention import ( - MultiHeadedAttention, # noqa: H301 - RelPositionMultiHeadedAttention, # noqa: H301 - LegacyRelPositionMultiHeadedAttention, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.embedding import ( - PositionalEncoding, # noqa: H301 - ScaledPositionalEncoding, # noqa: H301 - RelPositionalEncoding, # noqa: H301 - LegacyRelPositionalEncoding, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.attention import \ + LegacyRelPositionMultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + RelPositionMultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + LegacyRelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + RelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + ScaledPositionalEncoding # noqa: H301 from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import \ + Conv2dSubsampling class Encoder(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/conformer/encoder_layer.py b/espnet/nets/pytorch_backend/conformer/encoder_layer.py index bc620261aee..294ccb6c538 100644 --- a/espnet/nets/pytorch_backend/conformer/encoder_layer.py +++ b/espnet/nets/pytorch_backend/conformer/encoder_layer.py @@ -8,7 +8,6 @@ """Encoder self-attention layer definition.""" import torch - from torch import nn from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm diff --git a/espnet/nets/pytorch_backend/ctc.py b/espnet/nets/pytorch_backend/ctc.py index 96b2e4f52b9..d9bcece8866 100644 --- a/espnet/nets/pytorch_backend/ctc.py +++ b/espnet/nets/pytorch_backend/ctc.py @@ -1,10 +1,10 @@ import logging -from packaging.version import parse as V import numpy as np import six import torch import torch.nn.functional as F +from packaging.version import parse as V from espnet.nets.pytorch_backend.nets_utils import to_device diff --git a/espnet/nets/pytorch_backend/e2e_asr.py b/espnet/nets/pytorch_backend/e2e_asr.py index 0008e84d4c4..578ecf7480e 100644 --- a/espnet/nets/pytorch_backend/e2e_asr.py +++ b/espnet/nets/pytorch_backend/e2e_asr.py @@ -4,35 +4,33 @@ """RNN sequence-to-sequence speech recognition model (pytorch).""" import argparse -from itertools import groupby import logging import math import os +from itertools import groupby import chainer -from chainer import reporter import numpy as np import six import torch +from chainer import reporter from espnet.nets.asr_interface import ASRInterface from espnet.nets.e2e_asr_common import label_smoothing_dist from espnet.nets.pytorch_backend.ctc import ctc_for -from espnet.nets.pytorch_backend.frontends.feature_transform import ( - feature_transform_for, # noqa: H301 -) +from espnet.nets.pytorch_backend.frontends.feature_transform import \ + feature_transform_for # noqa: H301 from espnet.nets.pytorch_backend.frontends.frontend import frontend_for -from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters -from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import to_device -from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor -from espnet.nets.pytorch_backend.rnn.argument import ( - add_arguments_rnn_encoder_common, # noqa: H301 - add_arguments_rnn_decoder_common, # noqa: H301 - add_arguments_rnn_attention_common, # noqa: H301 -) +from espnet.nets.pytorch_backend.initialization import ( + lecun_normal_init_parameters, set_forget_bias_to_one) +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, pad_list, + to_device, to_torch_tensor) +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_attention_common # noqa: H301 +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_decoder_common # noqa: H301 +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_encoder_common # noqa: H301 from espnet.nets.pytorch_backend.rnn.attentions import att_for from espnet.nets.pytorch_backend.rnn.decoders import decoder_for from espnet.nets.pytorch_backend.rnn.encoders import encoder_for diff --git a/espnet/nets/pytorch_backend/e2e_asr_conformer.py b/espnet/nets/pytorch_backend/e2e_asr_conformer.py index 4bcbad139e8..ae849e4e827 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_conformer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_conformer.py @@ -10,12 +10,13 @@ """ +from espnet.nets.pytorch_backend.conformer.argument import \ + add_arguments_conformer_common # noqa: H301 +from espnet.nets.pytorch_backend.conformer.argument import \ + verify_rel_pos_type # noqa: H301 from espnet.nets.pytorch_backend.conformer.encoder import Encoder -from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer -from espnet.nets.pytorch_backend.conformer.argument import ( - add_arguments_conformer_common, # noqa: H301 - verify_rel_pos_type, # noqa: H301 -) +from espnet.nets.pytorch_backend.e2e_asr_transformer import \ + E2E as E2ETransformer class E2E(E2ETransformer): diff --git a/espnet/nets/pytorch_backend/e2e_asr_maskctc.py b/espnet/nets/pytorch_backend/e2e_asr_maskctc.py index 7e7f6c3312d..3f04ea08b51 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_maskctc.py +++ b/espnet/nets/pytorch_backend/e2e_asr_maskctc.py @@ -9,24 +9,24 @@ """ -from itertools import groupby import logging import math - from distutils.util import strtobool +from itertools import groupby + import numpy import torch +from espnet.nets.pytorch_backend.conformer.argument import \ + add_arguments_conformer_common # noqa: H301 from espnet.nets.pytorch_backend.conformer.encoder import Encoder -from espnet.nets.pytorch_backend.conformer.argument import ( - add_arguments_conformer_common, # noqa: H301 -) from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD -from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer +from espnet.nets.pytorch_backend.e2e_asr_transformer import \ + E2E as E2ETransformer from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform from espnet.nets.pytorch_backend.maskctc.mask import square_mask -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + th_accuracy) class E2E(E2ETransformer): diff --git a/espnet/nets/pytorch_backend/e2e_asr_mix.py b/espnet/nets/pytorch_backend/e2e_asr_mix.py index 377aabe5162..518d665c067 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_mix.py +++ b/espnet/nets/pytorch_backend/e2e_asr_mix.py @@ -8,38 +8,34 @@ """ import argparse -from itertools import groupby import logging import math import os import sys +from itertools import groupby import numpy as np import six import torch from espnet.nets.asr_interface import ASRInterface -from espnet.nets.e2e_asr_common import get_vgg2l_odim -from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.e2e_asr_common import get_vgg2l_odim, label_smoothing_dist from espnet.nets.pytorch_backend.ctc import ctc_for from espnet.nets.pytorch_backend.e2e_asr import E2E as E2EASR from espnet.nets.pytorch_backend.e2e_asr import Reporter -from espnet.nets.pytorch_backend.frontends.feature_transform import ( - feature_transform_for, # noqa: H301 -) +from espnet.nets.pytorch_backend.frontends.feature_transform import \ + feature_transform_for # noqa: H301 from espnet.nets.pytorch_backend.frontends.frontend import frontend_for -from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters -from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import to_device -from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor +from espnet.nets.pytorch_backend.initialization import ( + lecun_normal_init_parameters, set_forget_bias_to_one) +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, + make_pad_mask, pad_list, + to_device, to_torch_tensor) from espnet.nets.pytorch_backend.rnn.attentions import att_for from espnet.nets.pytorch_backend.rnn.decoders import decoder_for -from espnet.nets.pytorch_backend.rnn.encoders import encoder_for as encoder_for_single -from espnet.nets.pytorch_backend.rnn.encoders import RNNP -from espnet.nets.pytorch_backend.rnn.encoders import VGG2L +from espnet.nets.pytorch_backend.rnn.encoders import RNNP, VGG2L +from espnet.nets.pytorch_backend.rnn.encoders import \ + encoder_for as encoder_for_single CTC_LOSS_THRESHOLD = 10000 diff --git a/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py index 4622e9214ae..186000ec945 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py @@ -16,9 +16,9 @@ 2. PIT is used in CTC to determine the permutation with minimum loss. """ -from argparse import Namespace import logging import math +from argparse import Namespace import numpy import torch @@ -31,13 +31,13 @@ from espnet.nets.pytorch_backend.e2e_asr_mix import E2E as E2EASRMIX from espnet.nets.pytorch_backend.e2e_asr_mix import PIT from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2EASR -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + th_accuracy) from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos from espnet.nets.pytorch_backend.transformer.encoder_mix import EncoderMix -from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.mask import (subsequent_mask, + target_mask) class E2E(E2EASR, ASRInterface, torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/e2e_asr_mulenc.py b/espnet/nets/pytorch_backend/e2e_asr_mulenc.py index 3e7f78366da..9f38b74bdf8 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_mulenc.py +++ b/espnet/nets/pytorch_backend/e2e_asr_mulenc.py @@ -5,27 +5,24 @@ """Define e2e module for multi-encoder network. https://arxiv.org/pdf/1811.04903.pdf.""" import argparse -from itertools import groupby import logging import math import os +from itertools import groupby import chainer -from chainer import reporter import numpy as np import torch +from chainer import reporter from espnet.nets.asr_interface import ASRInterface from espnet.nets.e2e_asr_common import label_smoothing_dist from espnet.nets.pytorch_backend.ctc import ctc_for -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import to_device -from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, pad_list, + to_device, to_torch_tensor) from espnet.nets.pytorch_backend.rnn.attentions import att_for from espnet.nets.pytorch_backend.rnn.decoders import decoder_for -from espnet.nets.pytorch_backend.rnn.encoders import Encoder -from espnet.nets.pytorch_backend.rnn.encoders import encoder_for +from espnet.nets.pytorch_backend.rnn.encoders import Encoder, encoder_for from espnet.nets.scorers.ctc import CTCPrefixScorer from espnet.utils.cli_utils import strtobool diff --git a/espnet/nets/pytorch_backend/e2e_asr_transducer.py b/espnet/nets/pytorch_backend/e2e_asr_transducer.py index 9ce0cb45dca..bb3aa98ffcb 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_transducer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_transducer.py @@ -1,44 +1,52 @@ """Transducer speech recognition model (pytorch).""" -from argparse import ArgumentParser -from argparse import Namespace -from dataclasses import asdict import logging import math -import numpy +from argparse import ArgumentParser, Namespace +from dataclasses import asdict from typing import List import chainer +import numpy import torch from espnet.nets.asr_interface import ASRInterface from espnet.nets.beam_search_transducer import BeamSearchTransducer -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.transducer.arguments import ( - add_auxiliary_task_arguments, # noqa: H301 - add_custom_decoder_arguments, # noqa: H301 - add_custom_encoder_arguments, # noqa: H301 - add_custom_training_arguments, # noqa: H301 - add_decoder_general_arguments, # noqa: H301 - add_encoder_general_arguments, # noqa: H301 - add_rnn_decoder_arguments, # noqa: H301 - add_rnn_encoder_arguments, # noqa: H301 - add_transducer_arguments, # noqa: H301 -) +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, + make_non_pad_mask) +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_auxiliary_task_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_custom_decoder_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_custom_encoder_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_custom_training_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_decoder_general_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_encoder_general_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_rnn_decoder_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_rnn_encoder_arguments # noqa: H301 +from espnet.nets.pytorch_backend.transducer.arguments import \ + add_transducer_arguments # noqa: H301 from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder from espnet.nets.pytorch_backend.transducer.custom_encoder import CustomEncoder -from espnet.nets.pytorch_backend.transducer.error_calculator import ErrorCalculator +from espnet.nets.pytorch_backend.transducer.error_calculator import \ + ErrorCalculator from espnet.nets.pytorch_backend.transducer.initializer import initializer from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder from espnet.nets.pytorch_backend.transducer.rnn_encoder import encoder_for -from espnet.nets.pytorch_backend.transducer.transducer_tasks import TransducerTasks -from espnet.nets.pytorch_backend.transducer.utils import get_decoder_input -from espnet.nets.pytorch_backend.transducer.utils import valid_aux_encoder_output_layers -from espnet.nets.pytorch_backend.transformer.attention import ( - MultiHeadedAttention, # noqa: H301 - RelPositionMultiHeadedAttention, # noqa: H301 -) +from espnet.nets.pytorch_backend.transducer.transducer_tasks import \ + TransducerTasks +from espnet.nets.pytorch_backend.transducer.utils import ( + get_decoder_input, valid_aux_encoder_output_layers) +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + RelPositionMultiHeadedAttention # noqa: H301 from espnet.nets.pytorch_backend.transformer.mask import target_mask from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport from espnet.utils.fill_missing_args import fill_missing_args diff --git a/espnet/nets/pytorch_backend/e2e_asr_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_transformer.py index b13c7e452b6..a114bfbef9d 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_transformer.py @@ -3,42 +3,40 @@ """Transformer speech recognition model (pytorch).""" -from argparse import Namespace import logging import math +from argparse import Namespace import numpy import torch from espnet.nets.asr_interface import ASRInterface from espnet.nets.ctc_prefix_score import CTCPrefixScore -from espnet.nets.e2e_asr_common import end_detect -from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.e2e_asr_common import ErrorCalculator, end_detect from espnet.nets.pytorch_backend.ctc import CTC -from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD -from espnet.nets.pytorch_backend.e2e_asr import Reporter -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD, Reporter +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, + make_non_pad_mask, + th_accuracy) from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.argument import ( - add_arguments_transformer_common, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.attention import ( - MultiHeadedAttention, # noqa: H301 - RelPositionMultiHeadedAttention, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.argument import \ + add_arguments_transformer_common # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + RelPositionMultiHeadedAttention # noqa: H301 from espnet.nets.pytorch_backend.transformer.decoder import Decoder -from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution -from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ + DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ + DynamicConvolution2D from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.initializer import initialize -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 +from espnet.nets.pytorch_backend.transformer.mask import (subsequent_mask, + target_mask) from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport from espnet.nets.scorers.ctc import CTCPrefixScorer from espnet.utils.fill_missing_args import fill_missing_args diff --git a/espnet/nets/pytorch_backend/e2e_mt.py b/espnet/nets/pytorch_backend/e2e_mt.py index 9dffdd7ba8d..4e933662117 100644 --- a/espnet/nets/pytorch_backend/e2e_mt.py +++ b/espnet/nets/pytorch_backend/e2e_mt.py @@ -9,22 +9,22 @@ import os import chainer -from chainer import reporter import nltk import numpy as np import torch +from chainer import reporter from espnet.nets.e2e_asr_common import label_smoothing_dist from espnet.nets.mt_interface import MTInterface from espnet.nets.pytorch_backend.initialization import uniform_init_parameters -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import to_device -from espnet.nets.pytorch_backend.rnn.argument import ( - add_arguments_rnn_encoder_common, # noqa: H301 - add_arguments_rnn_decoder_common, # noqa: H301 - add_arguments_rnn_attention_common, # noqa: H301 -) +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, pad_list, + to_device) +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_attention_common # noqa: H301 +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_decoder_common # noqa: H301 +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_encoder_common # noqa: H301 from espnet.nets.pytorch_backend.rnn.attentions import att_for from espnet.nets.pytorch_backend.rnn.decoders import decoder_for from espnet.nets.pytorch_backend.rnn.encoders import encoder_for diff --git a/espnet/nets/pytorch_backend/e2e_mt_transformer.py b/espnet/nets/pytorch_backend/e2e_mt_transformer.py index 5e4b9bb70e1..40a9c640a3a 100644 --- a/espnet/nets/pytorch_backend/e2e_mt_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_mt_transformer.py @@ -3,9 +3,9 @@ """Transformer text translation model (pytorch).""" -from argparse import Namespace import logging import math +from argparse import Namespace import numpy as np import torch @@ -14,23 +14,21 @@ from espnet.nets.e2e_mt_common import ErrorCalculator from espnet.nets.mt_interface import MTInterface from espnet.nets.pytorch_backend.e2e_mt import Reporter -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.nets_utils import th_accuracy -from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, + make_pad_mask, th_accuracy, + to_device) from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.argument import ( - add_arguments_transformer_common, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.argument import \ + add_arguments_transformer_common # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder import Decoder from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.initializer import initialize -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 +from espnet.nets.pytorch_backend.transformer.mask import (subsequent_mask, + target_mask) from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport from espnet.utils.fill_missing_args import fill_missing_args diff --git a/espnet/nets/pytorch_backend/e2e_st.py b/espnet/nets/pytorch_backend/e2e_st.py index 1464c896833..89d0efefe28 100644 --- a/espnet/nets/pytorch_backend/e2e_st.py +++ b/espnet/nets/pytorch_backend/e2e_st.py @@ -8,31 +8,27 @@ import logging import math import os - -import nltk +from itertools import groupby import chainer +import nltk import numpy as np import six import torch - -from itertools import groupby - from chainer import reporter from espnet.nets.e2e_asr_common import label_smoothing_dist from espnet.nets.pytorch_backend.ctc import CTC -from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters -from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import to_device -from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor -from espnet.nets.pytorch_backend.rnn.argument import ( - add_arguments_rnn_encoder_common, # noqa: H301 - add_arguments_rnn_decoder_common, # noqa: H301 - add_arguments_rnn_attention_common, # noqa: H301 -) +from espnet.nets.pytorch_backend.initialization import ( + lecun_normal_init_parameters, set_forget_bias_to_one) +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, pad_list, + to_device, to_torch_tensor) +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_attention_common # noqa: H301 +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_decoder_common # noqa: H301 +from espnet.nets.pytorch_backend.rnn.argument import \ + add_arguments_rnn_encoder_common # noqa: H301 from espnet.nets.pytorch_backend.rnn.attentions import att_for from espnet.nets.pytorch_backend.rnn.decoders import decoder_for from espnet.nets.pytorch_backend.rnn.encoders import encoder_for diff --git a/espnet/nets/pytorch_backend/e2e_st_conformer.py b/espnet/nets/pytorch_backend/e2e_st_conformer.py index f34bb1f598a..347bc686674 100644 --- a/espnet/nets/pytorch_backend/e2e_st_conformer.py +++ b/espnet/nets/pytorch_backend/e2e_st_conformer.py @@ -9,12 +9,13 @@ """ +from espnet.nets.pytorch_backend.conformer.argument import \ + add_arguments_conformer_common # noqa: H301 +from espnet.nets.pytorch_backend.conformer.argument import \ + verify_rel_pos_type # noqa: H301 from espnet.nets.pytorch_backend.conformer.encoder import Encoder -from espnet.nets.pytorch_backend.e2e_st_transformer import E2E as E2ETransformer -from espnet.nets.pytorch_backend.conformer.argument import ( - add_arguments_conformer_common, # noqa: H301 - verify_rel_pos_type, # noqa: H301 -) +from espnet.nets.pytorch_backend.e2e_st_transformer import \ + E2E as E2ETransformer class E2E(E2ETransformer): diff --git a/espnet/nets/pytorch_backend/e2e_st_transformer.py b/espnet/nets/pytorch_backend/e2e_st_transformer.py index 8c6406cb9ee..b92091ddcc4 100644 --- a/espnet/nets/pytorch_backend/e2e_st_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_st_transformer.py @@ -3,36 +3,34 @@ """Transformer speech recognition model (pytorch).""" -from argparse import Namespace import logging import math -import numpy +from argparse import Namespace +import numpy import torch -from espnet.nets.e2e_asr_common import end_detect from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator +from espnet.nets.e2e_asr_common import end_detect from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator from espnet.nets.pytorch_backend.ctc import CTC from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD from espnet.nets.pytorch_backend.e2e_st import Reporter -from espnet.nets.pytorch_backend.nets_utils import get_subsample -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.nets_utils import (get_subsample, + make_non_pad_mask, + pad_list, th_accuracy) from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.argument import ( - add_arguments_transformer_common, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.argument import \ + add_arguments_transformer_common # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder import Decoder from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.initializer import initialize -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 +from espnet.nets.pytorch_backend.transformer.mask import (subsequent_mask, + target_mask) from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport from espnet.nets.st_interface import STInterface from espnet.utils.fill_missing_args import fill_missing_args diff --git a/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py b/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py index 8c9f2bcb232..5cfbea565aa 100644 --- a/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py +++ b/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py @@ -8,22 +8,22 @@ import torch import torch.nn.functional as F -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import torch_load -from espnet.nets.pytorch_backend.fastspeech.duration_calculator import ( - DurationCalculator, # noqa: H301 -) -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( - DurationPredictorLoss, # noqa: H301 -) -from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.asr.asr_utils import get_model_conf, torch_load +from espnet.nets.pytorch_backend.fastspeech.duration_calculator import \ + DurationCalculator # noqa: H301 +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictorLoss # noqa: H301 +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.length_regulator import \ + LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + make_pad_mask) from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.initializer import initialize from espnet.nets.tts_interface import TTSInterface diff --git a/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py b/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py index 2e543d932e7..6364ee3f13f 100644 --- a/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py +++ b/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py @@ -10,11 +10,9 @@ import torch.nn.functional as F from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.rnn.attentions import AttForward -from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA -from espnet.nets.pytorch_backend.rnn.attentions import AttLoc -from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG -from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHGLoss +from espnet.nets.pytorch_backend.rnn.attentions import (AttForward, + AttForwardTA, AttLoc) +from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG, CBHGLoss from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder from espnet.nets.tts_interface import TTSInterface diff --git a/espnet/nets/pytorch_backend/e2e_tts_transformer.py b/espnet/nets/pytorch_backend/e2e_tts_transformer.py index e71f1973fda..4929d7a2963 100644 --- a/espnet/nets/pytorch_backend/e2e_tts_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_tts_transformer.py @@ -9,17 +9,19 @@ import torch.nn.functional as F from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss -from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( - Tacotron2Loss as TransformerLoss, # noqa: H301 -) +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import \ + Tacotron2Loss as TransformerLoss # noqa: H301 from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet -from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet -from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.tacotron2.decoder import \ + Prenet as DecoderPrenet +from espnet.nets.pytorch_backend.tacotron2.encoder import \ + Encoder as EncoderPrenet +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder import Decoder -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.initializer import initialize from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask @@ -69,7 +71,8 @@ def forward(self, att_ws, ilens, olens): try: - from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport + from espnet.nets.pytorch_backend.transformer.plot import \ + PlotAttentionReport except (ImportError, TypeError): TTSPlot = None else: @@ -93,9 +96,9 @@ def plotfn( """ import matplotlib.pyplot as plt - from espnet.nets.pytorch_backend.transformer.plot import ( - _plot_and_save_attention, # noqa: H301 - ) + + from espnet.nets.pytorch_backend.transformer.plot import \ + _plot_and_save_attention # noqa: H301 for name, att_ws in attn_dict.items(): for utt_id, att_w in zip(uttid_list, att_ws): diff --git a/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py b/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py index 049d9407f8a..a76678cb4b1 100644 --- a/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py +++ b/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py @@ -4,26 +4,23 @@ """Tacotron2-VC related modules.""" import logging - from distutils.util import strtobool import numpy as np import torch import torch.nn.functional as F -from espnet.nets.pytorch_backend.rnn.attentions import AttForward -from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA -from espnet.nets.pytorch_backend.rnn.attentions import AttLoc -from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG -from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHGLoss +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import \ + GuidedAttentionLoss # noqa: H301 +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import \ + Tacotron2Loss # noqa: H301 +from espnet.nets.pytorch_backend.rnn.attentions import (AttForward, + AttForwardTA, AttLoc) +from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG, CBHGLoss from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder from espnet.nets.tts_interface import TTSInterface from espnet.utils.fill_missing_args import fill_missing_args -from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( - GuidedAttentionLoss, # noqa: H301 - Tacotron2Loss, # noqa: H301 -) class Tacotron2(TTSInterface, torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/e2e_vc_transformer.py b/espnet/nets/pytorch_backend/e2e_vc_transformer.py index 99fd3f3962b..20e77e38101 100644 --- a/espnet/nets/pytorch_backend/e2e_vc_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_vc_transformer.py @@ -9,26 +9,28 @@ import torch.nn.functional as F from espnet.nets.pytorch_backend.e2e_asr_transformer import subsequent_mask -from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( - Tacotron2Loss as TransformerLoss, # noqa: H301 -) +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import \ + Tacotron2Loss as TransformerLoss # noqa: H301 +from espnet.nets.pytorch_backend.e2e_tts_transformer import \ + GuidedMultiHeadAttentionLoss # noqa: H301 +from espnet.nets.pytorch_backend.e2e_tts_transformer import \ + TTSPlot # noqa: H301 from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet -from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet -from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.tacotron2.decoder import \ + Prenet as DecoderPrenet +from espnet.nets.pytorch_backend.tacotron2.encoder import \ + Encoder as EncoderPrenet +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder import Decoder -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.initializer import initialize from espnet.nets.tts_interface import TTSInterface from espnet.utils.cli_utils import strtobool from espnet.utils.fill_missing_args import fill_missing_args -from espnet.nets.pytorch_backend.e2e_tts_transformer import ( - GuidedMultiHeadAttentionLoss, # noqa: H301 - TTSPlot, # noqa: H301 -) class Transformer(TTSInterface, torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py b/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py index 1495c81a40d..53332a49c75 100644 --- a/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py +++ b/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py @@ -3,14 +3,13 @@ import torch from torch.nn import functional as F +from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.frontends.beamformer import apply_beamforming_vector -from espnet.nets.pytorch_backend.frontends.beamformer import get_mvdr_vector +from espnet.nets.pytorch_backend.frontends.beamformer import \ + get_power_spectral_density_matrix # noqa: H301 from espnet.nets.pytorch_backend.frontends.beamformer import ( - get_power_spectral_density_matrix, # noqa: H301 -) + apply_beamforming_vector, get_mvdr_vector) from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator -from torch_complex.tensor import ComplexTensor class DNN_Beamformer(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/frontends/dnn_wpe.py b/espnet/nets/pytorch_backend/frontends/dnn_wpe.py index 8bfe599d2f7..02a753a78f5 100644 --- a/espnet/nets/pytorch_backend/frontends/dnn_wpe.py +++ b/espnet/nets/pytorch_backend/frontends/dnn_wpe.py @@ -1,7 +1,7 @@ from typing import Tuple -from pytorch_wpe import wpe_one_iteration import torch +from pytorch_wpe import wpe_one_iteration from torch_complex.tensor import ComplexTensor from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator diff --git a/espnet/nets/pytorch_backend/frontends/feature_transform.py b/espnet/nets/pytorch_backend/frontends/feature_transform.py index 53915d28815..9fe4ada9b65 100644 --- a/espnet/nets/pytorch_backend/frontends/feature_transform.py +++ b/espnet/nets/pytorch_backend/frontends/feature_transform.py @@ -1,6 +1,4 @@ -from typing import List -from typing import Tuple -from typing import Union +from typing import List, Tuple, Union import librosa import numpy as np diff --git a/espnet/nets/pytorch_backend/frontends/frontend.py b/espnet/nets/pytorch_backend/frontends/frontend.py index 7231f68b35f..e60e3ff8628 100644 --- a/espnet/nets/pytorch_backend/frontends/frontend.py +++ b/espnet/nets/pytorch_backend/frontends/frontend.py @@ -1,7 +1,4 @@ -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import List, Optional, Tuple, Union import numpy import torch diff --git a/espnet/nets/pytorch_backend/frontends/mask_estimator.py b/espnet/nets/pytorch_backend/frontends/mask_estimator.py index 861527c7a90..3ebdb3ff148 100644 --- a/espnet/nets/pytorch_backend/frontends/mask_estimator.py +++ b/espnet/nets/pytorch_backend/frontends/mask_estimator.py @@ -6,8 +6,7 @@ from torch_complex.tensor import ComplexTensor from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.rnn.encoders import RNN -from espnet.nets.pytorch_backend.rnn.encoders import RNNP +from espnet.nets.pytorch_backend.rnn.encoders import RNN, RNNP class MaskEstimator(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/lm/default.py b/espnet/nets/pytorch_backend/lm/default.py index 01bb26ea4a0..7a161333d4e 100644 --- a/espnet/nets/pytorch_backend/lm/default.py +++ b/espnet/nets/pytorch_backend/lm/default.py @@ -1,10 +1,8 @@ """Default Recurrent Neural Network Languge Model in `lm_train.py`.""" -from typing import Any -from typing import List -from typing import Tuple - import logging +from typing import Any, List, Tuple + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/espnet/nets/pytorch_backend/lm/transformer.py b/espnet/nets/pytorch_backend/lm/transformer.py index 42c2f86d461..3ded10416f1 100644 --- a/espnet/nets/pytorch_backend/lm/transformer.py +++ b/espnet/nets/pytorch_backend/lm/transformer.py @@ -1,16 +1,15 @@ """Transformer language model.""" -from typing import Any -from typing import List -from typing import Tuple - import logging +from typing import Any, List, Tuple + import torch import torch.nn as nn import torch.nn.functional as F from espnet.nets.lm_interface import LMInterface -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask from espnet.nets.scorer_interface import BatchScorerInterface diff --git a/espnet/nets/pytorch_backend/rnn/attentions.py b/espnet/nets/pytorch_backend/rnn/attentions.py index 3df28169bd9..a92a0a3c193 100644 --- a/espnet/nets/pytorch_backend/rnn/attentions.py +++ b/espnet/nets/pytorch_backend/rnn/attentions.py @@ -1,13 +1,12 @@ """Attention modules for RNN.""" import math -import six +import six import torch import torch.nn.functional as F -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device def _apply_attention_constraint( diff --git a/espnet/nets/pytorch_backend/rnn/decoders.py b/espnet/nets/pytorch_backend/rnn/decoders.py index 9e7a60716d8..2bcdb3fecd4 100644 --- a/espnet/nets/pytorch_backend/rnn/decoders.py +++ b/espnet/nets/pytorch_backend/rnn/decoders.py @@ -2,24 +2,18 @@ import logging import math import random -import six +from argparse import Namespace import numpy as np +import six import torch import torch.nn.functional as F -from argparse import Namespace - -from espnet.nets.ctc_prefix_score import CTCPrefixScore -from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH +from espnet.nets.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH from espnet.nets.e2e_asr_common import end_detect - +from espnet.nets.pytorch_backend.nets_utils import (mask_by_length, pad_list, + th_accuracy, to_device) from espnet.nets.pytorch_backend.rnn.attentions import att_to_numpy - -from espnet.nets.pytorch_backend.nets_utils import mask_by_length -from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.nets_utils import th_accuracy -from espnet.nets.pytorch_backend.nets_utils import to_device from espnet.nets.scorer_interface import ScorerInterface MAX_DECODER_OUTPUT = 5 diff --git a/espnet/nets/pytorch_backend/rnn/encoders.py b/espnet/nets/pytorch_backend/rnn/encoders.py index 7ab90e6c3ae..811bfd432c0 100644 --- a/espnet/nets/pytorch_backend/rnn/encoders.py +++ b/espnet/nets/pytorch_backend/rnn/encoders.py @@ -1,15 +1,13 @@ import logging -import six import numpy as np +import six import torch import torch.nn.functional as F -from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from espnet.nets.e2e_asr_common import get_vgg2l_odim -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device class RNNP(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/tacotron2/cbhg.py b/espnet/nets/pytorch_backend/tacotron2/cbhg.py index c869e0f8c63..24dec4da653 100644 --- a/espnet/nets/pytorch_backend/tacotron2/cbhg.py +++ b/espnet/nets/pytorch_backend/tacotron2/cbhg.py @@ -8,9 +8,7 @@ import torch import torch.nn.functional as F - -from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask diff --git a/espnet/nets/pytorch_backend/tacotron2/decoder.py b/espnet/nets/pytorch_backend/tacotron2/decoder.py index 352635ddd16..85bad0200a9 100644 --- a/espnet/nets/pytorch_backend/tacotron2/decoder.py +++ b/espnet/nets/pytorch_backend/tacotron2/decoder.py @@ -7,7 +7,6 @@ """Tacotron2 decoder related modules.""" import six - import torch import torch.nn.functional as F diff --git a/espnet/nets/pytorch_backend/tacotron2/encoder.py b/espnet/nets/pytorch_backend/tacotron2/encoder.py index 148db765cc7..0603941767b 100644 --- a/espnet/nets/pytorch_backend/tacotron2/encoder.py +++ b/espnet/nets/pytorch_backend/tacotron2/encoder.py @@ -7,11 +7,8 @@ """Tacotron2 encoder related modules.""" import six - import torch - -from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence def encoder_init(m): diff --git a/espnet/nets/pytorch_backend/transducer/arguments.py b/espnet/nets/pytorch_backend/transducer/arguments.py index feeaec8059f..87fdf0ad557 100644 --- a/espnet/nets/pytorch_backend/transducer/arguments.py +++ b/espnet/nets/pytorch_backend/transducer/arguments.py @@ -1,7 +1,7 @@ """Transducer model arguments.""" -from argparse import _ArgumentGroup import ast +from argparse import _ArgumentGroup from distutils.util import strtobool diff --git a/espnet/nets/pytorch_backend/transducer/blocks.py b/espnet/nets/pytorch_backend/transducer/blocks.py index 86abc21e9a8..b09676ce668 100644 --- a/espnet/nets/pytorch_backend/transducer/blocks.py +++ b/espnet/nets/pytorch_backend/transducer/blocks.py @@ -1,42 +1,34 @@ """Set of methods to create custom architecture.""" -from typing import Any -from typing import Dict -from typing import List -from typing import Tuple -from typing import Union +from typing import Any, Dict, List, Tuple, Union import torch from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule -from espnet.nets.pytorch_backend.conformer.encoder_layer import ( - EncoderLayer as ConformerEncoderLayer, # noqa: H301 -) - +from espnet.nets.pytorch_backend.conformer.encoder_layer import \ + EncoderLayer as ConformerEncoderLayer # noqa: H301 from espnet.nets.pytorch_backend.nets_utils import get_activation - -from espnet.nets.pytorch_backend.transducer.conv1d_nets import CausalConv1d -from espnet.nets.pytorch_backend.transducer.conv1d_nets import Conv1d -from espnet.nets.pytorch_backend.transducer.transformer_decoder_layer import ( - TransformerDecoderLayer, # noqa: H301 -) +from espnet.nets.pytorch_backend.transducer.conv1d_nets import (CausalConv1d, + Conv1d) +from espnet.nets.pytorch_backend.transducer.transformer_decoder_layer import \ + TransformerDecoderLayer # noqa: H301 from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L - -from espnet.nets.pytorch_backend.transformer.attention import ( - MultiHeadedAttention, # noqa: H301 - RelPositionMultiHeadedAttention, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + RelPositionMultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + RelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + ScaledPositionalEncoding # noqa: H301 from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer -from espnet.nets.pytorch_backend.transformer.embedding import ( - PositionalEncoding, # noqa: H301 - ScaledPositionalEncoding, # noqa: H301 - RelPositionalEncoding, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import MultiSequential -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import \ + Conv2dSubsampling def verify_block_arguments( diff --git a/espnet/nets/pytorch_backend/transducer/conv1d_nets.py b/espnet/nets/pytorch_backend/transducer/conv1d_nets.py index 56816e8d04d..0c71c123c88 100644 --- a/espnet/nets/pytorch_backend/transducer/conv1d_nets.py +++ b/espnet/nets/pytorch_backend/transducer/conv1d_nets.py @@ -1,8 +1,6 @@ """Convolution networks definition for custom archictecture.""" -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import torch diff --git a/espnet/nets/pytorch_backend/transducer/custom_decoder.py b/espnet/nets/pytorch_backend/transducer/custom_decoder.py index f5b2724ef75..f09f811a9c6 100644 --- a/espnet/nets/pytorch_backend/transducer/custom_decoder.py +++ b/espnet/nets/pytorch_backend/transducer/custom_decoder.py @@ -1,23 +1,17 @@ """Custom decoder definition for Transducer model.""" -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from espnet.nets.pytorch_backend.transducer.blocks import build_blocks -from espnet.nets.pytorch_backend.transducer.utils import check_batch_states -from espnet.nets.pytorch_backend.transducer.utils import check_state -from espnet.nets.pytorch_backend.transducer.utils import pad_sequence +from espnet.nets.pytorch_backend.transducer.utils import (check_batch_states, + check_state, + pad_sequence) from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.transducer_decoder_interface import ExtendedHypothesis -from espnet.nets.transducer_decoder_interface import Hypothesis -from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface +from espnet.nets.transducer_decoder_interface import ( + ExtendedHypothesis, Hypothesis, TransducerDecoderInterface) class CustomDecoder(TransducerDecoderInterface, torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transducer/custom_encoder.py b/espnet/nets/pytorch_backend/transducer/custom_encoder.py index 109d2071ba9..18e1d375d4f 100644 --- a/espnet/nets/pytorch_backend/transducer/custom_encoder.py +++ b/espnet/nets/pytorch_backend/transducer/custom_encoder.py @@ -1,16 +1,14 @@ """Cutom encoder definition for transducer models.""" -from typing import List -from typing import Tuple -from typing import Union +from typing import List, Tuple, Union import torch from espnet.nets.pytorch_backend.transducer.blocks import build_blocks from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L - from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import \ + Conv2dSubsampling class CustomEncoder(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transducer/error_calculator.py b/espnet/nets/pytorch_backend/transducer/error_calculator.py index 1d204770cfb..89224b62d70 100644 --- a/espnet/nets/pytorch_backend/transducer/error_calculator.py +++ b/espnet/nets/pytorch_backend/transducer/error_calculator.py @@ -1,8 +1,6 @@ """CER/WER computation for Transducer model.""" -from typing import List -from typing import Tuple -from typing import Union +from typing import List, Tuple, Union import torch diff --git a/espnet/nets/pytorch_backend/transducer/initializer.py b/espnet/nets/pytorch_backend/transducer/initializer.py index 8ae47ff471f..4ca566988b3 100644 --- a/espnet/nets/pytorch_backend/transducer/initializer.py +++ b/espnet/nets/pytorch_backend/transducer/initializer.py @@ -1,7 +1,7 @@ """Parameter initialization for Transducer model.""" -from argparse import Namespace import math +from argparse import Namespace import torch diff --git a/espnet/nets/pytorch_backend/transducer/rnn_decoder.py b/espnet/nets/pytorch_backend/transducer/rnn_decoder.py index 401cbe8f808..0046cd5bcdf 100644 --- a/espnet/nets/pytorch_backend/transducer/rnn_decoder.py +++ b/espnet/nets/pytorch_backend/transducer/rnn_decoder.py @@ -1,17 +1,11 @@ """RNN decoder definition for Transducer model.""" -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from espnet.nets.transducer_decoder_interface import ExtendedHypothesis -from espnet.nets.transducer_decoder_interface import Hypothesis -from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface +from espnet.nets.transducer_decoder_interface import ( + ExtendedHypothesis, Hypothesis, TransducerDecoderInterface) class RNNDecoder(TransducerDecoderInterface, torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transducer/rnn_encoder.py b/espnet/nets/pytorch_backend/transducer/rnn_encoder.py index 3fe6a783710..f1bafef33c3 100644 --- a/espnet/nets/pytorch_backend/transducer/rnn_encoder.py +++ b/espnet/nets/pytorch_backend/transducer/rnn_encoder.py @@ -9,20 +9,15 @@ """ from argparse import Namespace -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F -from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from espnet.nets.e2e_asr_common import get_vgg2l_odim -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device class RNNP(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transducer/transducer_tasks.py b/espnet/nets/pytorch_backend/transducer/transducer_tasks.py index 79dc614bca6..03bcfc5ecf2 100644 --- a/espnet/nets/pytorch_backend/transducer/transducer_tasks.py +++ b/espnet/nets/pytorch_backend/transducer/transducer_tasks.py @@ -1,17 +1,13 @@ """Module implementing Transducer main and auxiliary tasks.""" -from typing import Any -from typing import List -from typing import Optional -from typing import Tuple +from typing import Any, List, Optional, Tuple import torch from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 class TransducerTasks(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py b/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py index 9aecce54e0c..64ea4a30d3c 100644 --- a/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py +++ b/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py @@ -4,11 +4,11 @@ import torch -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 class TransformerDecoderLayer(torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transducer/utils.py b/espnet/nets/pytorch_backend/transducer/utils.py index d8bf3bfe336..91815a4016f 100644 --- a/espnet/nets/pytorch_backend/transducer/utils.py +++ b/espnet/nets/pytorch_backend/transducer/utils.py @@ -1,18 +1,14 @@ """Utility functions for Transducer models.""" import os -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch from espnet.nets.pytorch_backend.nets_utils import pad_list -from espnet.nets.transducer_decoder_interface import ExtendedHypothesis -from espnet.nets.transducer_decoder_interface import Hypothesis +from espnet.nets.transducer_decoder_interface import (ExtendedHypothesis, + Hypothesis) def get_decoder_input( diff --git a/espnet/nets/pytorch_backend/transducer/vgg2l.py b/espnet/nets/pytorch_backend/transducer/vgg2l.py index c7eecd23281..fb45d0b5a78 100644 --- a/espnet/nets/pytorch_backend/transducer/vgg2l.py +++ b/espnet/nets/pytorch_backend/transducer/vgg2l.py @@ -1,7 +1,6 @@ """VGG2L module definition for custom encoder.""" -from typing import Tuple -from typing import Union +from typing import Tuple, Union import torch diff --git a/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py b/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py index 16957e99820..cc1f13c512f 100644 --- a/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py +++ b/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py @@ -7,7 +7,6 @@ """Encoder self-attention layer definition.""" import torch - from torch import nn from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm diff --git a/espnet/nets/pytorch_backend/transformer/decoder.py b/espnet/nets/pytorch_backend/transformer/decoder.py index 5236632665c..024fc8d31c6 100644 --- a/espnet/nets/pytorch_backend/transformer/decoder.py +++ b/espnet/nets/pytorch_backend/transformer/decoder.py @@ -7,26 +7,28 @@ """Decoder definition.""" import logging - -from typing import Any -from typing import List -from typing import Tuple +from typing import Any, List, Tuple import torch from espnet.nets.pytorch_backend.nets_utils import rename_state_dict -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer -from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution -from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ + DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ + DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution -from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D +from espnet.nets.pytorch_backend.transformer.lightconv import \ + LightweightConvolution +from espnet.nets.pytorch_backend.transformer.lightconv2d import \ + LightweightConvolution2D from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat from espnet.nets.scorer_interface import BatchScorerInterface diff --git a/espnet/nets/pytorch_backend/transformer/dynamic_conv.py b/espnet/nets/pytorch_backend/transformer/dynamic_conv.py index 8a2a0c1eaf0..f254d41e520 100644 --- a/espnet/nets/pytorch_backend/transformer/dynamic_conv.py +++ b/espnet/nets/pytorch_backend/transformer/dynamic_conv.py @@ -2,9 +2,8 @@ import numpy import torch -from torch import nn import torch.nn.functional as F - +from torch import nn MIN_VALUE = float(numpy.finfo(numpy.float32).min) diff --git a/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py b/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py index f8a4dd6e9f6..401c61d4009 100644 --- a/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py +++ b/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py @@ -2,9 +2,8 @@ import numpy import torch -from torch import nn import torch.nn.functional as F - +from torch import nn MIN_VALUE = float(numpy.finfo(numpy.float32).min) diff --git a/espnet/nets/pytorch_backend/transformer/embedding.py b/espnet/nets/pytorch_backend/transformer/embedding.py index 17a39fddec4..7021aa39eff 100644 --- a/espnet/nets/pytorch_backend/transformer/embedding.py +++ b/espnet/nets/pytorch_backend/transformer/embedding.py @@ -7,6 +7,7 @@ """Positional Encoding Module.""" import math + import torch diff --git a/espnet/nets/pytorch_backend/transformer/encoder.py b/espnet/nets/pytorch_backend/transformer/encoder.py index 508bf1aa7a7..da0d1146fc8 100644 --- a/espnet/nets/pytorch_backend/transformer/encoder.py +++ b/espnet/nets/pytorch_backend/transformer/encoder.py @@ -4,27 +4,32 @@ """Encoder definition.""" import logging + import torch from espnet.nets.pytorch_backend.nets_utils import rename_state_dict from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution -from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ + DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ + DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution -from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.lightconv import \ + LightweightConvolution +from espnet.nets.pytorch_backend.transformer.lightconv2d import \ + LightweightConvolution2D +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8) def _pre_hook( diff --git a/espnet/nets/pytorch_backend/transformer/encoder_layer.py b/espnet/nets/pytorch_backend/transformer/encoder_layer.py index 863aa6730b3..1554cb4de3f 100644 --- a/espnet/nets/pytorch_backend/transformer/encoder_layer.py +++ b/espnet/nets/pytorch_backend/transformer/encoder_layer.py @@ -7,7 +7,6 @@ """Encoder self-attention layer definition.""" import torch - from torch import nn from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm diff --git a/espnet/nets/pytorch_backend/transformer/encoder_mix.py b/espnet/nets/pytorch_backend/transformer/encoder_mix.py index 4fa2d355545..b0c67998f10 100644 --- a/espnet/nets/pytorch_backend/transformer/encoder_mix.py +++ b/espnet/nets/pytorch_backend/transformer/encoder_mix.py @@ -9,12 +9,15 @@ import torch from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import \ + Conv2dSubsampling class EncoderMix(Encoder, torch.nn.Module): diff --git a/espnet/nets/pytorch_backend/transformer/lightconv.py b/espnet/nets/pytorch_backend/transformer/lightconv.py index b249402591e..1f3c2b89180 100644 --- a/espnet/nets/pytorch_backend/transformer/lightconv.py +++ b/espnet/nets/pytorch_backend/transformer/lightconv.py @@ -2,9 +2,8 @@ import numpy import torch -from torch import nn import torch.nn.functional as F - +from torch import nn MIN_VALUE = float(numpy.finfo(numpy.float32).min) diff --git a/espnet/nets/pytorch_backend/transformer/lightconv2d.py b/espnet/nets/pytorch_backend/transformer/lightconv2d.py index 294d23244e4..e7e52241134 100644 --- a/espnet/nets/pytorch_backend/transformer/lightconv2d.py +++ b/espnet/nets/pytorch_backend/transformer/lightconv2d.py @@ -2,9 +2,8 @@ import numpy import torch -from torch import nn import torch.nn.functional as F - +from torch import nn MIN_VALUE = float(numpy.finfo(numpy.float32).min) diff --git a/espnet/nets/pytorch_backend/transformer/longformer_attention.py b/espnet/nets/pytorch_backend/transformer/longformer_attention.py index 82a54c801d1..1610b6ab594 100644 --- a/espnet/nets/pytorch_backend/transformer/longformer_attention.py +++ b/espnet/nets/pytorch_backend/transformer/longformer_attention.py @@ -6,8 +6,7 @@ """Longformer based Local Attention Definition.""" -from longformer.longformer import LongformerConfig -from longformer.longformer import LongformerSelfAttention +from longformer.longformer import LongformerConfig, LongformerSelfAttention from torch import nn diff --git a/espnet/nets/pytorch_backend/transformer/plot.py b/espnet/nets/pytorch_backend/transformer/plot.py index 5946de6cd56..e7a1746e823 100644 --- a/espnet/nets/pytorch_backend/transformer/plot.py +++ b/espnet/nets/pytorch_backend/transformer/plot.py @@ -2,9 +2,9 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import logging +import os import numpy -import os from espnet.asr import asr_utils diff --git a/espnet/nets/pytorch_backend/transformer/subsampling.py b/espnet/nets/pytorch_backend/transformer/subsampling.py index a69bc09445f..3ff363c9f7b 100644 --- a/espnet/nets/pytorch_backend/transformer/subsampling.py +++ b/espnet/nets/pytorch_backend/transformer/subsampling.py @@ -8,7 +8,8 @@ import torch -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding class TooShortUttError(Exception): diff --git a/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py b/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py index 239d3f1ade7..1a188c48d95 100644 --- a/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py +++ b/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py @@ -4,6 +4,7 @@ """Subsampling layer definition.""" import math + import torch diff --git a/espnet/nets/pytorch_backend/wavenet.py b/espnet/nets/pytorch_backend/wavenet.py index 0539518342c..5fb8124e985 100644 --- a/espnet/nets/pytorch_backend/wavenet.py +++ b/espnet/nets/pytorch_backend/wavenet.py @@ -12,7 +12,6 @@ import numpy as np import torch import torch.nn.functional as F - from torch import nn diff --git a/espnet/nets/scorer_interface.py b/espnet/nets/scorer_interface.py index 946ec6be317..fb3a09fa676 100644 --- a/espnet/nets/scorer_interface.py +++ b/espnet/nets/scorer_interface.py @@ -1,11 +1,9 @@ """Scorer interface module.""" -from typing import Any -from typing import List -from typing import Tuple +import warnings +from typing import Any, List, Tuple import torch -import warnings class ScorerInterface: diff --git a/espnet/nets/scorers/ctc.py b/espnet/nets/scorers/ctc.py index 1d12ce6e2a2..74c3385ff02 100644 --- a/espnet/nets/scorers/ctc.py +++ b/espnet/nets/scorers/ctc.py @@ -3,8 +3,7 @@ import numpy as np import torch -from espnet.nets.ctc_prefix_score import CTCPrefixScore -from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH +from espnet.nets.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH from espnet.nets.scorer_interface import BatchPartialScorerInterface diff --git a/espnet/nets/scorers/length_bonus.py b/espnet/nets/scorers/length_bonus.py index fe32a616211..490ea84db4c 100644 --- a/espnet/nets/scorers/length_bonus.py +++ b/espnet/nets/scorers/length_bonus.py @@ -1,7 +1,5 @@ """Length bonus module.""" -from typing import Any -from typing import List -from typing import Tuple +from typing import Any, List, Tuple import torch diff --git a/espnet/nets/scorers/ngram.py b/espnet/nets/scorers/ngram.py index 61ed70efdb0..5f92479f3e1 100644 --- a/espnet/nets/scorers/ngram.py +++ b/espnet/nets/scorers/ngram.py @@ -5,8 +5,8 @@ import kenlm import torch -from espnet.nets.scorer_interface import BatchScorerInterface -from espnet.nets.scorer_interface import PartialScorerInterface +from espnet.nets.scorer_interface import (BatchScorerInterface, + PartialScorerInterface) class Ngrambase(ABC): diff --git a/espnet/nets/transducer_decoder_interface.py b/espnet/nets/transducer_decoder_interface.py index eb3ab318dcc..4c723346b50 100644 --- a/espnet/nets/transducer_decoder_interface.py +++ b/espnet/nets/transducer_decoder_interface.py @@ -1,12 +1,7 @@ """Transducer decoder interface module.""" from dataclasses import dataclass -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch diff --git a/espnet/nets/tts_interface.py b/espnet/nets/tts_interface.py index 587d7279237..6ce0d6e6e68 100644 --- a/espnet/nets/tts_interface.py +++ b/espnet/nets/tts_interface.py @@ -7,7 +7,6 @@ from espnet.asr.asr_utils import torch_load - try: import chainer except ImportError: diff --git a/espnet/optimizer/chainer.py b/espnet/optimizer/chainer.py index 0fb6f4b3fab..de58767dcb1 100644 --- a/espnet/optimizer/chainer.py +++ b/espnet/optimizer/chainer.py @@ -5,9 +5,7 @@ from chainer.optimizer_hooks import WeightDecay from espnet.optimizer.factory import OptimizerFactoryInterface -from espnet.optimizer.parser import adadelta -from espnet.optimizer.parser import adam -from espnet.optimizer.parser import sgd +from espnet.optimizer.parser import adadelta, adam, sgd class AdamFactory(OptimizerFactoryInterface): diff --git a/espnet/optimizer/pytorch.py b/espnet/optimizer/pytorch.py index 7914e36b999..946bdf58eba 100644 --- a/espnet/optimizer/pytorch.py +++ b/espnet/optimizer/pytorch.py @@ -4,9 +4,7 @@ import torch from espnet.optimizer.factory import OptimizerFactoryInterface -from espnet.optimizer.parser import adadelta -from espnet.optimizer.parser import adam -from espnet.optimizer.parser import sgd +from espnet.optimizer.parser import adadelta, adam, sgd class AdamFactory(OptimizerFactoryInterface): diff --git a/espnet/st/pytorch_backend/st.py b/espnet/st/pytorch_backend/st.py index 1a56930dcd3..017a23b7942 100644 --- a/espnet/st/pytorch_backend/st.py +++ b/espnet/st/pytorch_backend/st.py @@ -8,39 +8,30 @@ import logging import os -from chainer import training -from chainer.training import extensions import numpy as np import torch +from chainer import training +from chainer.training import extensions -from espnet.asr.asr_utils import adadelta_eps_decay -from espnet.asr.asr_utils import adam_lr_decay -from espnet.asr.asr_utils import add_results_to_json -from espnet.asr.asr_utils import CompareValueTrigger -from espnet.asr.asr_utils import restore_snapshot -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot -from espnet.asr.pytorch_backend.asr_init import load_trained_model -from espnet.asr.pytorch_backend.asr_init import load_trained_modules - +from espnet.asr.asr_utils import (CompareValueTrigger, adadelta_eps_decay, + adam_lr_decay, add_results_to_json, + restore_snapshot, snapshot_object, + torch_load, torch_resume, torch_snapshot) +from espnet.asr.pytorch_backend.asr import \ + CustomConverter as ASRCustomConverter +from espnet.asr.pytorch_backend.asr import CustomEvaluator, CustomUpdater +from espnet.asr.pytorch_backend.asr_init import (load_trained_model, + load_trained_modules) from espnet.nets.pytorch_backend.e2e_asr import pad_list from espnet.nets.st_interface import STInterface -from espnet.utils.dataset import ChainerDataLoader -from espnet.utils.dataset import TransformDataset +from espnet.utils.dataset import ChainerDataLoader, TransformDataset from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.iterators import ShufflingEnabler from espnet.utils.training.tensorboard_logger import TensorboardLogger -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop - -from espnet.asr.pytorch_backend.asr import CustomConverter as ASRCustomConverter -from espnet.asr.pytorch_backend.asr import CustomEvaluator -from espnet.asr.pytorch_backend.asr import CustomUpdater +from espnet.utils.training.train_utils import check_early_stop, set_early_stop class CustomConverter(ASRCustomConverter): @@ -183,7 +174,8 @@ def train(args): model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) elif args.opt == "noam": - from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + from espnet.nets.pytorch_backend.transformer.optimizer import \ + get_std_opt optimizer = get_std_opt( model.parameters(), diff --git a/espnet/transform/transformation.py b/espnet/transform/transformation.py index 1a043b00be4..201117bb61c 100644 --- a/espnet/transform/transformation.py +++ b/espnet/transform/transformation.py @@ -1,16 +1,15 @@ """Transformation module.""" -from collections.abc import Sequence -from collections import OrderedDict import copy -from inspect import signature import io import logging +from collections import OrderedDict +from collections.abc import Sequence +from inspect import signature import yaml from espnet.utils.dynamic_import import dynamic_import - # TODO(karita): inherit TransformInterface # TODO(karita): register cmd arguments in asr_train.py import_alias = dict( diff --git a/espnet/tts/pytorch_backend/tts.py b/espnet/tts/pytorch_backend/tts.py index 09c45479a48..d397bb4f3d3 100644 --- a/espnet/tts/pytorch_backend/tts.py +++ b/espnet/tts/pytorch_backend/tts.py @@ -17,32 +17,23 @@ import kaldiio import numpy as np import torch - from chainer import training from chainer.training import extensions -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.asr_utils import (get_model_conf, snapshot_object, torch_load, + torch_resume, torch_snapshot) from espnet.asr.pytorch_backend.asr_init import load_trained_modules from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet.nets.tts_interface import TTSInterface -from espnet.utils.dataset import ChainerDataLoader -from espnet.utils.dataset import TransformDataset +from espnet.utils.dataset import ChainerDataLoader, TransformDataset +from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.evaluator import BaseEvaluator - -from espnet.utils.deterministic_utils import set_deterministic_pytorch -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop - from espnet.utils.training.iterators import ShufflingEnabler - from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop, set_early_stop class CustomEvaluator(BaseEvaluator): @@ -354,7 +345,8 @@ def train(args): model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "noam": - from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + from espnet.nets.pytorch_backend.transformer.optimizer import \ + get_std_opt optimizer = get_std_opt( model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr diff --git a/espnet/utils/cli_utils.py b/espnet/utils/cli_utils.py index c4a4cd15b72..6b7fce6c924 100644 --- a/espnet/utils/cli_utils.py +++ b/espnet/utils/cli_utils.py @@ -1,6 +1,6 @@ +import sys from collections.abc import Sequence from distutils.util import strtobool as dist_strtobool -import sys import numpy diff --git a/espnet/utils/io_utils.py b/espnet/utils/io_utils.py index 6a642796c43..0df663efd02 100644 --- a/espnet/utils/io_utils.py +++ b/espnet/utils/io_utils.py @@ -1,7 +1,7 @@ -from collections import OrderedDict import io import logging import os +from collections import OrderedDict import h5py import kaldiio diff --git a/espnet/utils/training/iterators.py b/espnet/utils/training/iterators.py index 1cabb1f1fa8..7faf441dd01 100644 --- a/espnet/utils/training/iterators.py +++ b/espnet/utils/training/iterators.py @@ -1,10 +1,8 @@ import chainer -from chainer.iterators import MultiprocessIterator -from chainer.iterators import SerialIterator -from chainer.iterators import ShuffleOrderSampler -from chainer.training.extension import Extension - import numpy as np +from chainer.iterators import (MultiprocessIterator, SerialIterator, + ShuffleOrderSampler) +from chainer.training.extension import Extension class ShufflingEnabler(Extension): diff --git a/espnet/utils/training/train_utils.py b/espnet/utils/training/train_utils.py index 38f7cd4feb6..9e8b4fbbc52 100644 --- a/espnet/utils/training/train_utils.py +++ b/espnet/utils/training/train_utils.py @@ -1,6 +1,7 @@ -import chainer import logging +import chainer + def check_early_stop(trainer, epochs): """Checks an early stopping trigger and warns the user if it's the case diff --git a/espnet/vc/pytorch_backend/vc.py b/espnet/vc/pytorch_backend/vc.py index bfa3b0d11f3..9295bdaf735 100644 --- a/espnet/vc/pytorch_backend/vc.py +++ b/espnet/vc/pytorch_backend/vc.py @@ -17,32 +17,23 @@ import kaldiio import numpy as np import torch - from chainer import training from chainer.training import extensions -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import snapshot_object -from espnet.asr.asr_utils import torch_load -from espnet.asr.asr_utils import torch_resume -from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.asr_utils import (get_model_conf, snapshot_object, torch_load, + torch_resume, torch_snapshot) from espnet.asr.pytorch_backend.asr_init import load_trained_modules from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet.nets.tts_interface import TTSInterface -from espnet.utils.dataset import ChainerDataLoader -from espnet.utils.dataset import TransformDataset +from espnet.utils.dataset import ChainerDataLoader, TransformDataset +from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.evaluator import BaseEvaluator - -from espnet.utils.deterministic_utils import set_deterministic_pytorch -from espnet.utils.training.train_utils import check_early_stop -from espnet.utils.training.train_utils import set_early_stop - from espnet.utils.training.iterators import ShufflingEnabler - from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop, set_early_stop class CustomEvaluator(BaseEvaluator): @@ -349,7 +340,8 @@ def train(args): model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "noam": - from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + from espnet.nets.pytorch_backend.transformer.optimizer import \ + get_std_opt optimizer = get_std_opt( model.parameters(), diff --git a/espnet2/asr/decoder/abs_decoder.py b/espnet2/asr/decoder/abs_decoder.py index 4ad18d5e368..e46d1c24fcb 100644 --- a/espnet2/asr/decoder/abs_decoder.py +++ b/espnet2/asr/decoder/abs_decoder.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/asr/decoder/mlm_decoder.py b/espnet2/asr/decoder/mlm_decoder.py index 85cd1d3757f..b9d8df7dc9c 100644 --- a/espnet2/asr/decoder/mlm_decoder.py +++ b/espnet2/asr/decoder/mlm_decoder.py @@ -7,16 +7,17 @@ import torch from typeguard import check_argument_types +from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet2.asr.decoder.abs_decoder import AbsDecoder class MLMDecoder(AbsDecoder): diff --git a/espnet2/asr/decoder/rnn_decoder.py b/espnet2/asr/decoder/rnn_decoder.py index fc938225f35..3e5ad002493 100644 --- a/espnet2/asr/decoder/rnn_decoder.py +++ b/espnet2/asr/decoder/rnn_decoder.py @@ -5,11 +5,10 @@ import torch.nn.functional as F from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.nets_utils import to_device -from espnet.nets.pytorch_backend.rnn.attentions import initial_att from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device +from espnet.nets.pytorch_backend.rnn.attentions import initial_att def build_attention_list( diff --git a/espnet2/asr/decoder/transformer_decoder.py b/espnet2/asr/decoder/transformer_decoder.py index 1bd74cb76c1..989356d9c71 100644 --- a/espnet2/asr/decoder/transformer_decoder.py +++ b/espnet2/asr/decoder/transformer_decoder.py @@ -2,30 +2,32 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Decoder definition.""" -from typing import Any -from typing import List -from typing import Sequence -from typing import Tuple +from typing import Any, List, Sequence, Tuple import torch from typeguard import check_argument_types +from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer -from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution -from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ + DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ + DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution -from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D +from espnet.nets.pytorch_backend.transformer.lightconv import \ + LightweightConvolution +from espnet.nets.pytorch_backend.transformer.lightconv2d import \ + LightweightConvolution2D from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat from espnet.nets.scorer_interface import BatchScorerInterface -from espnet2.asr.decoder.abs_decoder import AbsDecoder class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): diff --git a/espnet2/asr/encoder/abs_encoder.py b/espnet2/asr/encoder/abs_encoder.py index 1fb7c97c35b..22a1a103458 100644 --- a/espnet2/asr/encoder/abs_encoder.py +++ b/espnet2/asr/encoder/abs_encoder.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod -from typing import Optional -from typing import Tuple +from abc import ABC, abstractmethod +from typing import Optional, Tuple import torch diff --git a/espnet2/asr/encoder/conformer_encoder.py b/espnet2/asr/encoder/conformer_encoder.py index c0c3d92fd1c..2c5c2f1c012 100644 --- a/espnet2/asr/encoder/conformer_encoder.py +++ b/espnet2/asr/encoder/conformer_encoder.py @@ -3,46 +3,41 @@ """Conformer encoder definition.""" -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - import logging -import torch +from typing import List, Optional, Tuple, Union +import torch from typeguard import check_argument_types +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer -from espnet.nets.pytorch_backend.nets_utils import get_activation -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.attention import ( - MultiHeadedAttention, # noqa: H301 - RelPositionMultiHeadedAttention, # noqa: H301 - LegacyRelPositionMultiHeadedAttention, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.embedding import ( - PositionalEncoding, # noqa: H301 - ScaledPositionalEncoding, # noqa: H301 - RelPositionalEncoding, # noqa: H301 - LegacyRelPositionalEncoding, # noqa: H301 -) +from espnet.nets.pytorch_backend.nets_utils import (get_activation, + make_pad_mask) +from espnet.nets.pytorch_backend.transformer.attention import \ + LegacyRelPositionMultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + RelPositionMultiHeadedAttention # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + LegacyRelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + RelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + ScaledPositionalEncoding # noqa: H301 from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling import check_short_utt -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet2.asr.ctc import CTC -from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, + Conv2dSubsampling8, TooShortUttError, check_short_utt) class ConformerEncoder(AbsEncoder): diff --git a/espnet2/asr/encoder/contextual_block_conformer_encoder.py b/espnet2/asr/encoder/contextual_block_conformer_encoder.py index 7152e34d44a..158837ec4aa 100644 --- a/espnet2/asr/encoder/contextual_block_conformer_encoder.py +++ b/espnet2/asr/encoder/contextual_block_conformer_encoder.py @@ -5,34 +5,31 @@ @author: Keqi Deng (UCAS) """ -from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule -from espnet.nets.pytorch_backend.conformer.contextual_block_encoder_layer import ( - ContextualBlockEncoderLayer, # noqa: H301 -) -from espnet.nets.pytorch_backend.nets_utils import ( - make_pad_mask, # noqa: H301 - get_activation, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.embedding import StreamPositionalEncoding -from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling_without_posenc import ( - Conv2dSubsamplingWOPosEnc, # noqa: H301 -) -from espnet2.asr.encoder.abs_encoder import AbsEncoder import math +from typing import Optional # noqa: H301 +from typing import Tuple # noqa: H301 + import torch from typeguard import check_argument_types -from typing import ( - Optional, # noqa: H301 - Tuple, # noqa: H301 -) + +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.conformer.contextual_block_encoder_layer import \ + ContextualBlockEncoderLayer # noqa: H301 +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.nets_utils import get_activation # noqa: H301 +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask # noqa: H301 +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import \ + StreamPositionalEncoding +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling_without_posenc import \ + Conv2dSubsamplingWOPosEnc # noqa: H301 class ContextualBlockConformerEncoder(AbsEncoder): diff --git a/espnet2/asr/encoder/contextual_block_transformer_encoder.py b/espnet2/asr/encoder/contextual_block_transformer_encoder.py index ec3b7e28193..7c29d01d3c5 100644 --- a/espnet2/asr/encoder/contextual_block_transformer_encoder.py +++ b/espnet2/asr/encoder/contextual_block_transformer_encoder.py @@ -2,28 +2,28 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Encoder definition.""" -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.contextual_block_encoder_layer import ( - ContextualBlockEncoderLayer, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.embedding import StreamPositionalEncoding -from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling_without_posenc import ( - Conv2dSubsamplingWOPosEnc, # noqa: H301 -) -from espnet2.asr.encoder.abs_encoder import AbsEncoder import math +from typing import Optional, Tuple + import torch from typeguard import check_argument_types -from typing import Optional -from typing import Tuple + +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.contextual_block_encoder_layer import \ + ContextualBlockEncoderLayer # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + StreamPositionalEncoding +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling_without_posenc import \ + Conv2dSubsamplingWOPosEnc # noqa: H301 class ContextualBlockTransformerEncoder(AbsEncoder): diff --git a/espnet2/asr/encoder/hubert_encoder.py b/espnet2/asr/encoder/hubert_encoder.py index 2e96da8bf9b..29bea364ba2 100644 --- a/espnet2/asr/encoder/hubert_encoder.py +++ b/espnet2/asr/encoder/hubert_encoder.py @@ -12,18 +12,17 @@ import copy import logging import os +from pathlib import Path +from typing import Optional, Tuple + import torch import yaml - from filelock import FileLock -from pathlib import Path from typeguard import check_argument_types -from typing import Optional -from typing import Tuple +from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet2.asr.encoder.abs_encoder import AbsEncoder class FairseqHubertEncoder(AbsEncoder): @@ -278,11 +277,10 @@ def __init__( self.use_amp = use_amp try: from fairseq.data.dictionary import Dictionary - from fairseq.models.hubert.hubert import ( - HubertModel, # noqa: H301 - HubertConfig, # noqa: H301 - HubertPretrainingConfig, # noqa: H301 - ) + from fairseq.models.hubert.hubert import HubertConfig # noqa: H301 + from fairseq.models.hubert.hubert import HubertModel # noqa: H301 + from fairseq.models.hubert.hubert import \ + HubertPretrainingConfig # noqa: H301 except Exception as e: print("Error: FairSeq is not properly installed.") print("Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done") diff --git a/espnet2/asr/encoder/longformer_encoder.py b/espnet2/asr/encoder/longformer_encoder.py index 1d9dcfcc864..032c2dece25 100644 --- a/espnet2/asr/encoder/longformer_encoder.py +++ b/espnet2/asr/encoder/longformer_encoder.py @@ -3,36 +3,28 @@ """Conformer encoder definition.""" -from typing import List -from typing import Optional -from typing import Tuple +from typing import List, Optional, Tuple import torch - from typeguard import check_argument_types +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.conformer_encoder import ConformerEncoder from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer -from espnet.nets.pytorch_backend.nets_utils import get_activation -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.embedding import ( - PositionalEncoding, # noqa: H301 -) +from espnet.nets.pytorch_backend.nets_utils import (get_activation, + make_pad_mask) +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding # noqa: H301 from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling import check_short_utt -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet2.asr.ctc import CTC -from espnet2.asr.encoder.conformer_encoder import ConformerEncoder +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, + Conv2dSubsampling8, TooShortUttError, check_short_utt) class LongformerEncoder(ConformerEncoder): @@ -228,11 +220,11 @@ def __init__( self.selfattention_layer_type = selfattention_layer_type if selfattention_layer_type == "lf_selfattn": assert pos_enc_layer_type == "abs_pos" - from espnet.nets.pytorch_backend.transformer.longformer_attention import ( - LongformerAttention, # noqa: H301 - ) from longformer.longformer import LongformerConfig + from espnet.nets.pytorch_backend.transformer.longformer_attention import \ + LongformerAttention # noqa: H301 + encoder_selfattn_layer = LongformerAttention config = LongformerConfig( diff --git a/espnet2/asr/encoder/rnn_encoder.py b/espnet2/asr/encoder/rnn_encoder.py index fd57ebfd2d8..38b2e244134 100644 --- a/espnet2/asr/encoder/rnn_encoder.py +++ b/espnet2/asr/encoder/rnn_encoder.py @@ -1,15 +1,12 @@ -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Optional, Sequence, Tuple import numpy as np import torch from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.rnn.encoders import RNN -from espnet.nets.pytorch_backend.rnn.encoders import RNNP from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.encoders import RNN, RNNP class RNNEncoder(AbsEncoder): diff --git a/espnet2/asr/encoder/transformer_encoder.py b/espnet2/asr/encoder/transformer_encoder.py index b11cb8c25d3..d91a4901e3d 100644 --- a/espnet2/asr/encoder/transformer_encoder.py +++ b/espnet2/asr/encoder/transformer_encoder.py @@ -3,32 +3,28 @@ """Transformer encoder definition.""" -from typing import List -from typing import Optional -from typing import Tuple +from typing import List, Optional, Tuple import torch from typeguard import check_argument_types +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear -from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d -from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 from espnet.nets.pytorch_backend.transformer.repeat import repeat -from espnet.nets.pytorch_backend.transformer.subsampling import check_short_utt -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 -from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet2.asr.ctc import CTC -from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, + Conv2dSubsampling8, TooShortUttError, check_short_utt) class TransformerEncoder(AbsEncoder): diff --git a/espnet2/asr/encoder/vgg_rnn_encoder.py b/espnet2/asr/encoder/vgg_rnn_encoder.py index 8c36c8cf4f2..c648e0b947f 100644 --- a/espnet2/asr/encoder/vgg_rnn_encoder.py +++ b/espnet2/asr/encoder/vgg_rnn_encoder.py @@ -4,12 +4,10 @@ import torch from typeguard import check_argument_types +from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet.nets.e2e_asr_common import get_vgg2l_odim from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.rnn.encoders import RNN -from espnet.nets.pytorch_backend.rnn.encoders import RNNP -from espnet.nets.pytorch_backend.rnn.encoders import VGG2L -from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.rnn.encoders import RNN, RNNP, VGG2L class VGGRNNEncoder(AbsEncoder): diff --git a/espnet2/asr/encoder/wav2vec2_encoder.py b/espnet2/asr/encoder/wav2vec2_encoder.py index 68cad0ae60f..b6186e2596d 100644 --- a/espnet2/asr/encoder/wav2vec2_encoder.py +++ b/espnet2/asr/encoder/wav2vec2_encoder.py @@ -4,18 +4,17 @@ """Encoder definition.""" import contextlib import copy -from filelock import FileLock import logging import os -from typing import Optional -from typing import Tuple +from typing import Optional, Tuple import torch +from filelock import FileLock from typeguard import check_argument_types +from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm -from espnet2.asr.encoder.abs_encoder import AbsEncoder class FairSeqWav2Vec2Encoder(AbsEncoder): diff --git a/espnet2/asr/espnet_model.py b/espnet2/asr/espnet_model.py index 67698e95115..465b5a80153 100644 --- a/espnet2/asr/espnet_model.py +++ b/espnet2/asr/espnet_model.py @@ -1,21 +1,11 @@ -from contextlib import contextmanager import logging -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.nets.e2e_asr_common import ErrorCalculator -from espnet.nets.pytorch_backend.nets_utils import th_accuracy -from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) from espnet2.asr.ctc import CTC from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.encoder.abs_encoder import AbsEncoder @@ -28,6 +18,11 @@ from espnet2.layers.abs_normalize import AbsNormalize from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast diff --git a/espnet2/asr/frontend/abs_frontend.py b/espnet2/asr/frontend/abs_frontend.py index 538236fe944..8f785e38d9e 100644 --- a/espnet2/asr/frontend/abs_frontend.py +++ b/espnet2/asr/frontend/abs_frontend.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/asr/frontend/default.py b/espnet2/asr/frontend/default.py index a2aa62c133e..f9ac5245262 100644 --- a/espnet2/asr/frontend/default.py +++ b/espnet2/asr/frontend/default.py @@ -1,7 +1,5 @@ import copy -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import humanfriendly import numpy as np @@ -9,11 +7,11 @@ from torch_complex.tensor import ComplexTensor from typeguard import check_argument_types -from espnet.nets.pytorch_backend.frontends.frontend import Frontend from espnet2.asr.frontend.abs_frontend import AbsFrontend from espnet2.layers.log_mel import LogMel from espnet2.layers.stft import Stft from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet.nets.pytorch_backend.frontends.frontend import Frontend class DefaultFrontend(AbsFrontend): diff --git a/espnet2/asr/frontend/fused.py b/espnet2/asr/frontend/fused.py index 365de936fc7..bc08a4ed63f 100644 --- a/espnet2/asr/frontend/fused.py +++ b/espnet2/asr/frontend/fused.py @@ -1,10 +1,12 @@ -from espnet2.asr.frontend.abs_frontend import AbsFrontend -from espnet2.asr.frontend.default import DefaultFrontend -from espnet2.asr.frontend.s3prl import S3prlFrontend +from typing import Tuple + import numpy as np import torch from typeguard import check_argument_types -from typing import Tuple + +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.frontend.default import DefaultFrontend +from espnet2.asr.frontend.s3prl import S3prlFrontend class FusedFrontends(AbsFrontend): diff --git a/espnet2/asr/frontend/s3prl.py b/espnet2/asr/frontend/s3prl.py index 6a497e0fab7..1f380ed0db4 100644 --- a/espnet2/asr/frontend/s3prl.py +++ b/espnet2/asr/frontend/s3prl.py @@ -1,19 +1,17 @@ -from argparse import Namespace import copy import logging import os -from typing import Optional -from typing import Tuple -from typing import Union +from argparse import Namespace +from typing import Optional, Tuple, Union import humanfriendly import torch from typeguard import check_argument_types -from espnet.nets.pytorch_backend.frontends.frontend import Frontend -from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet2.asr.frontend.abs_frontend import AbsFrontend from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet.nets.pytorch_backend.frontends.frontend import Frontend +from espnet.nets.pytorch_backend.nets_utils import pad_list def base_s3prl_setup(args): diff --git a/espnet2/asr/frontend/windowing.py b/espnet2/asr/frontend/windowing.py index 55600ca30d8..200d33e9954 100644 --- a/espnet2/asr/frontend/windowing.py +++ b/espnet2/asr/frontend/windowing.py @@ -4,10 +4,12 @@ """Sliding Window for raw audio input data.""" -from espnet2.asr.frontend.abs_frontend import AbsFrontend +from typing import Tuple + import torch from typeguard import check_argument_types -from typing import Tuple + +from espnet2.asr.frontend.abs_frontend import AbsFrontend class SlidingWindow(AbsFrontend): diff --git a/espnet2/asr/maskctc_model.py b/espnet2/asr/maskctc_model.py index 2a95eec89ea..5e5139b14d1 100644 --- a/espnet2/asr/maskctc_model.py +++ b/espnet2/asr/maskctc_model.py @@ -1,24 +1,13 @@ +import logging from contextlib import contextmanager from itertools import groupby -import logging -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import numpy import torch +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.nets.beam_search import Hypothesis -from espnet.nets.e2e_asr_common import ErrorCalculator -from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform -from espnet.nets.pytorch_backend.nets_utils import th_accuracy -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) from espnet2.asr.ctc import CTC from espnet2.asr.decoder.mlm_decoder import MLMDecoder from espnet2.asr.encoder.abs_encoder import AbsEncoder @@ -30,6 +19,12 @@ from espnet2.layers.abs_normalize import AbsNormalize from espnet2.text.token_id_converter import TokenIDConverter from espnet2.torch_utils.device_funcs import force_gatherable +from espnet.nets.beam_search import Hypothesis +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast diff --git a/espnet2/asr/postencoder/abs_postencoder.py b/espnet2/asr/postencoder/abs_postencoder.py index f5ac03be27b..cebfa3b7021 100644 --- a/espnet2/asr/postencoder/abs_postencoder.py +++ b/espnet2/asr/postencoder/abs_postencoder.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/asr/postencoder/hugging_face_transformers_postencoder.py b/espnet2/asr/postencoder/hugging_face_transformers_postencoder.py index a8a8177f8fd..80a716593da 100644 --- a/espnet2/asr/postencoder/hugging_face_transformers_postencoder.py +++ b/espnet2/asr/postencoder/hugging_face_transformers_postencoder.py @@ -4,14 +4,15 @@ """Hugging Face Transformers PostEncoder.""" -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder -from typeguard import check_argument_types -from typing import Tuple - import copy import logging +from typing import Tuple + import torch +from typeguard import check_argument_types + +from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask try: from transformers import AutoModel diff --git a/espnet2/asr/preencoder/abs_preencoder.py b/espnet2/asr/preencoder/abs_preencoder.py index 3ecdc6b91f0..67777477e0b 100644 --- a/espnet2/asr/preencoder/abs_preencoder.py +++ b/espnet2/asr/preencoder/abs_preencoder.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/asr/preencoder/linear.py b/espnet2/asr/preencoder/linear.py index 9c7cc497fca..82dea1ad6dd 100644 --- a/espnet2/asr/preencoder/linear.py +++ b/espnet2/asr/preencoder/linear.py @@ -4,11 +4,12 @@ """Linear Projection.""" -from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder -from typeguard import check_argument_types from typing import Tuple import torch +from typeguard import check_argument_types + +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder class LinearProjection(AbsPreEncoder): diff --git a/espnet2/asr/preencoder/sinc.py b/espnet2/asr/preencoder/sinc.py index 9a9dfa6e4c0..1cf86def402 100644 --- a/espnet2/asr/preencoder/sinc.py +++ b/espnet2/asr/preencoder/sinc.py @@ -5,15 +5,14 @@ """Sinc convolutions for raw audio input.""" from collections import OrderedDict -from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder -from espnet2.layers.sinc_conv import LogCompression -from espnet2.layers.sinc_conv import SincConv +from typing import Optional, Tuple, Union + import humanfriendly import torch from typeguard import check_argument_types -from typing import Optional -from typing import Tuple -from typing import Union + +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.layers.sinc_conv import LogCompression, SincConv class LightweightSincConvs(AbsPreEncoder): diff --git a/espnet2/asr/specaug/abs_specaug.py b/espnet2/asr/specaug/abs_specaug.py index 3cbac418fb6..6c9c6d8ea18 100644 --- a/espnet2/asr/specaug/abs_specaug.py +++ b/espnet2/asr/specaug/abs_specaug.py @@ -1,5 +1,4 @@ -from typing import Optional -from typing import Tuple +from typing import Optional, Tuple import torch diff --git a/espnet2/asr/specaug/specaug.py b/espnet2/asr/specaug/specaug.py index 65ed221f220..2077275921d 100644 --- a/espnet2/asr/specaug/specaug.py +++ b/espnet2/asr/specaug/specaug.py @@ -1,11 +1,9 @@ """SpecAugment module.""" -from typing import Optional -from typing import Sequence -from typing import Union +from typing import Optional, Sequence, Union from espnet2.asr.specaug.abs_specaug import AbsSpecAug -from espnet2.layers.mask_along_axis import MaskAlongAxis -from espnet2.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth +from espnet2.layers.mask_along_axis import (MaskAlongAxis, + MaskAlongAxisVariableMaxWidth) from espnet2.layers.time_warp import TimeWarp diff --git a/espnet2/asr/transducer/beam_search_transducer.py b/espnet2/asr/transducer/beam_search_transducer.py index 211b15a0101..4bcd2904b7a 100644 --- a/espnet2/asr/transducer/beam_search_transducer.py +++ b/espnet2/asr/transducer/beam_search_transducer.py @@ -1,23 +1,17 @@ """Search algorithms for Transducer models.""" from dataclasses import dataclass -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from espnet.nets.pytorch_backend.transducer.utils import is_prefix -from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps -from espnet.nets.pytorch_backend.transducer.utils import select_k_expansions -from espnet.nets.pytorch_backend.transducer.utils import subtract - from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.transducer.joint_network import JointNetwork +from espnet.nets.pytorch_backend.transducer.utils import (is_prefix, + recombine_hyps, + select_k_expansions, + subtract) @dataclass diff --git a/espnet2/asr/transducer/error_calculator.py b/espnet2/asr/transducer/error_calculator.py index 5c624825a4f..4ddf9cc9b7b 100644 --- a/espnet2/asr/transducer/error_calculator.py +++ b/espnet2/asr/transducer/error_calculator.py @@ -1,7 +1,6 @@ """Error Calculator module for Transducer.""" -from typing import List -from typing import Tuple +from typing import List, Tuple import torch diff --git a/espnet2/asr/transducer/transducer_decoder.py b/espnet2/asr/transducer/transducer_decoder.py index 8543cb22752..1c032bc9983 100644 --- a/espnet2/asr/transducer/transducer_decoder.py +++ b/espnet2/asr/transducer/transducer_decoder.py @@ -1,18 +1,13 @@ """(RNN-)Transducer decoder definition.""" -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from typeguard import check_argument_types from espnet2.asr.decoder.abs_decoder import AbsDecoder -from espnet2.asr.transducer.beam_search_transducer import ExtendedHypothesis -from espnet2.asr.transducer.beam_search_transducer import Hypothesis +from espnet2.asr.transducer.beam_search_transducer import (ExtendedHypothesis, + Hypothesis) class TransducerDecoder(AbsDecoder): diff --git a/espnet2/bin/aggregate_stats_dirs.py b/espnet2/bin/aggregate_stats_dirs.py index b79e67c399d..7579b513402 100755 --- a/espnet2/bin/aggregate_stats_dirs.py +++ b/espnet2/bin/aggregate_stats_dirs.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 import argparse import logging -from pathlib import Path import sys -from typing import Iterable -from typing import Union +from pathlib import Path +from typing import Iterable, Union import numpy as np diff --git a/espnet2/bin/asr_align.py b/espnet2/bin/asr_align.py index a9f8823ca57..1bd00ae7ddc 100755 --- a/espnet2/bin/asr_align.py +++ b/espnet2/bin/asr_align.py @@ -5,33 +5,25 @@ import argparse import logging -from pathlib import Path import sys -from typing import Optional -from typing import TextIO -from typing import Union +from pathlib import Path +from typing import List, Optional, TextIO, Union import numpy as np import soundfile import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List +# imports for CTC segmentation +from ctc_segmentation import (CtcSegmentationParameters, ctc_segmentation, + determine_utterance_segments, prepare_text, + prepare_token_list) +from typeguard import check_argument_types, check_return_type -# imports for inference -from espnet.utils.cli_utils import get_commandline_args from espnet2.tasks.asr import ASRTask from espnet2.torch_utils.device_funcs import to_device from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none - -# imports for CTC segmentation -from ctc_segmentation import ctc_segmentation -from ctc_segmentation import CtcSegmentationParameters -from ctc_segmentation import determine_utterance_segments -from ctc_segmentation import prepare_text -from ctc_segmentation import prepare_token_list +from espnet2.utils.types import str2bool, str_or_none +# imports for inference +from espnet.utils.cli_utils import get_commandline_args class CTCSegmentationTask: diff --git a/espnet2/bin/asr_inference.py b/espnet2/bin/asr_inference.py index cbe9e7b3ac5..f025ffc80b4 100755 --- a/espnet2/bin/asr_inference.py +++ b/espnet2/bin/asr_inference.py @@ -1,36 +1,21 @@ #!/usr/bin/env python3 import argparse -from distutils.version import LooseVersion import logging -from pathlib import Path import sys -from typing import Any -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from distutils.version import LooseVersion +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.quantization -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List +from typeguard import check_argument_types, check_return_type -from espnet.nets.batch_beam_search import BatchBeamSearch -from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim -from espnet.nets.beam_search import BeamSearch -from espnet.nets.beam_search import Hypothesis -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet.nets.scorer_interface import BatchScorerInterface -from espnet.nets.scorers.ctc import CTCPrefixScorer -from espnet.nets.scorers.length_bonus import LengthBonus -from espnet.utils.cli_utils import get_commandline_args from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer -from espnet2.asr.transducer.beam_search_transducer import ( - ExtendedHypothesis as ExtTransHypothesis, # noqa: H301 -) -from espnet2.asr.transducer.beam_search_transducer import Hypothesis as TransHypothesis +from espnet2.asr.transducer.beam_search_transducer import \ + ExtendedHypothesis as ExtTransHypothesis # noqa: H301 +from espnet2.asr.transducer.beam_search_transducer import \ + Hypothesis as TransHypothesis from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.asr import ASRTask from espnet2.tasks.enh_s2t import EnhS2TTask @@ -40,9 +25,16 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim +from espnet.nets.beam_search import BeamSearch, Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args class Speech2Text: diff --git a/espnet2/bin/asr_inference_k2.py b/espnet2/bin/asr_inference_k2.py index 81b206fc978..830721197b3 100755 --- a/espnet2/bin/asr_inference_k2.py +++ b/espnet2/bin/asr_inference_k2.py @@ -2,24 +2,16 @@ import argparse import datetime import logging -from pathlib import Path import sys -from typing import Any -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import k2 import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List import yaml +from typeguard import check_argument_types, check_return_type -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.fst.lm_rescore import nbest_am_lm_scores from espnet2.tasks.asr import ASRTask @@ -29,9 +21,8 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.utils.cli_utils import get_commandline_args def indices_to_split_size(indices, total_elements: int = None): diff --git a/espnet2/bin/asr_inference_maskctc.py b/espnet2/bin/asr_inference_maskctc.py index 20b857482f1..8bbd43d3a6b 100644 --- a/espnet2/bin/asr_inference_maskctc.py +++ b/espnet2/bin/asr_inference_maskctc.py @@ -1,23 +1,14 @@ #!/usr/bin/env python3 import argparse import logging -from pathlib import Path import sys -from typing import Any -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List +from typeguard import check_argument_types, check_return_type -from espnet.nets.beam_search import Hypothesis -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet.utils.cli_utils import get_commandline_args from espnet2.asr.maskctc_model import MaskCTCInference from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.asr import ASRTask @@ -26,9 +17,11 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.beam_search import Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.utils.cli_utils import get_commandline_args class Speech2Text: diff --git a/espnet2/bin/asr_inference_streaming.py b/espnet2/bin/asr_inference_streaming.py index 934c0bff276..098fb6078bd 100755 --- a/espnet2/bin/asr_inference_streaming.py +++ b/espnet2/bin/asr_inference_streaming.py @@ -1,18 +1,19 @@ #!/usr/bin/env python3 import argparse -from espnet.nets.batch_beam_search_online import BatchBeamSearchOnline -from espnet.nets.beam_search import Hypothesis -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet.nets.scorer_interface import BatchScorerInterface -from espnet.nets.scorers.ctc import CTCPrefixScorer -from espnet.nets.scorers.length_bonus import LengthBonus -from espnet.utils.cli_utils import get_commandline_args -from espnet2.asr.encoder.contextual_block_transformer_encoder import ( - ContextualBlockTransformerEncoder, # noqa: H301 -) -from espnet2.asr.encoder.contextual_block_conformer_encoder import ( - ContextualBlockConformerEncoder, # noqa: H301 -) +import logging +import math +import sys +from pathlib import Path +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from typeguard import check_argument_types, check_return_type + +from espnet2.asr.encoder.contextual_block_conformer_encoder import \ + ContextualBlockConformerEncoder # noqa: H301 +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.asr import ASRTask from espnet2.tasks.lm import LMTask @@ -21,22 +22,15 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none -import logging -import math -import numpy as np -from pathlib import Path -import sys -import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.batch_beam_search_online import BatchBeamSearchOnline +from espnet.nets.beam_search import Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args class Speech2TextStreaming: diff --git a/espnet2/bin/diar_inference.py b/espnet2/bin/diar_inference.py index df44afed1f4..a794689c25f 100755 --- a/espnet2/bin/diar_inference.py +++ b/espnet2/bin/diar_inference.py @@ -2,31 +2,23 @@ import argparse import logging -from pathlib import Path import sys -from typing import Any -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch from tqdm import trange from typeguard import check_argument_types -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.npy_scp import NpyScpWriter from espnet2.tasks.diar import DiarizationTask from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import humanfriendly_parse_size_or_none -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (humanfriendly_parse_size_or_none, int_or_none, + str2bool, str2triple_str, str_or_none) +from espnet.utils.cli_utils import get_commandline_args class DiarizeSpeech: diff --git a/espnet2/bin/enh_inference.py b/espnet2/bin/enh_inference.py index 2deed3250c5..ddf7f73c1f3 100755 --- a/espnet2/bin/enh_inference.py +++ b/espnet2/bin/enh_inference.py @@ -1,24 +1,18 @@ #!/usr/bin/env python3 import argparse -from itertools import chain import logging -from pathlib import Path import sys -from typing import Any -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from itertools import chain +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import humanfriendly import numpy as np import torch +import yaml from tqdm import trange from typeguard import check_argument_types -import yaml -from espnet.utils.cli_utils import get_commandline_args from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE from espnet2.enh.loss.criterions.time_domain import SISNRLoss from espnet2.enh.loss.wrappers.pit_solver import PITSolver @@ -29,10 +23,8 @@ from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.train.abs_espnet_model import AbsESPnetModel from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none - +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.utils.cli_utils import get_commandline_args EPS = torch.finfo(torch.get_default_dtype()).eps diff --git a/espnet2/bin/enh_scoring.py b/espnet2/bin/enh_scoring.py index 1c42fbf1f6d..5d76f6be42e 100755 --- a/espnet2/bin/enh_scoring.py +++ b/espnet2/bin/enh_scoring.py @@ -2,21 +2,19 @@ import argparse import logging import sys -from typing import List -from typing import Union +from typing import List, Union -from mir_eval.separation import bss_eval_sources import numpy as np -from pystoi import stoi import torch +from mir_eval.separation import bss_eval_sources +from pystoi import stoi from typeguard import check_argument_types -from espnet.utils.cli_utils import get_commandline_args from espnet2.enh.loss.criterions.time_domain import SISNRLoss from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.fileio.sound_scp import SoundScpReader from espnet2.utils import config_argparse - +from espnet.utils.cli_utils import get_commandline_args si_snr_loss = SISNRLoss() diff --git a/espnet2/bin/launch.py b/espnet2/bin/launch.py index 57290c3262d..23acc9c55be 100755 --- a/espnet2/bin/launch.py +++ b/espnet2/bin/launch.py @@ -2,16 +2,15 @@ import argparse import logging import os -from pathlib import Path import shlex import shutil import subprocess import sys import uuid +from pathlib import Path +from espnet2.utils.types import str2bool, str_or_none from espnet.utils.cli_utils import get_commandline_args -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none def get_parser(): diff --git a/espnet2/bin/lm_calc_perplexity.py b/espnet2/bin/lm_calc_perplexity.py index 97ba229afe3..9a878b50cd0 100755 --- a/espnet2/bin/lm_calc_perplexity.py +++ b/espnet2/bin/lm_calc_perplexity.py @@ -1,29 +1,24 @@ #!/usr/bin/env python3 import argparse import logging -from pathlib import Path import sys -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from pathlib import Path +from typing import Optional, Sequence, Tuple, Union import numpy as np import torch from torch.nn.parallel import data_parallel from typeguard import check_argument_types -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.lm import LMTask from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.forward_adaptor import ForwardAdaptor from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import float_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (float_or_none, str2bool, str2triple_str, + str_or_none) +from espnet.utils.cli_utils import get_commandline_args def calc_perplexity( diff --git a/espnet2/bin/mt_inference.py b/espnet2/bin/mt_inference.py index e523e1e6d47..c20cde3b03c 100755 --- a/espnet2/bin/mt_inference.py +++ b/espnet2/bin/mt_inference.py @@ -1,27 +1,14 @@ #!/usr/bin/env python3 import argparse import logging -from pathlib import Path import sys -from typing import Any -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List +from typeguard import check_argument_types, check_return_type -from espnet.nets.batch_beam_search import BatchBeamSearch -from espnet.nets.beam_search import BeamSearch -from espnet.nets.beam_search import Hypothesis -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet.nets.scorer_interface import BatchScorerInterface -from espnet.nets.scorers.length_bonus import LengthBonus -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.lm import LMTask from espnet2.tasks.mt import MTTask @@ -30,9 +17,14 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.beam_search import BeamSearch, Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args class Text2Text: diff --git a/espnet2/bin/split_scps.py b/espnet2/bin/split_scps.py index 557c70bac2c..ff4f15c3d23 100755 --- a/espnet2/bin/split_scps.py +++ b/espnet2/bin/split_scps.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 import argparse +import logging +import sys from collections import Counter from itertools import zip_longest -import logging from pathlib import Path -import sys -from typing import List -from typing import Optional +from typing import List, Optional from espnet.utils.cli_utils import get_commandline_args diff --git a/espnet2/bin/st_inference.py b/espnet2/bin/st_inference.py index 4cf9bc4d1a6..25155c3400c 100755 --- a/espnet2/bin/st_inference.py +++ b/espnet2/bin/st_inference.py @@ -1,27 +1,14 @@ #!/usr/bin/env python3 import argparse import logging -from pathlib import Path import sys -from typing import Any -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List +from typeguard import check_argument_types, check_return_type -from espnet.nets.batch_beam_search import BatchBeamSearch -from espnet.nets.beam_search import BeamSearch -from espnet.nets.beam_search import Hypothesis -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet.nets.scorer_interface import BatchScorerInterface -from espnet.nets.scorers.length_bonus import LengthBonus -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.enh_s2t import EnhS2TTask from espnet2.tasks.lm import LMTask @@ -31,9 +18,14 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.beam_search import BeamSearch, Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args class Speech2Text: diff --git a/espnet2/bin/st_inference_streaming.py b/espnet2/bin/st_inference_streaming.py index 8be428f2441..e3b9a4fb8bd 100644 --- a/espnet2/bin/st_inference_streaming.py +++ b/espnet2/bin/st_inference_streaming.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 import argparse -from espnet.nets.batch_beam_search_online import BatchBeamSearchOnline -from espnet.nets.beam_search import Hypothesis -from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError -from espnet.nets.scorer_interface import BatchScorerInterface -from espnet.nets.scorers.length_bonus import LengthBonus -from espnet.utils.cli_utils import get_commandline_args -from espnet2.asr.encoder.contextual_block_transformer_encoder import ( - ContextualBlockTransformerEncoder, # noqa: H301 -) -from espnet2.asr.encoder.contextual_block_conformer_encoder import ( - ContextualBlockConformerEncoder, # noqa: H301 -) +import logging +import math +import sys +from pathlib import Path +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from typeguard import check_argument_types, check_return_type + +from espnet2.asr.encoder.contextual_block_conformer_encoder import \ + ContextualBlockConformerEncoder # noqa: H301 +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.tasks.lm import LMTask from espnet2.tasks.st import STTask @@ -20,22 +22,14 @@ from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none -import logging -import math -import numpy as np -from pathlib import Path -import sys -import torch -from typeguard import check_argument_types -from typeguard import check_return_type -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.batch_beam_search_online import BatchBeamSearchOnline +from espnet.nets.beam_search import Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args class Speech2TextStreaming: diff --git a/espnet2/bin/tokenize_text.py b/espnet2/bin/tokenize_text.py index f22068c8846..890a112aaf2 100755 --- a/espnet2/bin/tokenize_text.py +++ b/espnet2/bin/tokenize_text.py @@ -1,20 +1,18 @@ #!/usr/bin/env python3 import argparse -from collections import Counter import logging -from pathlib import Path import sys -from typing import List -from typing import Optional +from collections import Counter +from pathlib import Path +from typing import List, Optional from typeguard import check_argument_types -from espnet.utils.cli_utils import get_commandline_args from espnet2.text.build_tokenizer import build_tokenizer from espnet2.text.cleaner import TextCleaner from espnet2.text.phoneme_tokenizer import g2p_choices -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str_or_none +from espnet.utils.cli_utils import get_commandline_args def field2slice(field: Optional[str]) -> slice: diff --git a/espnet2/bin/tts_inference.py b/espnet2/bin/tts_inference.py index 6e3da15f0de..500152cdfa5 100755 --- a/espnet2/bin/tts_inference.py +++ b/espnet2/bin/tts_inference.py @@ -7,23 +7,15 @@ import shutil import sys import time - -from packaging.version import parse as V from pathlib import Path -from typing import Any -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import numpy as np import soundfile as sf import torch - +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.utils.cli_utils import get_commandline_args from espnet2.fileio.npy_scp import NpyScpWriter from espnet2.gan_tts.vits import VITS from espnet2.tasks.tts import TTSTask @@ -35,9 +27,8 @@ from espnet2.tts.transformer import Transformer from espnet2.tts.utils import DurationCalculator from espnet2.utils import config_argparse -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.utils.cli_utils import get_commandline_args class Text2Speech: @@ -288,7 +279,8 @@ def from_pretrained( if vocoder_tag is not None: if vocoder_tag.startswith("parallel_wavegan/"): try: - from parallel_wavegan.utils import download_pretrained_model + from parallel_wavegan.utils import \ + download_pretrained_model except ImportError: logging.error( diff --git a/espnet2/diar/abs_diar.py b/espnet2/diar/abs_diar.py index 9cb2f2b2cc2..e9ca1ec419e 100644 --- a/espnet2/diar/abs_diar.py +++ b/espnet2/diar/abs_diar.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import OrderedDict from typing import Tuple diff --git a/espnet2/diar/attractor/abs_attractor.py b/espnet2/diar/attractor/abs_attractor.py index 914fdb62ea2..ca07033f575 100644 --- a/espnet2/diar/attractor/abs_attractor.py +++ b/espnet2/diar/attractor/abs_attractor.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/diar/decoder/abs_decoder.py b/espnet2/diar/decoder/abs_decoder.py index bd9a1674144..1fe7bdede1e 100644 --- a/espnet2/diar/decoder/abs_decoder.py +++ b/espnet2/diar/decoder/abs_decoder.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/diar/espnet_model.py b/espnet2/diar/espnet_model.py index 2017316f70f..d7986ef64fe 100644 --- a/espnet2/diar/espnet_model.py +++ b/espnet2/diar/espnet_model.py @@ -3,16 +3,13 @@ from contextlib import contextmanager from itertools import permutations -from packaging.version import parse as V -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Dict, Optional, Tuple import numpy as np import torch +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import to_device from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.frontend.abs_frontend import AbsFrontend from espnet2.asr.specaug.abs_specaug import AbsSpecAug @@ -21,6 +18,7 @@ from espnet2.layers.abs_normalize import AbsNormalize from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.pytorch_backend.nets_utils import to_device if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast diff --git a/espnet2/enh/abs_enh.py b/espnet2/enh/abs_enh.py index c28745e26d1..7cfd1d89442 100644 --- a/espnet2/enh/abs_enh.py +++ b/espnet2/enh/abs_enh.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import OrderedDict from typing import Tuple diff --git a/espnet2/enh/decoder/abs_decoder.py b/espnet2/enh/decoder/abs_decoder.py index 1ab8cb6a557..c01c465c730 100644 --- a/espnet2/enh/decoder/abs_decoder.py +++ b/espnet2/enh/decoder/abs_decoder.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/enh/decoder/stft_decoder.py b/espnet2/enh/decoder/stft_decoder.py index 93768dd2484..82209244474 100644 --- a/espnet2/enh/decoder/stft_decoder.py +++ b/espnet2/enh/decoder/stft_decoder.py @@ -1,5 +1,5 @@ -from packaging.version import parse as V import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor from espnet2.enh.decoder.abs_decoder import AbsDecoder diff --git a/espnet2/enh/encoder/abs_encoder.py b/espnet2/enh/encoder/abs_encoder.py index ef1afb68213..5e12b053de7 100644 --- a/espnet2/enh/encoder/abs_encoder.py +++ b/espnet2/enh/encoder/abs_encoder.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/enh/encoder/stft_encoder.py b/espnet2/enh/encoder/stft_encoder.py index 2c1f68934d5..0f126036b0d 100644 --- a/espnet2/enh/encoder/stft_encoder.py +++ b/espnet2/enh/encoder/stft_encoder.py @@ -1,5 +1,5 @@ -from packaging.version import parse as V import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor from espnet2.enh.encoder.abs_encoder import AbsEncoder diff --git a/espnet2/enh/espnet_enh_s2t_model.py b/espnet2/enh/espnet_enh_s2t_model.py index 4d37ce0b0c0..12888d95759 100644 --- a/espnet2/enh/espnet_enh_s2t_model.py +++ b/espnet2/enh/espnet_enh_s2t_model.py @@ -1,13 +1,10 @@ -from contextlib import contextmanager import logging -from packaging.version import parse as V import random -from typing import Dict -from typing import List -from typing import Tuple -from typing import Union +from contextlib import contextmanager +from typing import Dict, List, Tuple, Union import torch +from packaging.version import parse as V from typeguard import check_argument_types from espnet2.asr.espnet_model import ESPnetASRModel diff --git a/espnet2/enh/espnet_model.py b/espnet2/enh/espnet_model.py index 06d9f72902e..cbb3034032e 100644 --- a/espnet2/enh/espnet_model.py +++ b/espnet2/enh/espnet_model.py @@ -1,12 +1,8 @@ """Enhancement model module.""" -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import OrderedDict -from typing import Tuple +from typing import Dict, List, Optional, OrderedDict, Tuple import torch +from packaging.version import parse as V from typeguard import check_argument_types from espnet2.enh.decoder.abs_decoder import AbsDecoder @@ -19,7 +15,6 @@ from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") EPS = torch.finfo(torch.get_default_dtype()).eps diff --git a/espnet2/enh/layers/beamformer.py b/espnet2/enh/layers/beamformer.py index 2ceeee6c728..6234f9f98c7 100644 --- a/espnet2/enh/layers/beamformer.py +++ b/espnet2/enh/layers/beamformer.py @@ -1,24 +1,15 @@ """Beamformer module.""" -from packaging.version import parse as V -from typing import List -from typing import Optional -from typing import Union +from typing import List, Optional, Union import torch +from packaging.version import parse as V from torch_complex import functional as FC from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.complex_utils import cat -from espnet2.enh.layers.complex_utils import complex_norm -from espnet2.enh.layers.complex_utils import einsum -from espnet2.enh.layers.complex_utils import inverse -from espnet2.enh.layers.complex_utils import is_complex -from espnet2.enh.layers.complex_utils import is_torch_complex_tensor -from espnet2.enh.layers.complex_utils import matmul -from espnet2.enh.layers.complex_utils import reverse -from espnet2.enh.layers.complex_utils import solve -from espnet2.enh.layers.complex_utils import to_double - +from espnet2.enh.layers.complex_utils import (cat, complex_norm, einsum, + inverse, is_complex, + is_torch_complex_tensor, matmul, + reverse, solve, to_double) is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") EPS = torch.finfo(torch.double).eps diff --git a/espnet2/enh/layers/complex_utils.py b/espnet2/enh/layers/complex_utils.py index 329eee35d7c..8fc407a51cb 100644 --- a/espnet2/enh/layers/complex_utils.py +++ b/espnet2/enh/layers/complex_utils.py @@ -1,14 +1,11 @@ """Beamformer module.""" -from packaging.version import parse as V -from typing import Sequence -from typing import Tuple -from typing import Union +from typing import Sequence, Tuple, Union import torch +from packaging.version import parse as V from torch_complex import functional as FC from torch_complex.tensor import ComplexTensor - EPS = torch.finfo(torch.double).eps is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0") is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/layers/dc_crn.py b/espnet2/enh/layers/dc_crn.py index ba781a4cd45..c4d68715645 100644 --- a/espnet2/enh/layers/dc_crn.py +++ b/espnet2/enh/layers/dc_crn.py @@ -9,8 +9,8 @@ import torch import torch.nn as nn -from espnet2.enh.layers.conv_utils import conv2d_output_shape -from espnet2.enh.layers.conv_utils import convtransp2d_output_shape +from espnet2.enh.layers.conv_utils import (conv2d_output_shape, + convtransp2d_output_shape) class GLSTM(nn.Module): diff --git a/espnet2/enh/layers/dnn_beamformer.py b/espnet2/enh/layers/dnn_beamformer.py index be4c3622e40..c48083dd748 100644 --- a/espnet2/enh/layers/dnn_beamformer.py +++ b/espnet2/enh/layers/dnn_beamformer.py @@ -1,35 +1,28 @@ """DNN beamformer module.""" -from packaging.version import parse as V -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - import logging +from typing import List, Optional, Tuple, Union + import torch +from packaging.version import parse as V from torch.nn import functional as F from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.beamformer import apply_beamforming_vector -from espnet2.enh.layers.beamformer import blind_analytic_normalization -from espnet2.enh.layers.beamformer import get_gev_vector -from espnet2.enh.layers.beamformer import get_lcmv_vector_with_rtf -from espnet2.enh.layers.beamformer import get_mvdr_vector -from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf -from espnet2.enh.layers.beamformer import get_mwf_vector -from espnet2.enh.layers.beamformer import get_rank1_mwf_vector -from espnet2.enh.layers.beamformer import get_rtf_matrix -from espnet2.enh.layers.beamformer import get_sdw_mwf_vector -from espnet2.enh.layers.beamformer import get_WPD_filter_v2 -from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf -from espnet2.enh.layers.beamformer import perform_WPD_filtering -from espnet2.enh.layers.beamformer import prepare_beamformer_stats -from espnet2.enh.layers.complex_utils import stack -from espnet2.enh.layers.complex_utils import to_double -from espnet2.enh.layers.complex_utils import to_float +from espnet2.enh.layers.beamformer import (apply_beamforming_vector, + blind_analytic_normalization, + get_gev_vector, + get_lcmv_vector_with_rtf, + get_mvdr_vector, + get_mvdr_vector_with_rtf, + get_mwf_vector, + get_rank1_mwf_vector, + get_rtf_matrix, get_sdw_mwf_vector, + get_WPD_filter_v2, + get_WPD_filter_with_rtf, + perform_WPD_filtering, + prepare_beamformer_stats) +from espnet2.enh.layers.complex_utils import stack, to_double, to_float from espnet2.enh.layers.mask_estimator import MaskEstimator - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") BEAMFORMER_TYPES = ( diff --git a/espnet2/enh/layers/dnn_wpe.py b/espnet2/enh/layers/dnn_wpe.py index f3430087742..ba5f9ccb107 100644 --- a/espnet2/enh/layers/dnn_wpe.py +++ b/espnet2/enh/layers/dnn_wpe.py @@ -1,14 +1,12 @@ -from typing import Tuple -from typing import Union +from typing import Tuple, Union import torch from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet2.enh.layers.complex_utils import to_double -from espnet2.enh.layers.complex_utils import to_float +from espnet2.enh.layers.complex_utils import to_double, to_float from espnet2.enh.layers.mask_estimator import MaskEstimator from espnet2.enh.layers.wpe import wpe_one_iteration +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask class DNN_WPE(torch.nn.Module): diff --git a/espnet2/enh/layers/dprnn.py b/espnet2/enh/layers/dprnn.py index 830e3c59a5e..4c89e3a4ccd 100644 --- a/espnet2/enh/layers/dprnn.py +++ b/espnet2/enh/layers/dprnn.py @@ -9,9 +9,8 @@ import torch -from torch.autograd import Variable import torch.nn as nn - +from torch.autograd import Variable EPS = torch.finfo(torch.get_default_dtype()).eps diff --git a/espnet2/enh/layers/ifasnet.py b/espnet2/enh/layers/ifasnet.py index 076898f4b2d..c6a4f6cbb3d 100644 --- a/espnet2/enh/layers/ifasnet.py +++ b/espnet2/enh/layers/ifasnet.py @@ -11,8 +11,7 @@ import torch.nn as nn from espnet2.enh.layers import dprnn -from espnet2.enh.layers.fasnet import BF_module -from espnet2.enh.layers.fasnet import FaSNet_base +from espnet2.enh.layers.fasnet import BF_module, FaSNet_base # implicit FaSNet (iFaSNet) diff --git a/espnet2/enh/layers/mask_estimator.py b/espnet2/enh/layers/mask_estimator.py index 6f40c66ddfe..6bd69e7ef2f 100644 --- a/espnet2/enh/layers/mask_estimator.py +++ b/espnet2/enh/layers/mask_estimator.py @@ -1,17 +1,14 @@ -from packaging.version import parse as V -from typing import Tuple -from typing import Union +from typing import Tuple, Union import numpy as np import torch +from packaging.version import parse as V from torch.nn import functional as F from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.rnn.encoders import RNN -from espnet.nets.pytorch_backend.rnn.encoders import RNNP from espnet2.enh.layers.complex_utils import is_complex - +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.encoders import RNN, RNNP is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/layers/skim.py b/espnet2/enh/layers/skim.py index f095f97495c..09ea339e476 100644 --- a/espnet2/enh/layers/skim.py +++ b/espnet2/enh/layers/skim.py @@ -6,9 +6,7 @@ import torch import torch.nn as nn -from espnet2.enh.layers.dprnn import merge_feature -from espnet2.enh.layers.dprnn import SingleRNN -from espnet2.enh.layers.dprnn import split_feature +from espnet2.enh.layers.dprnn import SingleRNN, merge_feature, split_feature from espnet2.enh.layers.tcn import choose_norm diff --git a/espnet2/enh/layers/wpe.py b/espnet2/enh/layers/wpe.py index e6117b89786..69548eec6a4 100644 --- a/espnet2/enh/layers/wpe.py +++ b/espnet2/enh/layers/wpe.py @@ -1,16 +1,12 @@ -from packaging.version import parse as V -from typing import Tuple -from typing import Union +from typing import Tuple, Union import torch import torch.nn.functional as F import torch_complex.functional as FC +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.complex_utils import einsum -from espnet2.enh.layers.complex_utils import matmul -from espnet2.enh.layers.complex_utils import reverse - +from espnet2.enh.layers.complex_utils import einsum, matmul, reverse is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/loss/criterions/abs_loss.py b/espnet2/enh/loss/criterions/abs_loss.py index c09119c9e07..4e2f1bbb676 100644 --- a/espnet2/enh/loss/criterions/abs_loss.py +++ b/espnet2/enh/loss/criterions/abs_loss.py @@ -1,6 +1,4 @@ -from abc import ABC -from abc import abstractmethod - +from abc import ABC, abstractmethod import torch diff --git a/espnet2/enh/loss/criterions/tf_domain.py b/espnet2/enh/loss/criterions/tf_domain.py index 4c4a91ef5d2..2cbe893ed6a 100644 --- a/espnet2/enh/loss/criterions/tf_domain.py +++ b/espnet2/enh/loss/criterions/tf_domain.py @@ -1,18 +1,15 @@ -from abc import ABC -from abc import abstractmethod -from functools import reduce import math -from packaging.version import parse as V +from abc import ABC, abstractmethod +from functools import reduce import torch import torch.nn.functional as F +from packaging.version import parse as V -from espnet2.enh.layers.complex_utils import complex_norm -from espnet2.enh.layers.complex_utils import is_complex -from espnet2.enh.layers.complex_utils import new_complex_like +from espnet2.enh.layers.complex_utils import (complex_norm, is_complex, + new_complex_like) from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") EPS = torch.finfo(torch.get_default_dtype()).eps diff --git a/espnet2/enh/loss/criterions/time_domain.py b/espnet2/enh/loss/criterions/time_domain.py index d000b83fbbb..d822af544f9 100644 --- a/espnet2/enh/loss/criterions/time_domain.py +++ b/espnet2/enh/loss/criterions/time_domain.py @@ -1,11 +1,10 @@ -from abc import ABC import logging +from abc import ABC import ci_sdr import fast_bss_eval import torch - from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss diff --git a/espnet2/enh/loss/wrappers/abs_wrapper.py b/espnet2/enh/loss/wrappers/abs_wrapper.py index e48a2b7f869..9133d6cc3af 100644 --- a/espnet2/enh/loss/wrappers/abs_wrapper.py +++ b/espnet2/enh/loss/wrappers/abs_wrapper.py @@ -1,8 +1,5 @@ -from abc import ABC -from abc import abstractmethod -from typing import Dict -from typing import List -from typing import Tuple +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple import torch diff --git a/espnet2/enh/separator/abs_separator.py b/espnet2/enh/separator/abs_separator.py index 72fe2eea918..ce68f51c887 100644 --- a/espnet2/enh/separator/abs_separator.py +++ b/espnet2/enh/separator/abs_separator.py @@ -1,9 +1,6 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Dict, Optional, Tuple import torch diff --git a/espnet2/enh/separator/asteroid_models.py b/espnet2/enh/separator/asteroid_models.py index 2310929c1e3..6f05eb4802c 100644 --- a/espnet2/enh/separator/asteroid_models.py +++ b/espnet2/enh/separator/asteroid_models.py @@ -1,8 +1,6 @@ -from collections import OrderedDict -from typing import Dict -from typing import Optional -from typing import Tuple import warnings +from collections import OrderedDict +from typing import Dict, Optional, Tuple import torch diff --git a/espnet2/enh/separator/conformer_separator.py b/espnet2/enh/separator/conformer_separator.py index 3e3574beade..ad28743baaa 100644 --- a/espnet2/enh/separator/conformer_separator.py +++ b/espnet2/enh/separator/conformer_separator.py @@ -1,21 +1,15 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.conformer.encoder import ( - Encoder as ConformerEncoder, # noqa: H301 -) -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask from espnet2.enh.layers.complex_utils import is_complex from espnet2.enh.separator.abs_separator import AbsSeparator - +from espnet.nets.pytorch_backend.conformer.encoder import \ + Encoder as ConformerEncoder # noqa: H301 +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/dan_separator.py b/espnet2/enh/separator/dan_separator.py index d3b3222ae90..0e7bf708312 100644 --- a/espnet2/enh/separator/dan_separator.py +++ b/espnet2/enh/separator/dan_separator.py @@ -1,17 +1,13 @@ from collections import OrderedDict from functools import reduce -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as Fun from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.rnn.encoders import RNN from espnet2.enh.separator.abs_separator import AbsSeparator +from espnet.nets.pytorch_backend.rnn.encoders import RNN class DANSeparator(AbsSeparator): diff --git a/espnet2/enh/separator/dc_crn_separator.py b/espnet2/enh/separator/dc_crn_separator.py index b3f9be4fddd..99eb1aa7f48 100644 --- a/espnet2/enh/separator/dc_crn_separator.py +++ b/espnet2/enh/separator/dc_crn_separator.py @@ -1,20 +1,14 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.complex_utils import is_complex -from espnet2.enh.layers.complex_utils import new_complex_like +from espnet2.enh.layers.complex_utils import is_complex, new_complex_like from espnet2.enh.layers.dc_crn import DC_CRN from espnet2.enh.separator.abs_separator import AbsSeparator - EPS = torch.finfo(torch.get_default_dtype()).eps is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/dccrn_separator.py b/espnet2/enh/separator/dccrn_separator.py index 74f793d14bd..0591c770b09 100644 --- a/espnet2/enh/separator/dccrn_separator.py +++ b/espnet2/enh/separator/dccrn_separator.py @@ -1,21 +1,15 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.complexnn import complex_cat -from espnet2.enh.layers.complexnn import ComplexBatchNorm -from espnet2.enh.layers.complexnn import ComplexConv2d -from espnet2.enh.layers.complexnn import ComplexConvTranspose2d -from espnet2.enh.layers.complexnn import NavieComplexLSTM +from espnet2.enh.layers.complexnn import (ComplexBatchNorm, ComplexConv2d, + ComplexConvTranspose2d, + NavieComplexLSTM, complex_cat) from espnet2.enh.separator.abs_separator import AbsSeparator is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/dpcl_e2e_separator.py b/espnet2/enh/separator/dpcl_e2e_separator.py index 35264c5c137..134dec4c617 100644 --- a/espnet2/enh/separator/dpcl_e2e_separator.py +++ b/espnet2/enh/separator/dpcl_e2e_separator.py @@ -1,15 +1,11 @@ from collections import OrderedDict -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.rnn.encoders import RNN from espnet2.enh.separator.abs_separator import AbsSeparator +from espnet.nets.pytorch_backend.rnn.encoders import RNN class DPCLE2ESeparator(AbsSeparator): diff --git a/espnet2/enh/separator/dpcl_separator.py b/espnet2/enh/separator/dpcl_separator.py index 0eb0abf67e0..c2d4229512d 100644 --- a/espnet2/enh/separator/dpcl_separator.py +++ b/espnet2/enh/separator/dpcl_separator.py @@ -1,15 +1,11 @@ from collections import OrderedDict -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.rnn.encoders import RNN from espnet2.enh.separator.abs_separator import AbsSeparator +from espnet.nets.pytorch_backend.rnn.encoders import RNN class DPCLSeparator(AbsSeparator): diff --git a/espnet2/enh/separator/dprnn_separator.py b/espnet2/enh/separator/dprnn_separator.py index d0f446ee36a..228837f97fd 100644 --- a/espnet2/enh/separator/dprnn_separator.py +++ b/espnet2/enh/separator/dprnn_separator.py @@ -1,21 +1,14 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor from espnet2.enh.layers.complex_utils import is_complex -from espnet2.enh.layers.dprnn import DPRNN -from espnet2.enh.layers.dprnn import merge_feature -from espnet2.enh.layers.dprnn import split_feature +from espnet2.enh.layers.dprnn import DPRNN, merge_feature, split_feature from espnet2.enh.separator.abs_separator import AbsSeparator - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/fasnet_separator.py b/espnet2/enh/separator/fasnet_separator.py index deb1a4d1f43..325398a7359 100644 --- a/espnet2/enh/separator/fasnet_separator.py +++ b/espnet2/enh/separator/fasnet_separator.py @@ -1,17 +1,13 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Dict, List, Optional, Tuple import torch +from packaging.version import parse as V from espnet2.enh.layers.fasnet import FaSNet_TAC from espnet2.enh.layers.ifasnet import iFaSNet from espnet2.enh.separator.abs_separator import AbsSeparator - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/neural_beamformer.py b/espnet2/enh/separator/neural_beamformer.py index dff26d6f66c..aa4047f24ee 100644 --- a/espnet2/enh/separator/neural_beamformer.py +++ b/espnet2/enh/separator/neural_beamformer.py @@ -1,9 +1,5 @@ from collections import OrderedDict -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch_complex.tensor import ComplexTensor diff --git a/espnet2/enh/separator/rnn_separator.py b/espnet2/enh/separator/rnn_separator.py index 2a551edbde0..3f4629def5a 100644 --- a/espnet2/enh/separator/rnn_separator.py +++ b/espnet2/enh/separator/rnn_separator.py @@ -1,18 +1,13 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet.nets.pytorch_backend.rnn.encoders import RNN from espnet2.enh.layers.complex_utils import is_complex from espnet2.enh.separator.abs_separator import AbsSeparator - +from espnet.nets.pytorch_backend.rnn.encoders import RNN is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/skim_separator.py b/espnet2/enh/separator/skim_separator.py index 15dd467ea53..be13531d46c 100644 --- a/espnet2/enh/separator/skim_separator.py +++ b/espnet2/enh/separator/skim_separator.py @@ -1,9 +1,5 @@ from collections import OrderedDict -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch_complex.tensor import ComplexTensor diff --git a/espnet2/enh/separator/svoice_separator.py b/espnet2/enh/separator/svoice_separator.py index 6a7fce40cba..54545ff13ea 100644 --- a/espnet2/enh/separator/svoice_separator.py +++ b/espnet2/enh/separator/svoice_separator.py @@ -1,17 +1,13 @@ -from collections import OrderedDict import math -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from espnet2.enh.layers.dpmulcat import DPMulCat -from espnet2.enh.layers.dprnn import merge_feature -from espnet2.enh.layers.dprnn import split_feature +from espnet2.enh.layers.dprnn import merge_feature, split_feature from espnet2.enh.separator.abs_separator import AbsSeparator diff --git a/espnet2/enh/separator/tcn_separator.py b/espnet2/enh/separator/tcn_separator.py index 12c6db42e42..0e34c540937 100644 --- a/espnet2/enh/separator/tcn_separator.py +++ b/espnet2/enh/separator/tcn_separator.py @@ -1,19 +1,14 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor from espnet2.enh.layers.complex_utils import is_complex from espnet2.enh.layers.tcn import TemporalConvNet from espnet2.enh.separator.abs_separator import AbsSeparator - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/enh/separator/transformer_separator.py b/espnet2/enh/separator/transformer_separator.py index c6dbcf91eaa..93c2ee30fd6 100644 --- a/espnet2/enh/separator/transformer_separator.py +++ b/espnet2/enh/separator/transformer_separator.py @@ -1,26 +1,19 @@ from collections import OrderedDict -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor - -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.transformer.embedding import ( - PositionalEncoding, # noqa: H301 - ScaledPositionalEncoding, # noqa: H301 -) -from espnet.nets.pytorch_backend.transformer.encoder import ( - Encoder as TransformerEncoder, # noqa: H301 -) from espnet2.enh.layers.complex_utils import is_complex from espnet2.enh.separator.abs_separator import AbsSeparator - +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + ScaledPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.encoder import \ + Encoder as TransformerEncoder # noqa: H301 is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/fileio/datadir_writer.py b/espnet2/fileio/datadir_writer.py index bafdf984f19..625c73dbed7 100644 --- a/espnet2/fileio/datadir_writer.py +++ b/espnet2/fileio/datadir_writer.py @@ -1,9 +1,8 @@ +import warnings from pathlib import Path from typing import Union -import warnings -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type class DatadirWriter: diff --git a/espnet2/fileio/read_text.py b/espnet2/fileio/read_text.py index e26e7a1c582..830c1651b2d 100644 --- a/espnet2/fileio/read_text.py +++ b/espnet2/fileio/read_text.py @@ -1,8 +1,6 @@ import logging from pathlib import Path -from typing import Dict -from typing import List -from typing import Union +from typing import Dict, List, Union from typeguard import check_argument_types diff --git a/espnet2/fileio/rttm.py b/espnet2/fileio/rttm.py index 5b8a343f3dc..feec3a82f60 100644 --- a/espnet2/fileio/rttm.py +++ b/espnet2/fileio/rttm.py @@ -1,12 +1,9 @@ import collections.abc +import re from pathlib import Path -from typing import Dict -from typing import List -from typing import Tuple -from typing import Union +from typing import Dict, List, Tuple, Union import numpy as np -import re from typeguard import check_argument_types diff --git a/espnet2/fst/lm_rescore.py b/espnet2/fst/lm_rescore.py index 340bd409643..41d662c830f 100644 --- a/espnet2/fst/lm_rescore.py +++ b/espnet2/fst/lm_rescore.py @@ -1,8 +1,7 @@ -from typing import List -from typing import Tuple +import math +from typing import List, Tuple import k2 -import math import torch diff --git a/espnet2/gan_tts/abs_gan_tts.py b/espnet2/gan_tts/abs_gan_tts.py index 248264ecbc9..feee1d293f0 100644 --- a/espnet2/gan_tts/abs_gan_tts.py +++ b/espnet2/gan_tts/abs_gan_tts.py @@ -3,11 +3,8 @@ """GAN-based TTS abstrast class.""" -from abc import ABC -from abc import abstractmethod - -from typing import Dict -from typing import Union +from abc import ABC, abstractmethod +from typing import Dict, Union import torch diff --git a/espnet2/gan_tts/espnet_model.py b/espnet2/gan_tts/espnet_model.py index 5cc1785a4d5..81d898df186 100644 --- a/espnet2/gan_tts/espnet_model.py +++ b/espnet2/gan_tts/espnet_model.py @@ -4,13 +4,10 @@ """GAN-based text-to-speech ESPnet model.""" from contextlib import contextmanager -from packaging.version import parse as V -from typing import Any -from typing import Dict -from typing import Optional +from typing import Any, Dict, Optional import torch - +from packaging.version import parse as V from typeguard import check_argument_types from espnet2.gan_tts.abs_gan_tts import AbsGANTTS diff --git a/espnet2/gan_tts/hifigan/__init__.py b/espnet2/gan_tts/hifigan/__init__.py index c65d1896c03..d99d366ae9a 100644 --- a/espnet2/gan_tts/hifigan/__init__.py +++ b/espnet2/gan_tts/hifigan/__init__.py @@ -1,8 +1,10 @@ from espnet2.gan_tts.hifigan.hifigan import HiFiGANGenerator # NOQA -from espnet2.gan_tts.hifigan.hifigan import HiFiGANMultiPeriodDiscriminator # NOQA -from espnet2.gan_tts.hifigan.hifigan import HiFiGANMultiScaleDiscriminator # NOQA -from espnet2.gan_tts.hifigan.hifigan import ( # NOQA - HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA -) +from espnet2.gan_tts.hifigan.hifigan import \ + HiFiGANMultiPeriodDiscriminator # NOQA +from espnet2.gan_tts.hifigan.hifigan import \ + HiFiGANMultiScaleDiscriminator # NOQA +from espnet2.gan_tts.hifigan.hifigan import \ + HiFiGANMultiScaleMultiPeriodDiscriminator # NOQA from espnet2.gan_tts.hifigan.hifigan import HiFiGANPeriodDiscriminator # NOQA -from espnet2.gan_tts.hifigan.hifigan import HiFiGANScaleDiscriminator # NOQA +from espnet2.gan_tts.hifigan.hifigan import \ + HiFiGANScaleDiscriminator # NOQA; NOQA diff --git a/espnet2/gan_tts/hifigan/hifigan.py b/espnet2/gan_tts/hifigan/hifigan.py index 516678366b1..18c311907a2 100644 --- a/espnet2/gan_tts/hifigan/hifigan.py +++ b/espnet2/gan_tts/hifigan/hifigan.py @@ -9,11 +9,7 @@ import copy import logging - -from typing import Any -from typing import Dict -from typing import List -from typing import Optional +from typing import Any, Dict, List, Optional import numpy as np import torch diff --git a/espnet2/gan_tts/hifigan/loss.py b/espnet2/gan_tts/hifigan/loss.py index 083b5de6cb5..d16e12a70f4 100644 --- a/espnet2/gan_tts/hifigan/loss.py +++ b/espnet2/gan_tts/hifigan/loss.py @@ -7,10 +7,7 @@ """ -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F diff --git a/espnet2/gan_tts/hifigan/residual_block.py b/espnet2/gan_tts/hifigan/residual_block.py index c5ac2c4e2f5..a6ac90b1af6 100644 --- a/espnet2/gan_tts/hifigan/residual_block.py +++ b/espnet2/gan_tts/hifigan/residual_block.py @@ -7,9 +7,7 @@ """ -from typing import Any -from typing import Dict -from typing import List +from typing import Any, Dict, List import torch diff --git a/espnet2/gan_tts/jets/alignments.py b/espnet2/gan_tts/jets/alignments.py index 6548dc5009f..e35b63e3383 100644 --- a/espnet2/gan_tts/jets/alignments.py +++ b/espnet2/gan_tts/jets/alignments.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from numba import jit diff --git a/espnet2/gan_tts/jets/generator.py b/espnet2/gan_tts/jets/generator.py index 75734e49c23..0e3a974d21a 100644 --- a/espnet2/gan_tts/jets/generator.py +++ b/espnet2/gan_tts/jets/generator.py @@ -4,38 +4,31 @@ """Generator module in JETS.""" import logging - -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import torch import torch.nn.functional as F -from espnet.nets.pytorch_backend.conformer.encoder import ( - Encoder as ConformerEncoder, # noqa: H301 -) -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding -from espnet.nets.pytorch_backend.transformer.encoder import ( - Encoder as TransformerEncoder, # noqa: H301 -) from espnet2.gan_tts.hifigan import HiFiGANGenerator -from espnet2.gan_tts.jets.alignments import AlignmentModule -from espnet2.gan_tts.jets.alignments import average_by_duration -from espnet2.gan_tts.jets.alignments import viterbi_decode +from espnet2.gan_tts.jets.alignments import (AlignmentModule, + average_by_duration, + viterbi_decode) from espnet2.gan_tts.jets.length_regulator import GaussianUpsampling from espnet2.gan_tts.utils import get_random_segments from espnet2.torch_utils.initialize import initialize from espnet2.tts.fastspeech2.variance_predictor import VariancePredictor from espnet2.tts.gst.style_encoder import StyleEncoder +from espnet.nets.pytorch_backend.conformer.encoder import \ + Encoder as ConformerEncoder # noqa: H301 +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictor +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + make_pad_mask) +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) +from espnet.nets.pytorch_backend.transformer.encoder import \ + Encoder as TransformerEncoder # noqa: H301 class JETSGenerator(torch.nn.Module): diff --git a/espnet2/gan_tts/jets/jets.py b/espnet2/gan_tts/jets/jets.py index e2e0e3cd2e6..79f8ff46b29 100644 --- a/espnet2/gan_tts/jets/jets.py +++ b/espnet2/gan_tts/jets/jets.py @@ -3,31 +3,26 @@ """JETS module for GAN-TTS task.""" -from typing import Any -from typing import Dict -from typing import Optional +from typing import Any, Dict, Optional import torch - from typeguard import check_argument_types from espnet2.gan_tts.abs_gan_tts import AbsGANTTS -from espnet2.gan_tts.hifigan import HiFiGANMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANScaleDiscriminator -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import FeatureMatchLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import MelSpectrogramLoss +from espnet2.gan_tts.hifigan import (HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator) +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + MelSpectrogramLoss) from espnet2.gan_tts.jets.generator import JETSGenerator -from espnet2.gan_tts.jets.loss import ForwardSumLoss -from espnet2.gan_tts.jets.loss import VarianceLoss +from espnet2.gan_tts.jets.loss import ForwardSumLoss, VarianceLoss from espnet2.gan_tts.utils import get_segments from espnet2.torch_utils.device_funcs import force_gatherable - AVAILABLE_GENERATERS = { "jets_generator": JETSGenerator, } diff --git a/espnet2/gan_tts/jets/loss.py b/espnet2/gan_tts/jets/loss.py index 066c9fe8829..ed9b8979216 100644 --- a/espnet2/gan_tts/jets/loss.py +++ b/espnet2/gan_tts/jets/loss.py @@ -6,14 +6,13 @@ from typing import Tuple import numpy as np -from scipy.stats import betabinom import torch import torch.nn.functional as F +from scipy.stats import betabinom from typeguard import check_argument_types -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( - DurationPredictorLoss, # noqa: H301 -) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictorLoss # noqa: H301 from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask diff --git a/espnet2/gan_tts/joint/joint_text2wav.py b/espnet2/gan_tts/joint/joint_text2wav.py index 5d85e337642..2059030b477 100644 --- a/espnet2/gan_tts/joint/joint_text2wav.py +++ b/espnet2/gan_tts/joint/joint_text2wav.py @@ -3,33 +3,30 @@ """Joint text-to-wav module for end-to-end training.""" -from typing import Any -from typing import Dict +from typing import Any, Dict import torch - from typeguard import check_argument_types from espnet2.gan_tts.abs_gan_tts import AbsGANTTS -from espnet2.gan_tts.hifigan import HiFiGANGenerator -from espnet2.gan_tts.hifigan import HiFiGANMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANScaleDiscriminator -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import FeatureMatchLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import MelSpectrogramLoss -from espnet2.gan_tts.melgan import MelGANGenerator -from espnet2.gan_tts.melgan import MelGANMultiScaleDiscriminator +from espnet2.gan_tts.hifigan import (HiFiGANGenerator, + HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator) +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + MelSpectrogramLoss) +from espnet2.gan_tts.melgan import (MelGANGenerator, + MelGANMultiScaleDiscriminator) from espnet2.gan_tts.melgan.pqmf import PQMF -from espnet2.gan_tts.parallel_wavegan import ParallelWaveGANDiscriminator -from espnet2.gan_tts.parallel_wavegan import ParallelWaveGANGenerator -from espnet2.gan_tts.style_melgan import StyleMelGANDiscriminator -from espnet2.gan_tts.style_melgan import StyleMelGANGenerator -from espnet2.gan_tts.utils import get_random_segments -from espnet2.gan_tts.utils import get_segments +from espnet2.gan_tts.parallel_wavegan import (ParallelWaveGANDiscriminator, + ParallelWaveGANGenerator) +from espnet2.gan_tts.style_melgan import (StyleMelGANDiscriminator, + StyleMelGANGenerator) +from espnet2.gan_tts.utils import get_random_segments, get_segments from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.tts.fastspeech import FastSpeech from espnet2.tts.fastspeech2 import FastSpeech2 diff --git a/espnet2/gan_tts/melgan/melgan.py b/espnet2/gan_tts/melgan/melgan.py index 7b1281d14fd..0c930e1bb71 100644 --- a/espnet2/gan_tts/melgan/melgan.py +++ b/espnet2/gan_tts/melgan/melgan.py @@ -8,10 +8,7 @@ """ import logging - -from typing import Any -from typing import Dict -from typing import List +from typing import Any, Dict, List import numpy as np import torch diff --git a/espnet2/gan_tts/melgan/pqmf.py b/espnet2/gan_tts/melgan/pqmf.py index ef4e053d862..7e504b7dc71 100644 --- a/espnet2/gan_tts/melgan/pqmf.py +++ b/espnet2/gan_tts/melgan/pqmf.py @@ -10,7 +10,6 @@ import numpy as np import torch import torch.nn.functional as F - from scipy.signal import kaiser diff --git a/espnet2/gan_tts/melgan/residual_stack.py b/espnet2/gan_tts/melgan/residual_stack.py index 3fb7e927e87..daeb009c51e 100644 --- a/espnet2/gan_tts/melgan/residual_stack.py +++ b/espnet2/gan_tts/melgan/residual_stack.py @@ -7,8 +7,7 @@ """ -from typing import Any -from typing import Dict +from typing import Any, Dict import torch diff --git a/espnet2/gan_tts/parallel_wavegan/__init__.py b/espnet2/gan_tts/parallel_wavegan/__init__.py index 357235c4847..0a755e29b31 100644 --- a/espnet2/gan_tts/parallel_wavegan/__init__.py +++ b/espnet2/gan_tts/parallel_wavegan/__init__.py @@ -1,4 +1,4 @@ -from espnet2.gan_tts.parallel_wavegan.parallel_wavegan import ( # NOQA - ParallelWaveGANDiscriminator, # NOQA - ParallelWaveGANGenerator, # NOQA -) +from espnet2.gan_tts.parallel_wavegan.parallel_wavegan import \ + ParallelWaveGANDiscriminator # NOQA +from espnet2.gan_tts.parallel_wavegan.parallel_wavegan import \ + ParallelWaveGANGenerator # NOQA; NOQA diff --git a/espnet2/gan_tts/parallel_wavegan/parallel_wavegan.py b/espnet2/gan_tts/parallel_wavegan/parallel_wavegan.py index 85b9ac224ae..34da658ca46 100644 --- a/espnet2/gan_tts/parallel_wavegan/parallel_wavegan.py +++ b/espnet2/gan_tts/parallel_wavegan/parallel_wavegan.py @@ -9,18 +9,14 @@ import logging import math - -from typing import Any -from typing import Dict -from typing import Optional +from typing import Any, Dict, Optional import numpy as np import torch from espnet2.gan_tts.parallel_wavegan import upsample -from espnet2.gan_tts.wavenet.residual_block import Conv1d -from espnet2.gan_tts.wavenet.residual_block import Conv1d1x1 -from espnet2.gan_tts.wavenet.residual_block import ResidualBlock +from espnet2.gan_tts.wavenet.residual_block import (Conv1d, Conv1d1x1, + ResidualBlock) class ParallelWaveGANGenerator(torch.nn.Module): diff --git a/espnet2/gan_tts/parallel_wavegan/upsample.py b/espnet2/gan_tts/parallel_wavegan/upsample.py index 4e0acee577c..de8163a4375 100644 --- a/espnet2/gan_tts/parallel_wavegan/upsample.py +++ b/espnet2/gan_tts/parallel_wavegan/upsample.py @@ -7,10 +7,7 @@ """ -from typing import Any -from typing import Dict -from typing import List -from typing import Optional +from typing import Any, Dict, List, Optional import numpy as np import torch diff --git a/espnet2/gan_tts/style_melgan/__init__.py b/espnet2/gan_tts/style_melgan/__init__.py index 8f47edde698..7cf6bac5a9e 100644 --- a/espnet2/gan_tts/style_melgan/__init__.py +++ b/espnet2/gan_tts/style_melgan/__init__.py @@ -1,2 +1,4 @@ -from espnet2.gan_tts.style_melgan.style_melgan import StyleMelGANDiscriminator # NOQA -from espnet2.gan_tts.style_melgan.style_melgan import StyleMelGANGenerator # NOQA +from espnet2.gan_tts.style_melgan.style_melgan import \ + StyleMelGANDiscriminator # NOQA +from espnet2.gan_tts.style_melgan.style_melgan import \ + StyleMelGANGenerator # NOQA diff --git a/espnet2/gan_tts/style_melgan/style_melgan.py b/espnet2/gan_tts/style_melgan/style_melgan.py index 4934a094f23..72cf9eae035 100644 --- a/espnet2/gan_tts/style_melgan/style_melgan.py +++ b/espnet2/gan_tts/style_melgan/style_melgan.py @@ -10,11 +10,7 @@ import copy import logging import math - -from typing import Any -from typing import Dict -from typing import List -from typing import Optional +from typing import Any, Dict, List, Optional import numpy as np import torch diff --git a/espnet2/gan_tts/utils/__init__.py b/espnet2/gan_tts/utils/__init__.py index 591d46fc2a5..b9dd249e0ca 100644 --- a/espnet2/gan_tts/utils/__init__.py +++ b/espnet2/gan_tts/utils/__init__.py @@ -1,2 +1,3 @@ -from espnet2.gan_tts.utils.get_random_segments import get_random_segments # NOQA +from espnet2.gan_tts.utils.get_random_segments import \ + get_random_segments # NOQA from espnet2.gan_tts.utils.get_random_segments import get_segments # NOQA diff --git a/espnet2/gan_tts/vits/duration_predictor.py b/espnet2/gan_tts/vits/duration_predictor.py index 5a480b11344..e86d24cb353 100644 --- a/espnet2/gan_tts/vits/duration_predictor.py +++ b/espnet2/gan_tts/vits/duration_predictor.py @@ -8,17 +8,14 @@ """ import math - from typing import Optional import torch import torch.nn.functional as F -from espnet2.gan_tts.vits.flow import ConvFlow -from espnet2.gan_tts.vits.flow import DilatedDepthSeparableConv -from espnet2.gan_tts.vits.flow import ElementwiseAffineFlow -from espnet2.gan_tts.vits.flow import FlipFlow -from espnet2.gan_tts.vits.flow import LogFlow +from espnet2.gan_tts.vits.flow import (ConvFlow, DilatedDepthSeparableConv, + ElementwiseAffineFlow, FlipFlow, + LogFlow) class StochasticDurationPredictor(torch.nn.Module): diff --git a/espnet2/gan_tts/vits/flow.py b/espnet2/gan_tts/vits/flow.py index ef384df3802..e49ee8af33f 100644 --- a/espnet2/gan_tts/vits/flow.py +++ b/espnet2/gan_tts/vits/flow.py @@ -8,14 +8,12 @@ """ import math - -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import torch -from espnet2.gan_tts.vits.transform import piecewise_rational_quadratic_transform +from espnet2.gan_tts.vits.transform import \ + piecewise_rational_quadratic_transform class FlipFlow(torch.nn.Module): diff --git a/espnet2/gan_tts/vits/generator.py b/espnet2/gan_tts/vits/generator.py index 4907dbd6162..1142ce1f5c4 100644 --- a/espnet2/gan_tts/vits/generator.py +++ b/espnet2/gan_tts/vits/generator.py @@ -8,22 +8,19 @@ """ import math - -from typing import List -from typing import Optional -from typing import Tuple +from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask from espnet2.gan_tts.hifigan import HiFiGANGenerator from espnet2.gan_tts.utils import get_random_segments from espnet2.gan_tts.vits.duration_predictor import StochasticDurationPredictor from espnet2.gan_tts.vits.posterior_encoder import PosteriorEncoder from espnet2.gan_tts.vits.residual_coupling import ResidualAffineCouplingBlock from espnet2.gan_tts.vits.text_encoder import TextEncoder +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask class VITSGenerator(torch.nn.Module): diff --git a/espnet2/gan_tts/vits/monotonic_align/__init__.py b/espnet2/gan_tts/vits/monotonic_align/__init__.py index 59bbf12dba4..e7390c6ea99 100644 --- a/espnet2/gan_tts/vits/monotonic_align/__init__.py +++ b/espnet2/gan_tts/vits/monotonic_align/__init__.py @@ -8,9 +8,7 @@ import numpy as np import torch - -from numba import njit -from numba import prange +from numba import njit, prange try: from .core import maximum_path_c diff --git a/espnet2/gan_tts/vits/monotonic_align/setup.py b/espnet2/gan_tts/vits/monotonic_align/setup.py index 6df5c46d7f2..d044b2794ea 100644 --- a/espnet2/gan_tts/vits/monotonic_align/setup.py +++ b/espnet2/gan_tts/vits/monotonic_align/setup.py @@ -1,11 +1,8 @@ """Setup cython code.""" -from setuptools import Extension -from setuptools import setup - -from setuptools.command.build_ext import build_ext as _build_ext - from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext class build_ext(_build_ext): diff --git a/espnet2/gan_tts/vits/posterior_encoder.py b/espnet2/gan_tts/vits/posterior_encoder.py index 1ae3a8ca332..199b6586000 100644 --- a/espnet2/gan_tts/vits/posterior_encoder.py +++ b/espnet2/gan_tts/vits/posterior_encoder.py @@ -7,14 +7,13 @@ """ -from typing import Optional -from typing import Tuple +from typing import Optional, Tuple import torch -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet2.gan_tts.wavenet.residual_block import Conv1d from espnet2.gan_tts.wavenet import WaveNet +from espnet2.gan_tts.wavenet.residual_block import Conv1d +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask class PosteriorEncoder(torch.nn.Module): diff --git a/espnet2/gan_tts/vits/residual_coupling.py b/espnet2/gan_tts/vits/residual_coupling.py index e01bd2c85ac..0a222c8763b 100644 --- a/espnet2/gan_tts/vits/residual_coupling.py +++ b/espnet2/gan_tts/vits/residual_coupling.py @@ -7,9 +7,7 @@ """ -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import torch diff --git a/espnet2/gan_tts/vits/text_encoder.py b/espnet2/gan_tts/vits/text_encoder.py index 6e529081d57..b268b2f5397 100644 --- a/espnet2/gan_tts/vits/text_encoder.py +++ b/espnet2/gan_tts/vits/text_encoder.py @@ -8,7 +8,6 @@ """ import math - from typing import Tuple import torch diff --git a/espnet2/gan_tts/vits/transform.py b/espnet2/gan_tts/vits/transform.py index aa7729839cf..9addfa73ae9 100644 --- a/espnet2/gan_tts/vits/transform.py +++ b/espnet2/gan_tts/vits/transform.py @@ -4,12 +4,10 @@ """ +import numpy as np import torch from torch.nn import functional as F -import numpy as np - - DEFAULT_MIN_BIN_WIDTH = 1e-3 DEFAULT_MIN_BIN_HEIGHT = 1e-3 DEFAULT_MIN_DERIVATIVE = 1e-3 diff --git a/espnet2/gan_tts/vits/vits.py b/espnet2/gan_tts/vits/vits.py index e08c486d3fe..d034d5d849e 100644 --- a/espnet2/gan_tts/vits/vits.py +++ b/espnet2/gan_tts/vits/vits.py @@ -5,24 +5,21 @@ from contextlib import contextmanager from distutils.version import LooseVersion -from typing import Any -from typing import Dict -from typing import Optional +from typing import Any, Dict, Optional import torch - from typeguard import check_argument_types from espnet2.gan_tts.abs_gan_tts import AbsGANTTS -from espnet2.gan_tts.hifigan import HiFiGANMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANPeriodDiscriminator -from espnet2.gan_tts.hifigan import HiFiGANScaleDiscriminator -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import FeatureMatchLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import MelSpectrogramLoss +from espnet2.gan_tts.hifigan import (HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator) +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + MelSpectrogramLoss) from espnet2.gan_tts.utils import get_segments from espnet2.gan_tts.vits.generator import VITSGenerator from espnet2.gan_tts.vits.loss import KLDivergenceLoss diff --git a/espnet2/gan_tts/wavenet/residual_block.py b/espnet2/gan_tts/wavenet/residual_block.py index e568c7e7aa5..8385bacf60e 100644 --- a/espnet2/gan_tts/wavenet/residual_block.py +++ b/espnet2/gan_tts/wavenet/residual_block.py @@ -8,9 +8,7 @@ """ import math - -from typing import Optional -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn.functional as F diff --git a/espnet2/gan_tts/wavenet/wavenet.py b/espnet2/gan_tts/wavenet/wavenet.py index cd91cf47710..44533455c00 100644 --- a/espnet2/gan_tts/wavenet/wavenet.py +++ b/espnet2/gan_tts/wavenet/wavenet.py @@ -9,13 +9,11 @@ import logging import math - from typing import Optional import torch -from espnet2.gan_tts.wavenet.residual_block import Conv1d1x1 -from espnet2.gan_tts.wavenet.residual_block import ResidualBlock +from espnet2.gan_tts.wavenet.residual_block import Conv1d1x1, ResidualBlock class WaveNet(torch.nn.Module): diff --git a/espnet2/hubert/espnet_model.py b/espnet2/hubert/espnet_model.py index 35468bde93e..4b8b256b555 100644 --- a/espnet2/hubert/espnet_model.py +++ b/espnet2/hubert/espnet_model.py @@ -7,18 +7,12 @@ # Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert from contextlib import contextmanager -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.nets.e2e_asr_common import ErrorCalculator - from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.frontend.abs_frontend import AbsFrontend from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder @@ -27,6 +21,7 @@ from espnet2.layers.abs_normalize import AbsNormalize from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.e2e_asr_common import ErrorCalculator if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast diff --git a/espnet2/hubert/hubert_loss.py b/espnet2/hubert/hubert_loss.py index af790177068..76c7c361bd0 100644 --- a/espnet2/hubert/hubert_loss.py +++ b/espnet2/hubert/hubert_loss.py @@ -11,8 +11,8 @@ """Hubert Pretrain Loss module.""" -from torch import nn import torch.nn.functional as F +from torch import nn class HubertPretrainLoss(nn.Module): diff --git a/espnet2/iterators/abs_iter_factory.py b/espnet2/iterators/abs_iter_factory.py index 36e4dd2c521..9f63a210a73 100644 --- a/espnet2/iterators/abs_iter_factory.py +++ b/espnet2/iterators/abs_iter_factory.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Iterator diff --git a/espnet2/iterators/chunk_iter_factory.py b/espnet2/iterators/chunk_iter_factory.py index 828710ab92f..7f5d82aa949 100644 --- a/espnet2/iterators/chunk_iter_factory.py +++ b/espnet2/iterators/chunk_iter_factory.py @@ -1,11 +1,5 @@ import logging -from typing import Any -from typing import Dict -from typing import Iterator -from typing import List -from typing import Sequence -from typing import Tuple -from typing import Union +from typing import Any, Dict, Iterator, List, Sequence, Tuple, Union import numpy as np import torch diff --git a/espnet2/iterators/multiple_iter_factory.py b/espnet2/iterators/multiple_iter_factory.py index 28e3d2dcb61..29f174df9b8 100644 --- a/espnet2/iterators/multiple_iter_factory.py +++ b/espnet2/iterators/multiple_iter_factory.py @@ -1,7 +1,5 @@ import logging -from typing import Callable -from typing import Collection -from typing import Iterator +from typing import Callable, Collection, Iterator import numpy as np from typeguard import check_argument_types diff --git a/espnet2/iterators/sequence_iter_factory.py b/espnet2/iterators/sequence_iter_factory.py index 48f61f8c7df..b80aee55345 100644 --- a/espnet2/iterators/sequence_iter_factory.py +++ b/espnet2/iterators/sequence_iter_factory.py @@ -1,6 +1,4 @@ -from typing import Any -from typing import Sequence -from typing import Union +from typing import Any, Sequence, Union import numpy as np from torch.utils.data import DataLoader diff --git a/espnet2/layers/abs_normalize.py b/espnet2/layers/abs_normalize.py index f2be748dd7c..c908f38f74c 100644 --- a/espnet2/layers/abs_normalize.py +++ b/espnet2/layers/abs_normalize.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/layers/global_mvn.py b/espnet2/layers/global_mvn.py index 31635cb4feb..819bce53a46 100644 --- a/espnet2/layers/global_mvn.py +++ b/espnet2/layers/global_mvn.py @@ -1,14 +1,13 @@ from pathlib import Path -from typing import Tuple -from typing import Union +from typing import Tuple, Union import numpy as np import torch from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet2.layers.abs_normalize import AbsNormalize from espnet2.layers.inversible_interface import InversibleInterface +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask class GlobalMVN(AbsNormalize, InversibleInterface): diff --git a/espnet2/layers/inversible_interface.py b/espnet2/layers/inversible_interface.py index a1a59399aae..30874a87e8f 100644 --- a/espnet2/layers/inversible_interface.py +++ b/espnet2/layers/inversible_interface.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/layers/label_aggregation.py b/espnet2/layers/label_aggregation.py index fbd845842e6..e5201515301 100644 --- a/espnet2/layers/label_aggregation.py +++ b/espnet2/layers/label_aggregation.py @@ -1,7 +1,7 @@ +from typing import Optional, Tuple + import torch from typeguard import check_argument_types -from typing import Optional -from typing import Tuple from espnet.nets.pytorch_backend.nets_utils import make_pad_mask diff --git a/espnet2/layers/log_mel.py b/espnet2/layers/log_mel.py index 5caeadbe31e..631c83d46c9 100644 --- a/espnet2/layers/log_mel.py +++ b/espnet2/layers/log_mel.py @@ -1,6 +1,7 @@ +from typing import Tuple + import librosa import torch -from typing import Tuple from espnet.nets.pytorch_backend.nets_utils import make_pad_mask diff --git a/espnet2/layers/mask_along_axis.py b/espnet2/layers/mask_along_axis.py index ecff6fa9659..96bd269113d 100644 --- a/espnet2/layers/mask_along_axis.py +++ b/espnet2/layers/mask_along_axis.py @@ -1,8 +1,8 @@ import math +from typing import Sequence, Union + import torch from typeguard import check_argument_types -from typing import Sequence -from typing import Union def mask_along_axis( diff --git a/espnet2/layers/sinc_conv.py b/espnet2/layers/sinc_conv.py index 33df97fbcdf..a31683474b4 100644 --- a/espnet2/layers/sinc_conv.py +++ b/espnet2/layers/sinc_conv.py @@ -4,9 +4,10 @@ """Sinc convolutions.""" import math +from typing import Union + import torch from typeguard import check_argument_types -from typing import Union class LogCompression(torch.nn.Module): diff --git a/espnet2/layers/stft.py b/espnet2/layers/stft.py index 847469bbd4a..9dee3ac681d 100644 --- a/espnet2/layers/stft.py +++ b/espnet2/layers/stft.py @@ -1,17 +1,15 @@ -from packaging.version import parse as V -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union +import librosa +import numpy as np import torch +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet2.enh.layers.complex_utils import is_complex from espnet2.layers.inversible_interface import InversibleInterface -import librosa -import numpy as np +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/espnet2/layers/utterance_mvn.py b/espnet2/layers/utterance_mvn.py index 4f1adb3e53b..b1d50b7aea6 100644 --- a/espnet2/layers/utterance_mvn.py +++ b/espnet2/layers/utterance_mvn.py @@ -3,8 +3,8 @@ import torch from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet2.layers.abs_normalize import AbsNormalize +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask class UtteranceMVN(AbsNormalize): diff --git a/espnet2/lm/abs_model.py b/espnet2/lm/abs_model.py index ba5773d0126..5c96c0ed19c 100644 --- a/espnet2/lm/abs_model.py +++ b/espnet2/lm/abs_model.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Tuple import torch diff --git a/espnet2/lm/espnet_model.py b/espnet2/lm/espnet_model.py index de6cd114a25..bbaecb8d8ee 100644 --- a/espnet2/lm/espnet_model.py +++ b/espnet2/lm/espnet_model.py @@ -1,15 +1,13 @@ -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Dict, Optional, Tuple import torch import torch.nn.functional as F from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet2.lm.abs_model import AbsLM from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask class ESPnetLanguageModel(AbsESPnetModel): diff --git a/espnet2/lm/seq_rnn_lm.py b/espnet2/lm/seq_rnn_lm.py index 9af85ed3cc7..05bca746efc 100644 --- a/espnet2/lm/seq_rnn_lm.py +++ b/espnet2/lm/seq_rnn_lm.py @@ -1,6 +1,5 @@ """Sequential implementation of Recurrent Neural Network Language Model.""" -from typing import Tuple -from typing import Union +from typing import Tuple, Union import torch import torch.nn as nn diff --git a/espnet2/lm/transformer_lm.py b/espnet2/lm/transformer_lm.py index 57df87bb11c..444c258ea59 100644 --- a/espnet2/lm/transformer_lm.py +++ b/espnet2/lm/transformer_lm.py @@ -1,14 +1,13 @@ -from typing import Any -from typing import List -from typing import Tuple +from typing import Any, List, Tuple import torch import torch.nn as nn -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet2.lm.abs_model import AbsLM +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet2.lm.abs_model import AbsLM class TransformerLM(AbsLM): diff --git a/espnet2/main_funcs/average_nbest_models.py b/espnet2/main_funcs/average_nbest_models.py index 4c278e23823..e706456e8ea 100644 --- a/espnet2/main_funcs/average_nbest_models.py +++ b/espnet2/main_funcs/average_nbest_models.py @@ -1,13 +1,10 @@ import logging -from pathlib import Path -from typing import Optional -from typing import Sequence -from typing import Union import warnings +from pathlib import Path +from typing import Collection, Optional, Sequence, Union import torch from typeguard import check_argument_types -from typing import Collection from espnet2.train.reporter import Reporter diff --git a/espnet2/main_funcs/calculate_all_attentions.py b/espnet2/main_funcs/calculate_all_attentions.py index 52fe045779b..e238cf7af4e 100644 --- a/espnet2/main_funcs/calculate_all_attentions.py +++ b/espnet2/main_funcs/calculate_all_attentions.py @@ -1,27 +1,15 @@ from collections import defaultdict -from typing import Dict -from typing import List +from typing import Dict, List import torch -from espnet.nets.pytorch_backend.rnn.attentions import AttAdd -from espnet.nets.pytorch_backend.rnn.attentions import AttCov -from espnet.nets.pytorch_backend.rnn.attentions import AttCovLoc -from espnet.nets.pytorch_backend.rnn.attentions import AttDot -from espnet.nets.pytorch_backend.rnn.attentions import AttForward -from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA -from espnet.nets.pytorch_backend.rnn.attentions import AttLoc -from espnet.nets.pytorch_backend.rnn.attentions import AttLoc2D -from espnet.nets.pytorch_backend.rnn.attentions import AttLocRec -from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadAdd -from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadDot -from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadLoc -from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadMultiResLoc -from espnet.nets.pytorch_backend.rnn.attentions import NoAtt -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention - - from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.pytorch_backend.rnn.attentions import ( + AttAdd, AttCov, AttCovLoc, AttDot, AttForward, AttForwardTA, AttLoc, + AttLoc2D, AttLocRec, AttMultiHeadAdd, AttMultiHeadDot, AttMultiHeadLoc, + AttMultiHeadMultiResLoc, NoAtt) +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention @torch.no_grad() diff --git a/espnet2/main_funcs/collect_stats.py b/espnet2/main_funcs/collect_stats.py index 297f7bfda7f..9edfcfac412 100644 --- a/espnet2/main_funcs/collect_stats.py +++ b/espnet2/main_funcs/collect_stats.py @@ -1,11 +1,7 @@ -from collections import defaultdict import logging +from collections import defaultdict from pathlib import Path -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Tuple +from typing import Dict, Iterable, List, Optional, Tuple import numpy as np import torch diff --git a/espnet2/main_funcs/pack_funcs.py b/espnet2/main_funcs/pack_funcs.py index ffa807e23b6..c13dde41e25 100644 --- a/espnet2/main_funcs/pack_funcs.py +++ b/espnet2/main_funcs/pack_funcs.py @@ -1,15 +1,11 @@ -from datetime import datetime -from io import BytesIO -from io import TextIOWrapper import os -from pathlib import Path import sys import tarfile -from typing import Dict -from typing import Iterable -from typing import Optional -from typing import Union import zipfile +from datetime import datetime +from io import BytesIO, TextIOWrapper +from pathlib import Path +from typing import Dict, Iterable, Optional, Union import yaml diff --git a/espnet2/mt/espnet_model.py b/espnet2/mt/espnet_model.py index 8a493366046..ca1ae36a9be 100644 --- a/espnet2/mt/espnet_model.py +++ b/espnet2/mt/espnet_model.py @@ -1,21 +1,11 @@ -from contextlib import contextmanager import logging -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator -from espnet.nets.pytorch_backend.nets_utils import th_accuracy -from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.frontend.abs_frontend import AbsFrontend @@ -23,6 +13,11 @@ from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast diff --git a/espnet2/mt/frontend/embedding.py b/espnet2/mt/frontend/embedding.py index b9044c1385f..6b14df5182c 100644 --- a/espnet2/mt/frontend/embedding.py +++ b/espnet2/mt/frontend/embedding.py @@ -4,11 +4,14 @@ """Embedding Frontend for text based inputs.""" -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet2.asr.frontend.abs_frontend import AbsFrontend +from typing import Tuple + import torch from typeguard import check_argument_types -from typing import Tuple + +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding class Embedding(AbsFrontend): diff --git a/espnet2/samplers/abs_sampler.py b/espnet2/samplers/abs_sampler.py index 2f7aa539b8a..48e60e243f8 100644 --- a/espnet2/samplers/abs_sampler.py +++ b/espnet2/samplers/abs_sampler.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod -from typing import Iterator -from typing import Tuple +from abc import ABC, abstractmethod +from typing import Iterator, Tuple from torch.utils.data import Sampler diff --git a/espnet2/samplers/build_batch_sampler.py b/espnet2/samplers/build_batch_sampler.py index 0775dd962f7..e9a2b77502d 100644 --- a/espnet2/samplers/build_batch_sampler.py +++ b/espnet2/samplers/build_batch_sampler.py @@ -1,10 +1,6 @@ -from typing import List -from typing import Sequence -from typing import Tuple -from typing import Union +from typing import List, Sequence, Tuple, Union -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.samplers.abs_sampler import AbsSampler from espnet2.samplers.folded_batch_sampler import FoldedBatchSampler @@ -13,7 +9,6 @@ from espnet2.samplers.sorted_batch_sampler import SortedBatchSampler from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler - BATCH_TYPES = dict( unsorted="UnsortedBatchSampler has nothing in particular feature and " "just creates mini-batches which has constant batch_size. " diff --git a/espnet2/samplers/folded_batch_sampler.py b/espnet2/samplers/folded_batch_sampler.py index 4d2e941e3d4..e1e85cd084e 100644 --- a/espnet2/samplers/folded_batch_sampler.py +++ b/espnet2/samplers/folded_batch_sampler.py @@ -1,13 +1,8 @@ -from typing import Iterator -from typing import List -from typing import Sequence -from typing import Tuple -from typing import Union +from typing import Iterator, List, Sequence, Tuple, Union from typeguard import check_argument_types -from espnet2.fileio.read_text import load_num_sequence_text -from espnet2.fileio.read_text import read_2column_text +from espnet2.fileio.read_text import load_num_sequence_text, read_2column_text from espnet2.samplers.abs_sampler import AbsSampler diff --git a/espnet2/samplers/length_batch_sampler.py b/espnet2/samplers/length_batch_sampler.py index 522a4b49e14..5e1cf6e3e6d 100644 --- a/espnet2/samplers/length_batch_sampler.py +++ b/espnet2/samplers/length_batch_sampler.py @@ -1,7 +1,4 @@ -from typing import Iterator -from typing import List -from typing import Tuple -from typing import Union +from typing import Iterator, List, Tuple, Union from typeguard import check_argument_types diff --git a/espnet2/samplers/num_elements_batch_sampler.py b/espnet2/samplers/num_elements_batch_sampler.py index 46ff177b8f3..31569e2e81f 100644 --- a/espnet2/samplers/num_elements_batch_sampler.py +++ b/espnet2/samplers/num_elements_batch_sampler.py @@ -1,7 +1,4 @@ -from typing import Iterator -from typing import List -from typing import Tuple -from typing import Union +from typing import Iterator, List, Tuple, Union import numpy as np from typeguard import check_argument_types diff --git a/espnet2/samplers/sorted_batch_sampler.py b/espnet2/samplers/sorted_batch_sampler.py index 4649f9a4fd7..be26aa56010 100644 --- a/espnet2/samplers/sorted_batch_sampler.py +++ b/espnet2/samplers/sorted_batch_sampler.py @@ -1,6 +1,5 @@ import logging -from typing import Iterator -from typing import Tuple +from typing import Iterator, Tuple from typeguard import check_argument_types diff --git a/espnet2/samplers/unsorted_batch_sampler.py b/espnet2/samplers/unsorted_batch_sampler.py index 33a22090ac2..32937977f46 100644 --- a/espnet2/samplers/unsorted_batch_sampler.py +++ b/espnet2/samplers/unsorted_batch_sampler.py @@ -1,6 +1,5 @@ import logging -from typing import Iterator -from typing import Tuple +from typing import Iterator, Tuple from typeguard import check_argument_types diff --git a/espnet2/schedulers/abs_scheduler.py b/espnet2/schedulers/abs_scheduler.py index 7395f259c3e..ea79767833a 100644 --- a/espnet2/schedulers/abs_scheduler.py +++ b/espnet2/schedulers/abs_scheduler.py @@ -1,5 +1,4 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod import torch.optim.lr_scheduler as L diff --git a/espnet2/schedulers/noam_lr.py b/espnet2/schedulers/noam_lr.py index 1c9aeb152da..c9402755fa5 100644 --- a/espnet2/schedulers/noam_lr.py +++ b/espnet2/schedulers/noam_lr.py @@ -1,6 +1,6 @@ """Noam learning rate scheduler module.""" -from typing import Union import warnings +from typing import Union import torch from torch.optim.lr_scheduler import _LRScheduler diff --git a/espnet2/st/espnet_model.py b/espnet2/st/espnet_model.py index 743b53d8288..45fd31d54e2 100644 --- a/espnet2/st/espnet_model.py +++ b/espnet2/st/espnet_model.py @@ -1,22 +1,11 @@ -from contextlib import contextmanager import logging -from packaging.version import parse as V -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union import torch +from packaging.version import parse as V from typeguard import check_argument_types -from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator -from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator -from espnet.nets.pytorch_backend.nets_utils import th_accuracy -from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) from espnet2.asr.ctc import CTC from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.encoder.abs_encoder import AbsEncoder @@ -27,6 +16,12 @@ from espnet2.layers.abs_normalize import AbsNormalize from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator +from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 if V(torch.__version__) >= V("1.6.0"): from torch.cuda.amp import autocast diff --git a/espnet2/tasks/abs_task.py b/espnet2/tasks/abs_task.py index 0f23feaa93d..8fbedea302d 100644 --- a/espnet2/tasks/abs_task.py +++ b/espnet2/tasks/abs_task.py @@ -1,22 +1,13 @@ """Abstract task module.""" -from abc import ABC -from abc import abstractmethod import argparse -from dataclasses import dataclass import functools import logging import os -from packaging.version import parse as V -from pathlib import Path import sys -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import humanfriendly import numpy as np @@ -24,21 +15,20 @@ import torch.multiprocessing import torch.nn import torch.optim -from torch.utils.data import DataLoader -from typeguard import check_argument_types -from typeguard import check_return_type import yaml +from packaging.version import parse as V +from torch.utils.data import DataLoader +from typeguard import check_argument_types, check_return_type from espnet import __version__ -from espnet.utils.cli_utils import get_commandline_args from espnet2.iterators.abs_iter_factory import AbsIterFactory from espnet2.iterators.chunk_iter_factory import ChunkIterFactory from espnet2.iterators.multiple_iter_factory import MultipleIterFactory from espnet2.iterators.sequence_iter_factory import SequenceIterFactory from espnet2.main_funcs.collect_stats import collect_stats from espnet2.optimizers.sgd import SGD -from espnet2.samplers.build_batch_sampler import BATCH_TYPES -from espnet2.samplers.build_batch_sampler import build_batch_sampler +from espnet2.samplers.build_batch_sampler import (BATCH_TYPES, + build_batch_sampler) from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler from espnet2.schedulers.noam_lr import NoamLR from espnet2.schedulers.warmup_lr import WarmupLR @@ -48,28 +38,22 @@ from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.train.abs_espnet_model import AbsESPnetModel from espnet2.train.class_choices import ClassChoices -from espnet2.train.dataset import AbsDataset -from espnet2.train.dataset import DATA_TYPES -from espnet2.train.dataset import ESPnetDataset -from espnet2.train.distributed_utils import DistributedOption -from espnet2.train.distributed_utils import free_port -from espnet2.train.distributed_utils import get_master_port -from espnet2.train.distributed_utils import get_node_rank -from espnet2.train.distributed_utils import get_num_nodes -from espnet2.train.distributed_utils import resolve_distributed_mode +from espnet2.train.dataset import DATA_TYPES, AbsDataset, ESPnetDataset +from espnet2.train.distributed_utils import (DistributedOption, free_port, + get_master_port, get_node_rank, + get_num_nodes, + resolve_distributed_mode) from espnet2.train.iterable_dataset import IterableESPnetDataset from espnet2.train.trainer import Trainer -from espnet2.utils.build_dataclass import build_dataclass from espnet2.utils import config_argparse +from espnet2.utils.build_dataclass import build_dataclass from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import humanfriendly_parse_size_or_none -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_int -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (humanfriendly_parse_size_or_none, int_or_none, + str2bool, str2triple_str, str_or_int, + str_or_none) from espnet2.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump +from espnet.utils.cli_utils import get_commandline_args try: import wandb diff --git a/espnet2/tasks/asr.py b/espnet2/tasks/asr.py index 9ab3c9ca7fd..c13fd75efb9 100644 --- a/espnet2/tasks/asr.py +++ b/espnet2/tasks/asr.py @@ -1,46 +1,34 @@ import argparse import logging -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.asr.ctc import CTC from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.decoder.mlm_decoder import MLMDecoder from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet2.asr.decoder.transformer_decoder import \ + DynamicConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolutionTransformerDecoder # noqa: H301 from espnet2.asr.decoder.transformer_decoder import ( - DynamicConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolutionTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import TransformerDecoder + DynamicConvolutionTransformerDecoder, TransformerDecoder) from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.encoder.conformer_encoder import ConformerEncoder +from espnet2.asr.encoder.contextual_block_conformer_encoder import \ + ContextualBlockConformerEncoder # noqa: H301 +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 +from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, + FairseqHubertPretrainEncoder) from espnet2.asr.encoder.longformer_encoder import LongformerEncoder - -from espnet2.asr.encoder.hubert_encoder import FairseqHubertEncoder -from espnet2.asr.encoder.hubert_encoder import FairseqHubertPretrainEncoder from espnet2.asr.encoder.rnn_encoder import RNNEncoder from espnet2.asr.encoder.transformer_encoder import TransformerEncoder -from espnet2.asr.encoder.contextual_block_transformer_encoder import ( - ContextualBlockTransformerEncoder, # noqa: H301 -) -from espnet2.asr.encoder.contextual_block_conformer_encoder import ( - ContextualBlockConformerEncoder, # noqa: H301 -) from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder from espnet2.asr.espnet_model import ESPnetASRModel @@ -51,9 +39,8 @@ from espnet2.asr.frontend.windowing import SlidingWindow from espnet2.asr.maskctc_model import MaskCTCModel from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder -from espnet2.asr.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) +from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ + HuggingFaceTransformersPostEncoder # noqa: H301 from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder from espnet2.asr.preencoder.linear import LinearProjection from espnet2.asr.preencoder.sinc import LightweightSincConvs @@ -74,10 +61,8 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import float_or_none -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (float_or_none, int_or_none, str2bool, + str_or_none) frontend_choices = ClassChoices( name="frontend", diff --git a/espnet2/tasks/diar.py b/espnet2/tasks/diar.py index e01a59532a0..b86aef53de0 100644 --- a/espnet2/tasks/diar.py +++ b/espnet2/tasks/diar.py @@ -1,15 +1,9 @@ import argparse -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.encoder.conformer_encoder import ConformerEncoder @@ -38,9 +32,7 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import int_or_none, str2bool, str_or_none frontend_choices = ClassChoices( name="frontend", diff --git a/espnet2/tasks/enh.py b/espnet2/tasks/enh.py index fd0742359da..49d04143695 100644 --- a/espnet2/tasks/enh.py +++ b/espnet2/tasks/enh.py @@ -1,15 +1,9 @@ import argparse -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.enh.decoder.abs_decoder import AbsDecoder from espnet2.enh.decoder.conv_decoder import ConvDecoder @@ -21,16 +15,14 @@ from espnet2.enh.encoder.stft_encoder import STFTEncoder from espnet2.enh.espnet_model import ESPnetEnhancementModel from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainAbsCoherence -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE -from espnet2.enh.loss.criterions.time_domain import CISDRLoss -from espnet2.enh.loss.criterions.time_domain import SDRLoss -from espnet2.enh.loss.criterions.time_domain import SISNRLoss -from espnet2.enh.loss.criterions.time_domain import SNRLoss -from espnet2.enh.loss.criterions.time_domain import TimeDomainL1 -from espnet2.enh.loss.criterions.time_domain import TimeDomainMSE +from espnet2.enh.loss.criterions.tf_domain import (FrequencyDomainAbsCoherence, + FrequencyDomainDPCL, + FrequencyDomainL1, + FrequencyDomainMSE) +from espnet2.enh.loss.criterions.time_domain import (CISDRLoss, SDRLoss, + SISNRLoss, SNRLoss, + TimeDomainL1, + TimeDomainMSE) from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper from espnet2.enh.loss.wrappers.dpcl_solver import DPCLSolver from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver @@ -59,8 +51,7 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import str2bool, str_or_none encoder_choices = ClassChoices( name="encoder", diff --git a/espnet2/tasks/enh_s2t.py b/espnet2/tasks/enh_s2t.py index d6a20bac700..561411de186 100644 --- a/espnet2/tasks/enh_s2t.py +++ b/espnet2/tasks/enh_s2t.py @@ -1,17 +1,11 @@ import argparse import copy import logging -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.asr.ctc import CTC from espnet2.asr.espnet_model import ESPnetASRModel @@ -21,34 +15,32 @@ from espnet2.tasks.asr import ASRTask from espnet2.tasks.asr import decoder_choices as asr_decoder_choices_ from espnet2.tasks.asr import encoder_choices as asr_encoder_choices_ -from espnet2.tasks.asr import frontend_choices -from espnet2.tasks.asr import normalize_choices +from espnet2.tasks.asr import frontend_choices, normalize_choices from espnet2.tasks.asr import postencoder_choices as asr_postencoder_choices_ from espnet2.tasks.asr import preencoder_choices as asr_preencoder_choices_ from espnet2.tasks.asr import specaug_choices +from espnet2.tasks.enh import EnhancementTask from espnet2.tasks.enh import decoder_choices as enh_decoder_choices_ from espnet2.tasks.enh import encoder_choices as enh_encoder_choices_ -from espnet2.tasks.enh import EnhancementTask from espnet2.tasks.enh import separator_choices as enh_separator_choices_ +from espnet2.tasks.st import STTask from espnet2.tasks.st import decoder_choices as st_decoder_choices_ from espnet2.tasks.st import encoder_choices as st_encoder_choices_ -from espnet2.tasks.st import extra_asr_decoder_choices as st_extra_asr_decoder_choices_ -from espnet2.tasks.st import extra_mt_decoder_choices as st_extra_mt_decoder_choices_ +from espnet2.tasks.st import \ + extra_asr_decoder_choices as st_extra_asr_decoder_choices_ +from espnet2.tasks.st import \ + extra_mt_decoder_choices as st_extra_mt_decoder_choices_ from espnet2.tasks.st import postencoder_choices as st_postencoder_choices_ from espnet2.tasks.st import preencoder_choices as st_preencoder_choices_ -from espnet2.tasks.st import STTask from espnet2.text.phoneme_tokenizer import g2p_choices from espnet2.torch_utils.initialize import initialize from espnet2.train.collate_fn import CommonCollateFn -from espnet2.train.preprocessor import CommonPreprocessor_multi -from espnet2.train.preprocessor import MutliTokenizerCommonPreprocessor +from espnet2.train.preprocessor import (CommonPreprocessor_multi, + MutliTokenizerCommonPreprocessor) from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none - +from espnet2.utils.types import int_or_none, str2bool, str_or_none # Enhancement enh_encoder_choices = copy.deepcopy(enh_encoder_choices_) diff --git a/espnet2/tasks/gan_tts.py b/espnet2/tasks/gan_tts.py index d1fdcb6f900..1a139218b8a 100644 --- a/espnet2/tasks/gan_tts.py +++ b/espnet2/tasks/gan_tts.py @@ -5,19 +5,11 @@ import argparse import logging - -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch - -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.gan_tts.abs_gan_tts import AbsGANTTS from espnet2.gan_tts.espnet_model import ESPnetGANTTSModel @@ -27,8 +19,7 @@ from espnet2.layers.abs_normalize import AbsNormalize from espnet2.layers.global_mvn import GlobalMVN from espnet2.layers.utterance_mvn import UtteranceMVN -from espnet2.tasks.abs_task import AbsTask -from espnet2.tasks.abs_task import optim_classes +from espnet2.tasks.abs_task import AbsTask, optim_classes from espnet2.text.phoneme_tokenizer import g2p_choices from espnet2.train.class_choices import ClassChoices from espnet2.train.collate_fn import CommonCollateFn @@ -42,9 +33,7 @@ from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import int_or_none, str2bool, str_or_none feats_extractor_choices = ClassChoices( "feats_extract", diff --git a/espnet2/tasks/hubert.py b/espnet2/tasks/hubert.py index 2c4fc9634d2..40d3eb69e80 100644 --- a/espnet2/tasks/hubert.py +++ b/espnet2/tasks/hubert.py @@ -7,22 +7,15 @@ # Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert import argparse import logging -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.asr.encoder.abs_encoder import AbsEncoder -from espnet2.asr.encoder.hubert_encoder import ( - FairseqHubertPretrainEncoder, # noqa: H301 -) +from espnet2.asr.encoder.hubert_encoder import \ + FairseqHubertPretrainEncoder # noqa: H301 from espnet2.asr.frontend.abs_frontend import AbsFrontend from espnet2.asr.frontend.default import DefaultFrontend from espnet2.asr.frontend.windowing import SlidingWindow @@ -43,10 +36,8 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import float_or_none -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (float_or_none, int_or_none, str2bool, + str_or_none) frontend_choices = ClassChoices( name="frontend", diff --git a/espnet2/tasks/lm.py b/espnet2/tasks/lm.py index eea17464ca5..bbb9847bed2 100644 --- a/espnet2/tasks/lm.py +++ b/espnet2/tasks/lm.py @@ -1,16 +1,10 @@ import argparse import logging -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.lm.abs_model import AbsLM from espnet2.lm.espnet_model import ESPnetLanguageModel @@ -25,9 +19,7 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none - +from espnet2.utils.types import str2bool, str_or_none lm_choices = ClassChoices( "lm", diff --git a/espnet2/tasks/mt.py b/espnet2/tasks/mt.py index 496b48b96e7..a0ada0c3d93 100644 --- a/espnet2/tasks/mt.py +++ b/espnet2/tasks/mt.py @@ -1,43 +1,32 @@ import argparse import logging -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet2.asr.decoder.transformer_decoder import \ + DynamicConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolutionTransformerDecoder # noqa: H301 from espnet2.asr.decoder.transformer_decoder import ( - DynamicConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolutionTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import TransformerDecoder + DynamicConvolutionTransformerDecoder, TransformerDecoder) from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.encoder.conformer_encoder import ConformerEncoder +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 from espnet2.asr.encoder.rnn_encoder import RNNEncoder from espnet2.asr.encoder.transformer_encoder import TransformerEncoder -from espnet2.asr.encoder.contextual_block_transformer_encoder import ( - ContextualBlockTransformerEncoder, # noqa: H301 -) from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder from espnet2.asr.frontend.abs_frontend import AbsFrontend from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder -from espnet2.asr.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) +from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ + HuggingFaceTransformersPostEncoder # noqa: H301 from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder from espnet2.asr.preencoder.linear import LinearProjection from espnet2.asr.preencoder.sinc import LightweightSincConvs @@ -52,9 +41,7 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import int_or_none, str2bool, str_or_none frontend_choices = ClassChoices( name="frontend", diff --git a/espnet2/tasks/st.py b/espnet2/tasks/st.py index 2b992f0be4e..0f6619c8842 100644 --- a/espnet2/tasks/st.py +++ b/espnet2/tasks/st.py @@ -1,40 +1,30 @@ import argparse import logging -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +from typing import Callable, Collection, Dict, List, Optional, Tuple import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.asr.ctc import CTC from espnet2.asr.decoder.abs_decoder import AbsDecoder from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet2.asr.decoder.transformer_decoder import \ + DynamicConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolutionTransformerDecoder # noqa: H301 from espnet2.asr.decoder.transformer_decoder import ( - DynamicConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolutionTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import TransformerDecoder + DynamicConvolutionTransformerDecoder, TransformerDecoder) from espnet2.asr.encoder.abs_encoder import AbsEncoder from espnet2.asr.encoder.conformer_encoder import ConformerEncoder -from espnet2.asr.encoder.hubert_encoder import FairseqHubertEncoder -from espnet2.asr.encoder.hubert_encoder import FairseqHubertPretrainEncoder +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 +from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, + FairseqHubertPretrainEncoder) from espnet2.asr.encoder.rnn_encoder import RNNEncoder from espnet2.asr.encoder.transformer_encoder import TransformerEncoder -from espnet2.asr.encoder.contextual_block_transformer_encoder import ( - ContextualBlockTransformerEncoder, # noqa: H301 -) from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder from espnet2.asr.frontend.abs_frontend import AbsFrontend @@ -42,9 +32,8 @@ from espnet2.asr.frontend.s3prl import S3prlFrontend from espnet2.asr.frontend.windowing import SlidingWindow from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder -from espnet2.asr.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) +from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ + HuggingFaceTransformersPostEncoder # noqa: H301 from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder from espnet2.asr.preencoder.linear import LinearProjection from espnet2.asr.preencoder.sinc import LightweightSincConvs @@ -63,10 +52,8 @@ from espnet2.train.trainer import Trainer from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import float_or_none -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (float_or_none, int_or_none, str2bool, + str_or_none) frontend_choices = ClassChoices( name="frontend", diff --git a/espnet2/tasks/tts.py b/espnet2/tasks/tts.py index b3e51fa0bed..36a73072f93 100644 --- a/espnet2/tasks/tts.py +++ b/espnet2/tasks/tts.py @@ -2,22 +2,13 @@ import argparse import logging -import yaml - from pathlib import Path -from typing import Callable -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Callable, Collection, Dict, List, Optional, Tuple, Union import numpy as np import torch - -from typeguard import check_argument_types -from typeguard import check_return_type +import yaml +from typeguard import check_argument_types, check_return_type from espnet2.gan_tts.jets import JETS from espnet2.gan_tts.joint import JointText2Wav @@ -46,9 +37,7 @@ from espnet2.utils.get_default_kwargs import get_default_kwargs from espnet2.utils.griffin_lim import Spectrogram2Waveform from espnet2.utils.nested_dict_action import NestedDictAction -from espnet2.utils.types import int_or_none -from espnet2.utils.types import str2bool -from espnet2.utils.types import str_or_none +from espnet2.utils.types import int_or_none, str2bool, str_or_none feats_extractor_choices = ClassChoices( "feats_extract", diff --git a/espnet2/text/abs_tokenizer.py b/espnet2/text/abs_tokenizer.py index fc2ccb3c369..21d727d6153 100644 --- a/espnet2/text/abs_tokenizer.py +++ b/espnet2/text/abs_tokenizer.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod -from typing import Iterable -from typing import List +from abc import ABC, abstractmethod +from typing import Iterable, List class AbsTokenizer(ABC): diff --git a/espnet2/text/build_tokenizer.py b/espnet2/text/build_tokenizer.py index 70c2b868b17..a69375b04c4 100644 --- a/espnet2/text/build_tokenizer.py +++ b/espnet2/text/build_tokenizer.py @@ -1,6 +1,5 @@ from pathlib import Path -from typing import Iterable -from typing import Union +from typing import Iterable, Union from typeguard import check_argument_types diff --git a/espnet2/text/char_tokenizer.py b/espnet2/text/char_tokenizer.py index 765f124cf20..2922b97afaa 100644 --- a/espnet2/text/char_tokenizer.py +++ b/espnet2/text/char_tokenizer.py @@ -1,8 +1,6 @@ -from pathlib import Path -from typing import Iterable -from typing import List -from typing import Union import warnings +from pathlib import Path +from typing import Iterable, List, Union from typeguard import check_argument_types diff --git a/espnet2/text/cleaner.py b/espnet2/text/cleaner.py index 687ff6afd9c..5743a1993e6 100644 --- a/espnet2/text/cleaner.py +++ b/espnet2/text/cleaner.py @@ -1,7 +1,7 @@ from typing import Collection -from jaconv import jaconv import tacotron_cleaner.cleaners +from jaconv import jaconv from typeguard import check_argument_types try: diff --git a/espnet2/text/phoneme_tokenizer.py b/espnet2/text/phoneme_tokenizer.py index 570e165f521..dd3843aaa5b 100644 --- a/espnet2/text/phoneme_tokenizer.py +++ b/espnet2/text/phoneme_tokenizer.py @@ -1,11 +1,8 @@ import logging -from pathlib import Path import re -from typing import Iterable -from typing import List -from typing import Optional -from typing import Union import warnings +from pathlib import Path +from typing import Iterable, List, Optional, Union import g2p_en import jamo @@ -13,7 +10,6 @@ from espnet2.text.abs_tokenizer import AbsTokenizer - g2p_choices = [ None, "g2p_en", @@ -61,9 +57,10 @@ def pyopenjtalk_g2p(text) -> List[str]: def pyopenjtalk_g2p_accent(text) -> List[str]: - import pyopenjtalk import re + import pyopenjtalk + phones = [] for labels in pyopenjtalk.run_frontend(text)[1]: p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9]+)", labels) @@ -73,9 +70,10 @@ def pyopenjtalk_g2p_accent(text) -> List[str]: def pyopenjtalk_g2p_accent_with_pause(text) -> List[str]: - import pyopenjtalk import re + import pyopenjtalk + phones = [] for labels in pyopenjtalk.run_frontend(text)[1]: if labels.split("-")[1].split("+")[0] == "pau": @@ -181,18 +179,15 @@ def _numeric_feature_by_regex(regex, s): def pypinyin_g2p(text) -> List[str]: - from pypinyin import pinyin - from pypinyin import Style + from pypinyin import Style, pinyin phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)] return phones def pypinyin_g2p_phone(text) -> List[str]: - from pypinyin import pinyin - from pypinyin import Style - from pypinyin.style._utils import get_finals - from pypinyin.style._utils import get_initials + from pypinyin import Style, pinyin + from pypinyin.style._utils import get_finals, get_initials phones = [ p diff --git a/espnet2/text/sentencepiece_tokenizer.py b/espnet2/text/sentencepiece_tokenizer.py index 0db7110760c..5fcc2fe4cf2 100644 --- a/espnet2/text/sentencepiece_tokenizer.py +++ b/espnet2/text/sentencepiece_tokenizer.py @@ -1,7 +1,5 @@ from pathlib import Path -from typing import Iterable -from typing import List -from typing import Union +from typing import Iterable, List, Union import sentencepiece as spm from typeguard import check_argument_types diff --git a/espnet2/text/token_id_converter.py b/espnet2/text/token_id_converter.py index c9a6b28638b..96bab4874f2 100644 --- a/espnet2/text/token_id_converter.py +++ b/espnet2/text/token_id_converter.py @@ -1,8 +1,5 @@ from pathlib import Path -from typing import Dict -from typing import Iterable -from typing import List -from typing import Union +from typing import Dict, Iterable, List, Union import numpy as np from typeguard import check_argument_types diff --git a/espnet2/text/word_tokenizer.py b/espnet2/text/word_tokenizer.py index 2788bc03e65..30873ef7297 100644 --- a/espnet2/text/word_tokenizer.py +++ b/espnet2/text/word_tokenizer.py @@ -1,8 +1,6 @@ -from pathlib import Path -from typing import Iterable -from typing import List -from typing import Union import warnings +from pathlib import Path +from typing import Iterable, List, Union from typeguard import check_argument_types diff --git a/espnet2/torch_utils/initialize.py b/espnet2/torch_utils/initialize.py index 2c0e7a43579..038c7cfa4a7 100644 --- a/espnet2/torch_utils/initialize.py +++ b/espnet2/torch_utils/initialize.py @@ -3,6 +3,7 @@ """Initialize modules for espnet2 neural networks.""" import math + import torch from typeguard import check_argument_types diff --git a/espnet2/torch_utils/load_pretrained_model.py b/espnet2/torch_utils/load_pretrained_model.py index 49c7bc6b558..4c7573b6be9 100644 --- a/espnet2/torch_utils/load_pretrained_model.py +++ b/espnet2/torch_utils/load_pretrained_model.py @@ -1,8 +1,6 @@ -from typing import Any -from typing import Dict -from typing import Union - import logging +from typing import Any, Dict, Union + import torch import torch.nn import torch.optim diff --git a/espnet2/train/abs_espnet_model.py b/espnet2/train/abs_espnet_model.py index 6fd50603680..9a9a74348c1 100644 --- a/espnet2/train/abs_espnet_model.py +++ b/espnet2/train/abs_espnet_model.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod -from typing import Dict -from typing import Tuple +from abc import ABC, abstractmethod +from typing import Dict, Tuple import torch diff --git a/espnet2/train/abs_gan_espnet_model.py b/espnet2/train/abs_gan_espnet_model.py index 6e78ecfdca4..323abb85410 100644 --- a/espnet2/train/abs_gan_espnet_model.py +++ b/espnet2/train/abs_gan_espnet_model.py @@ -3,10 +3,8 @@ """ESPnetModel abstract class for GAN-based training.""" -from abc import ABC -from abc import abstractmethod -from typing import Dict -from typing import Union +from abc import ABC, abstractmethod +from typing import Dict, Union import torch diff --git a/espnet2/train/class_choices.py b/espnet2/train/class_choices.py index 821bab8121b..412b33f8453 100644 --- a/espnet2/train/class_choices.py +++ b/espnet2/train/class_choices.py @@ -1,9 +1,6 @@ -from typing import Mapping -from typing import Optional -from typing import Tuple +from typing import Mapping, Optional, Tuple -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.utils.nested_dict_action import NestedDictAction from espnet2.utils.types import str_or_none diff --git a/espnet2/train/collate_fn.py b/espnet2/train/collate_fn.py index a9a5bbb7792..cc4297a30c5 100644 --- a/espnet2/train/collate_fn.py +++ b/espnet2/train/collate_fn.py @@ -1,13 +1,8 @@ -from typing import Collection -from typing import Dict -from typing import List -from typing import Tuple -from typing import Union +from typing import Collection, Dict, List, Tuple, Union import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/espnet2/train/dataset.py b/espnet2/train/dataset.py index 0c47366e94a..dfe481f77f1 100644 --- a/espnet2/train/dataset.py +++ b/espnet2/train/dataset.py @@ -1,18 +1,11 @@ -from abc import ABC -from abc import abstractmethod import collections import copy import functools import logging import numbers import re -from typing import Any -from typing import Callable -from typing import Collection -from typing import Dict -from typing import Mapping -from typing import Tuple -from typing import Union +from abc import ABC, abstractmethod +from typing import Any, Callable, Collection, Dict, Mapping, Tuple, Union import h5py import humanfriendly @@ -20,14 +13,12 @@ import numpy as np import torch from torch.utils.data.dataset import Dataset -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.fileio.npy_scp import NpyScpReader -from espnet2.fileio.rand_gen_dataset import FloatRandomGenerateDataset -from espnet2.fileio.rand_gen_dataset import IntRandomGenerateDataset -from espnet2.fileio.read_text import load_num_sequence_text -from espnet2.fileio.read_text import read_2column_text +from espnet2.fileio.rand_gen_dataset import (FloatRandomGenerateDataset, + IntRandomGenerateDataset) +from espnet2.fileio.read_text import load_num_sequence_text, read_2column_text from espnet2.fileio.rttm import RttmReader from espnet2.fileio.sound_scp import SoundScpReader from espnet2.utils.sized_dict import SizedDict diff --git a/espnet2/train/gan_trainer.py b/espnet2/train/gan_trainer.py index cc0aa1ba95d..62f20c3be79 100644 --- a/espnet2/train/gan_trainer.py +++ b/espnet2/train/gan_trainer.py @@ -7,28 +7,20 @@ import dataclasses import logging import time - from contextlib import contextmanager -from packaging.version import parse as V -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import torch - +from packaging.version import parse as V from typeguard import check_argument_types -from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler -from espnet2.schedulers.abs_scheduler import AbsScheduler +from espnet2.schedulers.abs_scheduler import (AbsBatchStepScheduler, + AbsScheduler) from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.recursive_op import recursive_average from espnet2.train.distributed_utils import DistributedOption from espnet2.train.reporter import SubReporter -from espnet2.train.trainer import Trainer -from espnet2.train.trainer import TrainerOptions +from espnet2.train.trainer import Trainer, TrainerOptions from espnet2.utils.build_dataclass import build_dataclass from espnet2.utils.types import str2bool @@ -36,8 +28,7 @@ from torch.distributed import ReduceOp if V(torch.__version__) >= V("1.6.0"): - from torch.cuda.amp import autocast - from torch.cuda.amp import GradScaler + from torch.cuda.amp import GradScaler, autocast else: # Nothing to do if torch<1.6.0 @contextmanager diff --git a/espnet2/train/iterable_dataset.py b/espnet2/train/iterable_dataset.py index ccf606726f3..7133d749b14 100644 --- a/espnet2/train/iterable_dataset.py +++ b/espnet2/train/iterable_dataset.py @@ -2,12 +2,7 @@ import copy from io import StringIO from pathlib import Path -from typing import Callable -from typing import Collection -from typing import Dict -from typing import Iterator -from typing import Tuple -from typing import Union +from typing import Callable, Collection, Dict, Iterator, Tuple, Union import kaldiio import numpy as np diff --git a/espnet2/train/preprocessor.py b/espnet2/train/preprocessor.py index bdf1c6437e8..0d841b2fd74 100644 --- a/espnet2/train/preprocessor.py +++ b/espnet2/train/preprocessor.py @@ -1,17 +1,11 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from pathlib import Path -from typing import Collection -from typing import Dict -from typing import Iterable -from typing import List -from typing import Union +from typing import Collection, Dict, Iterable, List, Union import numpy as np import scipy.signal import soundfile -from typeguard import check_argument_types -from typeguard import check_return_type +from typeguard import check_argument_types, check_return_type from espnet2.text.build_tokenizer import build_tokenizer from espnet2.text.cleaner import TextCleaner diff --git a/espnet2/train/reporter.py b/espnet2/train/reporter.py index be1d2a51fe5..20865bc81e6 100644 --- a/espnet2/train/reporter.py +++ b/espnet2/train/reporter.py @@ -1,27 +1,19 @@ """Reporter module.""" -from collections import defaultdict -from contextlib import contextmanager import dataclasses import datetime import logging -from packaging.version import parse as V -from pathlib import Path import time -from typing import ContextManager -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union import warnings +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from typing import ContextManager, Dict, List, Optional, Sequence, Tuple, Union import humanfriendly import numpy as np import torch -from typeguard import check_argument_types -from typeguard import check_return_type - +from packaging.version import parse as V +from typeguard import check_argument_types, check_return_type Num = Union[float, int, complex, torch.Tensor, np.ndarray] diff --git a/espnet2/train/trainer.py b/espnet2/train/trainer.py index da8ea6144b4..6e1c4b02b32 100644 --- a/espnet2/train/trainer.py +++ b/espnet2/train/trainer.py @@ -1,50 +1,43 @@ """Trainer module.""" import argparse -from contextlib import contextmanager import dataclasses -from dataclasses import is_dataclass import logging -from packaging.version import parse as V -from pathlib import Path import time -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union +from contextlib import contextmanager +from dataclasses import is_dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import humanfriendly import numpy as np import torch import torch.nn import torch.optim +from packaging.version import parse as V from typeguard import check_argument_types from espnet2.iterators.abs_iter_factory import AbsIterFactory from espnet2.main_funcs.average_nbest_models import average_nbest_models -from espnet2.main_funcs.calculate_all_attentions import calculate_all_attentions -from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler -from espnet2.schedulers.abs_scheduler import AbsEpochStepScheduler -from espnet2.schedulers.abs_scheduler import AbsScheduler -from espnet2.schedulers.abs_scheduler import AbsValEpochStepScheduler +from espnet2.main_funcs.calculate_all_attentions import \ + calculate_all_attentions +from espnet2.schedulers.abs_scheduler import (AbsBatchStepScheduler, + AbsEpochStepScheduler, + AbsScheduler, + AbsValEpochStepScheduler) from espnet2.torch_utils.add_gradient_noise import add_gradient_noise from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.recursive_op import recursive_average from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.train.abs_espnet_model import AbsESPnetModel from espnet2.train.distributed_utils import DistributedOption -from espnet2.train.reporter import Reporter -from espnet2.train.reporter import SubReporter +from espnet2.train.reporter import Reporter, SubReporter from espnet2.utils.build_dataclass import build_dataclass if torch.distributed.is_available(): from torch.distributed import ReduceOp if V(torch.__version__) >= V("1.6.0"): - from torch.cuda.amp import autocast - from torch.cuda.amp import GradScaler + from torch.cuda.amp import GradScaler, autocast else: # Nothing to do if torch<1.6.0 @contextmanager diff --git a/espnet2/tts/abs_tts.py b/espnet2/tts/abs_tts.py index 08eab189ad8..c1da2478b26 100644 --- a/espnet2/tts/abs_tts.py +++ b/espnet2/tts/abs_tts.py @@ -3,10 +3,8 @@ """Text-to-speech abstrast class.""" -from abc import ABC -from abc import abstractmethod -from typing import Dict -from typing import Tuple +from abc import ABC, abstractmethod +from typing import Dict, Tuple import torch diff --git a/espnet2/tts/espnet_model.py b/espnet2/tts/espnet_model.py index 6cb88fe4b5b..ee41d6e4c72 100644 --- a/espnet2/tts/espnet_model.py +++ b/espnet2/tts/espnet_model.py @@ -4,13 +4,10 @@ """Text-to-speech ESPnet model.""" from contextlib import contextmanager -from packaging.version import parse as V -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Dict, Optional, Tuple import torch - +from packaging.version import parse as V from typeguard import check_argument_types from espnet2.layers.abs_normalize import AbsNormalize diff --git a/espnet2/tts/fastspeech/fastspeech.py b/espnet2/tts/fastspeech/fastspeech.py index 481b86976fa..27da48b0e29 100644 --- a/espnet2/tts/fastspeech/fastspeech.py +++ b/espnet2/tts/fastspeech/fastspeech.py @@ -4,38 +4,31 @@ """Fastspeech related modules for ESPnet2.""" import logging - -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn.functional as F - from typeguard import check_argument_types -from espnet.nets.pytorch_backend.conformer.encoder import ( - Encoder as ConformerEncoder, # noqa: H301 -) -from espnet.nets.pytorch_backend.e2e_tts_fastspeech import ( - FeedForwardTransformerLoss as FastSpeechLoss, # NOQA -) -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor -from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding -from espnet.nets.pytorch_backend.transformer.encoder import ( - Encoder as TransformerEncoder, # noqa: H301 -) - from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.torch_utils.initialize import initialize from espnet2.tts.abs_tts import AbsTTS from espnet2.tts.gst.style_encoder import StyleEncoder +from espnet.nets.pytorch_backend.conformer.encoder import \ + Encoder as ConformerEncoder # noqa: H301 +from espnet.nets.pytorch_backend.e2e_tts_fastspeech import \ + FeedForwardTransformerLoss as FastSpeechLoss # NOQA +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.length_regulator import \ + LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + make_pad_mask) +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) +from espnet.nets.pytorch_backend.transformer.encoder import \ + Encoder as TransformerEncoder # noqa: H301 class FastSpeech(AbsTTS): diff --git a/espnet2/tts/fastspeech2/fastspeech2.py b/espnet2/tts/fastspeech2/fastspeech2.py index 06d3b0c6c5f..60a8ba1bfd9 100644 --- a/espnet2/tts/fastspeech2/fastspeech2.py +++ b/espnet2/tts/fastspeech2/fastspeech2.py @@ -4,37 +4,31 @@ """Fastspeech2 related modules for ESPnet2.""" import logging - -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn.functional as F - from typeguard import check_argument_types -from espnet.nets.pytorch_backend.conformer.encoder import ( - Encoder as ConformerEncoder, # noqa: H301 -) -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor -from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding -from espnet.nets.pytorch_backend.transformer.encoder import ( - Encoder as TransformerEncoder, # noqa: H301 -) - from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.torch_utils.initialize import initialize from espnet2.tts.abs_tts import AbsTTS from espnet2.tts.fastspeech2.loss import FastSpeech2Loss from espnet2.tts.fastspeech2.variance_predictor import VariancePredictor from espnet2.tts.gst.style_encoder import StyleEncoder +from espnet.nets.pytorch_backend.conformer.encoder import \ + Encoder as ConformerEncoder # noqa: H301 +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.length_regulator import \ + LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + make_pad_mask) +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) +from espnet.nets.pytorch_backend.transformer.encoder import \ + Encoder as TransformerEncoder # noqa: H301 class FastSpeech2(AbsTTS): diff --git a/espnet2/tts/fastspeech2/loss.py b/espnet2/tts/fastspeech2/loss.py index 086b856831a..33316b6afb8 100644 --- a/espnet2/tts/fastspeech2/loss.py +++ b/espnet2/tts/fastspeech2/loss.py @@ -6,12 +6,10 @@ from typing import Tuple import torch - from typeguard import check_argument_types -from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( - DurationPredictorLoss, # noqa: H301 -) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import \ + DurationPredictorLoss # noqa: H301 from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask diff --git a/espnet2/tts/fastspeech2/variance_predictor.py b/espnet2/tts/fastspeech2/variance_predictor.py index e948c17e0e7..aba9a64576d 100644 --- a/espnet2/tts/fastspeech2/variance_predictor.py +++ b/espnet2/tts/fastspeech2/variance_predictor.py @@ -6,7 +6,6 @@ """Variance predictor related modules.""" import torch - from typeguard import check_argument_types from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm diff --git a/espnet2/tts/feats_extract/abs_feats_extract.py b/espnet2/tts/feats_extract/abs_feats_extract.py index c4a459e5be7..48a2e351307 100644 --- a/espnet2/tts/feats_extract/abs_feats_extract.py +++ b/espnet2/tts/feats_extract/abs_feats_extract.py @@ -1,10 +1,7 @@ -from abc import ABC -from abc import abstractmethod -from typing import Any -from typing import Dict +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple import torch -from typing import Tuple class AbsFeatsExtract(torch.nn.Module, ABC): diff --git a/espnet2/tts/feats_extract/dio.py b/espnet2/tts/feats_extract/dio.py index 43b5dfae306..4e2974e0a45 100644 --- a/espnet2/tts/feats_extract/dio.py +++ b/espnet2/tts/feats_extract/dio.py @@ -4,23 +4,18 @@ """F0 extractor using DIO + Stonemask algorithm.""" import logging - -from typing import Any -from typing import Dict -from typing import Tuple -from typing import Union +from typing import Any, Dict, Tuple, Union import humanfriendly import numpy as np import pyworld import torch import torch.nn.functional as F - from scipy.interpolate import interp1d from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract +from espnet.nets.pytorch_backend.nets_utils import pad_list class Dio(AbsFeatsExtract): diff --git a/espnet2/tts/feats_extract/energy.py b/espnet2/tts/feats_extract/energy.py index d80f3af53b5..c7f9e0fcc14 100644 --- a/espnet2/tts/feats_extract/energy.py +++ b/espnet2/tts/feats_extract/energy.py @@ -3,20 +3,16 @@ """Energy extractor.""" -from typing import Any -from typing import Dict -from typing import Tuple -from typing import Union +from typing import Any, Dict, Tuple, Union import humanfriendly import torch import torch.nn.functional as F - from typeguard import check_argument_types -from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet2.layers.stft import Stft from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract +from espnet.nets.pytorch_backend.nets_utils import pad_list class Energy(AbsFeatsExtract): diff --git a/espnet2/tts/feats_extract/linear_spectrogram.py b/espnet2/tts/feats_extract/linear_spectrogram.py index d8f05d116a0..e8b1a6c0411 100644 --- a/espnet2/tts/feats_extract/linear_spectrogram.py +++ b/espnet2/tts/feats_extract/linear_spectrogram.py @@ -1,7 +1,4 @@ -from typing import Any -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Any, Dict, Optional, Tuple import torch from typeguard import check_argument_types diff --git a/espnet2/tts/feats_extract/log_mel_fbank.py b/espnet2/tts/feats_extract/log_mel_fbank.py index 2073c8cecc3..b05424713e5 100644 --- a/espnet2/tts/feats_extract/log_mel_fbank.py +++ b/espnet2/tts/feats_extract/log_mel_fbank.py @@ -1,8 +1,4 @@ -from typing import Any -from typing import Dict -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Any, Dict, Optional, Tuple, Union import humanfriendly import torch diff --git a/espnet2/tts/feats_extract/log_spectrogram.py b/espnet2/tts/feats_extract/log_spectrogram.py index fa00ea435f1..f436d6e04fe 100644 --- a/espnet2/tts/feats_extract/log_spectrogram.py +++ b/espnet2/tts/feats_extract/log_spectrogram.py @@ -1,7 +1,4 @@ -from typing import Any -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Any, Dict, Optional, Tuple import torch from typeguard import check_argument_types diff --git a/espnet2/tts/gst/style_encoder.py b/espnet2/tts/gst/style_encoder.py index 9fcdd9c52cd..6993ff09427 100644 --- a/espnet2/tts/gst/style_encoder.py +++ b/espnet2/tts/gst/style_encoder.py @@ -3,14 +3,13 @@ """Style encoder of GST-Tacotron.""" -from typeguard import check_argument_types from typing import Sequence import torch +from typeguard import check_argument_types -from espnet.nets.pytorch_backend.transformer.attention import ( - MultiHeadedAttention as BaseMultiHeadedAttention, # NOQA -) +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention as BaseMultiHeadedAttention # NOQA class StyleEncoder(torch.nn.Module): diff --git a/espnet2/tts/tacotron2/tacotron2.py b/espnet2/tts/tacotron2/tacotron2.py index a178b9079fd..50588e061be 100644 --- a/espnet2/tts/tacotron2/tacotron2.py +++ b/espnet2/tts/tacotron2/tacotron2.py @@ -4,28 +4,22 @@ """Tacotron 2 related modules for ESPnet2.""" import logging - -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn.functional as F - from typeguard import check_argument_types -from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss -from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2Loss -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.rnn.attentions import AttForward -from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA -from espnet.nets.pytorch_backend.rnn.attentions import AttLoc -from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder -from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.tts.abs_tts import AbsTTS from espnet2.tts.gst.style_encoder import StyleEncoder +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import (GuidedAttentionLoss, + Tacotron2Loss) +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.attentions import (AttForward, + AttForwardTA, AttLoc) +from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder class Tacotron2(AbsTTS): diff --git a/espnet2/tts/transformer/transformer.py b/espnet2/tts/transformer/transformer.py index f6d1f13cdb1..71bf0cb71f8 100644 --- a/espnet2/tts/transformer/transformer.py +++ b/espnet2/tts/transformer/transformer.py @@ -3,33 +3,32 @@ """Transformer-TTS related modules.""" -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple +from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn.functional as F - from typeguard import check_argument_types -from espnet.nets.pytorch_backend.e2e_tts_transformer import GuidedMultiHeadAttentionLoss -from espnet.nets.pytorch_backend.e2e_tts_transformer import TransformerLoss -from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask -from espnet.nets.pytorch_backend.nets_utils import make_pad_mask -from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet -from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet -from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention -from espnet.nets.pytorch_backend.transformer.decoder import Decoder -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding -from espnet.nets.pytorch_backend.transformer.encoder import Encoder -from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.torch_utils.initialize import initialize from espnet2.tts.abs_tts import AbsTTS from espnet2.tts.gst.style_encoder import StyleEncoder +from espnet.nets.pytorch_backend.e2e_tts_transformer import ( + GuidedMultiHeadAttentionLoss, TransformerLoss) +from espnet.nets.pytorch_backend.nets_utils import (make_non_pad_mask, + make_pad_mask) +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.tacotron2.decoder import \ + Prenet as DecoderPrenet +from espnet.nets.pytorch_backend.tacotron2.encoder import \ + Encoder as EncoderPrenet +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, ScaledPositionalEncoding) +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask class Transformer(AbsTTS): diff --git a/espnet2/tts/utils/__init__.py b/espnet2/tts/utils/__init__.py index 0b512d822e8..f2861a681fb 100644 --- a/espnet2/tts/utils/__init__.py +++ b/espnet2/tts/utils/__init__.py @@ -1,4 +1,3 @@ from espnet2.tts.utils.duration_calculator import DurationCalculator # NOQA -from espnet2.tts.utils.parallel_wavegan_pretrained_vocoder import ( # NOQA - ParallelWaveGANPretrainedVocoder, # NOQA -) +from espnet2.tts.utils.parallel_wavegan_pretrained_vocoder import \ + ParallelWaveGANPretrainedVocoder # NOQA; NOQA diff --git a/espnet2/tts/utils/parallel_wavegan_pretrained_vocoder.py b/espnet2/tts/utils/parallel_wavegan_pretrained_vocoder.py index 5ac5c48cda8..4019c7943d7 100644 --- a/espnet2/tts/utils/parallel_wavegan_pretrained_vocoder.py +++ b/espnet2/tts/utils/parallel_wavegan_pretrained_vocoder.py @@ -5,14 +5,11 @@ import logging import os - from pathlib import Path -from typing import Optional -from typing import Union - -import yaml +from typing import Optional, Union import torch +import yaml class ParallelWaveGANPretrainedVocoder(torch.nn.Module): diff --git a/espnet2/utils/griffin_lim.py b/espnet2/utils/griffin_lim.py index c9b08cd1235..ab7c9097e49 100644 --- a/espnet2/utils/griffin_lim.py +++ b/espnet2/utils/griffin_lim.py @@ -6,15 +6,14 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import logging - from functools import partial -from packaging.version import parse as V -from typeguard import check_argument_types from typing import Optional import librosa import numpy as np import torch +from packaging.version import parse as V +from typeguard import check_argument_types EPS = 1e-10 diff --git a/espnet2/utils/types.py b/espnet2/utils/types.py index 6b36f9c4b87..4d6ec7c3f42 100644 --- a/espnet2/utils/types.py +++ b/espnet2/utils/types.py @@ -1,7 +1,5 @@ from distutils.util import strtobool -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import humanfriendly diff --git a/setup.py b/setup.py index 9fb44f87b25..a0ed245dd25 100644 --- a/setup.py +++ b/setup.py @@ -4,9 +4,7 @@ import os -from setuptools import find_packages -from setuptools import setup - +from setuptools import find_packages, setup requirements = { "install": [ @@ -90,6 +88,7 @@ "flake8>=3.7.8", "flake8-docstrings>=1.3.1", "black", + "isort", ], "doc": [ "Jinja2<3.1", diff --git a/test/espnet2/asr/decoder/test_rnn_decoder.py b/test/espnet2/asr/decoder/test_rnn_decoder.py index aa07c88b079..8df889e6bdd 100644 --- a/test/espnet2/asr/decoder/test_rnn_decoder.py +++ b/test/espnet2/asr/decoder/test_rnn_decoder.py @@ -1,8 +1,8 @@ import pytest import torch -from espnet.nets.beam_search import BeamSearch from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet.nets.beam_search import BeamSearch @pytest.mark.parametrize("context_residual", [True, False]) diff --git a/test/espnet2/asr/decoder/test_transformer_decoder.py b/test/espnet2/asr/decoder/test_transformer_decoder.py index d01c5b07a64..52768e74324 100644 --- a/test/espnet2/asr/decoder/test_transformer_decoder.py +++ b/test/espnet2/asr/decoder/test_transformer_decoder.py @@ -1,22 +1,19 @@ import pytest import torch +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.transformer_decoder import \ + DynamicConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolutionTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import ( + DynamicConvolutionTransformerDecoder, TransformerDecoder) from espnet.nets.batch_beam_search import BatchBeamSearch from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim from espnet.nets.beam_search import BeamSearch from espnet.nets.scorers.ctc import CTCPrefixScorer -from espnet2.asr.ctc import CTC -from espnet2.asr.decoder.transformer_decoder import ( - DynamicConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolution2DTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import ( - LightweightConvolutionTransformerDecoder, # noqa: H301 -) -from espnet2.asr.decoder.transformer_decoder import TransformerDecoder @pytest.mark.parametrize("input_layer", ["linear", "embed"]) diff --git a/test/espnet2/asr/encoder/test_contextual_block_transformer_encoder.py b/test/espnet2/asr/encoder/test_contextual_block_transformer_encoder.py index b440e38b296..229cc18e70a 100644 --- a/test/espnet2/asr/encoder/test_contextual_block_transformer_encoder.py +++ b/test/espnet2/asr/encoder/test_contextual_block_transformer_encoder.py @@ -1,9 +1,8 @@ import pytest import torch -from espnet2.asr.encoder.contextual_block_transformer_encoder import ( - ContextualBlockTransformerEncoder, # noqa: H301 -) +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 @pytest.mark.parametrize("input_layer", ["linear", "conv2d", "embed", None]) diff --git a/test/espnet2/asr/encoder/test_longformer_encoder.py b/test/espnet2/asr/encoder/test_longformer_encoder.py index 8df5f5fc212..40e94fc7916 100644 --- a/test/espnet2/asr/encoder/test_longformer_encoder.py +++ b/test/espnet2/asr/encoder/test_longformer_encoder.py @@ -1,7 +1,8 @@ -from espnet2.asr.encoder.longformer_encoder import LongformerEncoder import pytest import torch +from espnet2.asr.encoder.longformer_encoder import LongformerEncoder + pytest.importorskip("longformer") diff --git a/test/espnet2/asr/frontend/test_fused.py b/test/espnet2/asr/frontend/test_fused.py index 4c35cfb5c03..2e5dd6cb4c1 100644 --- a/test/espnet2/asr/frontend/test_fused.py +++ b/test/espnet2/asr/frontend/test_fused.py @@ -1,6 +1,6 @@ -from espnet2.asr.frontend.fused import FusedFrontends import torch +from espnet2.asr.frontend.fused import FusedFrontends frontend1 = {"frontend_type": "default", "n_mels": 80, "n_fft": 512} frontend2 = {"frontend_type": "default", "hop_length": 128} diff --git a/test/espnet2/asr/frontend/test_s3prl.py b/test/espnet2/asr/frontend/test_s3prl.py index 2c0f66e1ee6..cbf79da1eed 100644 --- a/test/espnet2/asr/frontend/test_s3prl.py +++ b/test/espnet2/asr/frontend/test_s3prl.py @@ -1,6 +1,5 @@ -from packaging.version import parse as V - import torch +from packaging.version import parse as V from espnet2.asr.frontend.s3prl import S3prlFrontend diff --git a/test/espnet2/asr/postencoder/test_hugging_face_transformers_postencoder.py b/test/espnet2/asr/postencoder/test_hugging_face_transformers_postencoder.py index c3dbeaef4d3..af88b70fc9a 100644 --- a/test/espnet2/asr/postencoder/test_hugging_face_transformers_postencoder.py +++ b/test/espnet2/asr/postencoder/test_hugging_face_transformers_postencoder.py @@ -1,9 +1,8 @@ import pytest import torch -from espnet2.asr.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) +from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ + HuggingFaceTransformersPostEncoder # noqa: H301 @pytest.mark.parametrize( diff --git a/test/espnet2/asr/preencoder/test_linear.py b/test/espnet2/asr/preencoder/test_linear.py index bd1d29a9c5c..e4bceb644b3 100644 --- a/test/espnet2/asr/preencoder/test_linear.py +++ b/test/espnet2/asr/preencoder/test_linear.py @@ -1,6 +1,7 @@ -from espnet2.asr.preencoder.linear import LinearProjection import torch +from espnet2.asr.preencoder.linear import LinearProjection + def test_linear_projection_forward(): idim = 400 diff --git a/test/espnet2/asr/preencoder/test_sinc.py b/test/espnet2/asr/preencoder/test_sinc.py index 518a6520b0b..0de81993ec2 100644 --- a/test/espnet2/asr/preencoder/test_sinc.py +++ b/test/espnet2/asr/preencoder/test_sinc.py @@ -1,7 +1,7 @@ -from espnet2.asr.preencoder.sinc import LightweightSincConvs -from espnet2.asr.preencoder.sinc import SpatialDropout import torch +from espnet2.asr.preencoder.sinc import LightweightSincConvs, SpatialDropout + def test_spatial_dropout(): dropout = SpatialDropout() diff --git a/test/espnet2/asr/test_maskctc_model.py b/test/espnet2/asr/test_maskctc_model.py index 4631f9be539..708c54d2f18 100644 --- a/test/espnet2/asr/test_maskctc_model.py +++ b/test/espnet2/asr/test_maskctc_model.py @@ -5,8 +5,7 @@ from espnet2.asr.decoder.mlm_decoder import MLMDecoder from espnet2.asr.encoder.conformer_encoder import ConformerEncoder from espnet2.asr.encoder.transformer_encoder import TransformerEncoder -from espnet2.asr.maskctc_model import MaskCTCInference -from espnet2.asr.maskctc_model import MaskCTCModel +from espnet2.asr.maskctc_model import MaskCTCInference, MaskCTCModel @pytest.mark.parametrize("encoder_arch", [TransformerEncoder, ConformerEncoder]) diff --git a/test/espnet2/bin/test_aggregate_stats_dirs.py b/test/espnet2/bin/test_aggregate_stats_dirs.py index 6584f8a57a0..6e598babce8 100644 --- a/test/espnet2/bin/test_aggregate_stats_dirs.py +++ b/test/espnet2/bin/test_aggregate_stats_dirs.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.aggregate_stats_dirs import get_parser -from espnet2.bin.aggregate_stats_dirs import main +from espnet2.bin.aggregate_stats_dirs import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_asr_align.py b/test/espnet2/bin/test_asr_align.py index a3cad5b6872..1feabefc638 100644 --- a/test/espnet2/bin/test_asr_align.py +++ b/test/espnet2/bin/test_asr_align.py @@ -1,15 +1,13 @@ """Tests for asr_align.py.""" +import string from argparse import ArgumentParser from pathlib import Path -import string import numpy as np import pytest -from espnet2.bin.asr_align import CTCSegmentation -from espnet2.bin.asr_align import CTCSegmentationTask -from espnet2.bin.asr_align import get_parser -from espnet2.bin.asr_align import main +from espnet2.bin.asr_align import (CTCSegmentation, CTCSegmentationTask, + get_parser, main) from espnet2.tasks.asr import ASRTask diff --git a/test/espnet2/bin/test_asr_inference.py b/test/espnet2/bin/test_asr_inference.py index 03dc1cb9285..e4cd0ed527f 100644 --- a/test/espnet2/bin/test_asr_inference.py +++ b/test/espnet2/bin/test_asr_inference.py @@ -1,19 +1,17 @@ +import string from argparse import ArgumentParser from pathlib import Path -import string import numpy as np import pytest import yaml -from espnet.nets.beam_search import Hypothesis -from espnet2.bin.asr_inference import get_parser -from espnet2.bin.asr_inference import main -from espnet2.bin.asr_inference import Speech2Text +from espnet2.bin.asr_inference import Speech2Text, get_parser, main from espnet2.bin.asr_inference_streaming import Speech2TextStreaming from espnet2.tasks.asr import ASRTask from espnet2.tasks.enh_s2t import EnhS2TTask from espnet2.tasks.lm import LMTask +from espnet.nets.beam_search import Hypothesis def test_get_parser(): diff --git a/test/espnet2/bin/test_asr_inference_k2.py b/test/espnet2/bin/test_asr_inference_k2.py index 823c5ce848b..3a76240d00f 100644 --- a/test/espnet2/bin/test_asr_inference_k2.py +++ b/test/espnet2/bin/test_asr_inference_k2.py @@ -1,6 +1,6 @@ +import string from argparse import ArgumentParser from pathlib import Path -import string import numpy as np import pytest @@ -8,7 +8,6 @@ from espnet2.tasks.asr import ASRTask from espnet2.tasks.lm import LMTask - pytest.importorskip("k2") diff --git a/test/espnet2/bin/test_asr_inference_maskctc.py b/test/espnet2/bin/test_asr_inference_maskctc.py index 21a1d0392b4..ff52e5a440a 100644 --- a/test/espnet2/bin/test_asr_inference_maskctc.py +++ b/test/espnet2/bin/test_asr_inference_maskctc.py @@ -1,15 +1,13 @@ +import string from argparse import ArgumentParser from pathlib import Path -import string import numpy as np import pytest -from espnet.nets.beam_search import Hypothesis -from espnet2.bin.asr_inference_maskctc import get_parser -from espnet2.bin.asr_inference_maskctc import main -from espnet2.bin.asr_inference_maskctc import Speech2Text +from espnet2.bin.asr_inference_maskctc import Speech2Text, get_parser, main from espnet2.tasks.asr import ASRTask +from espnet.nets.beam_search import Hypothesis def test_get_parser(): diff --git a/test/espnet2/bin/test_asr_train.py b/test/espnet2/bin/test_asr_train.py index 066c28865c1..1188a628313 100644 --- a/test/espnet2/bin/test_asr_train.py +++ b/test/espnet2/bin/test_asr_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.asr_train import get_parser -from espnet2.bin.asr_train import main +from espnet2.bin.asr_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_diar_inference.py b/test/espnet2/bin/test_diar_inference.py index 8781200eb72..3275aca9dbf 100644 --- a/test/espnet2/bin/test_diar_inference.py +++ b/test/espnet2/bin/test_diar_inference.py @@ -4,9 +4,7 @@ import pytest import torch -from espnet2.bin.diar_inference import DiarizeSpeech -from espnet2.bin.diar_inference import get_parser -from espnet2.bin.diar_inference import main +from espnet2.bin.diar_inference import DiarizeSpeech, get_parser, main from espnet2.tasks.diar import DiarizationTask diff --git a/test/espnet2/bin/test_diar_train.py b/test/espnet2/bin/test_diar_train.py index 9f0cd5dff2f..ddb991df968 100644 --- a/test/espnet2/bin/test_diar_train.py +++ b/test/espnet2/bin/test_diar_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.diar_train import get_parser -from espnet2.bin.diar_train import main +from espnet2.bin.diar_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_enh_inference.py b/test/espnet2/bin/test_enh_inference.py index 2bad3cae4ea..95d788784f9 100644 --- a/test/espnet2/bin/test_enh_inference.py +++ b/test/espnet2/bin/test_enh_inference.py @@ -1,14 +1,12 @@ +import string from argparse import ArgumentParser from pathlib import Path -import string import pytest import torch import yaml -from espnet2.bin.enh_inference import get_parser -from espnet2.bin.enh_inference import main -from espnet2.bin.enh_inference import SeparateSpeech +from espnet2.bin.enh_inference import SeparateSpeech, get_parser, main from espnet2.enh.encoder.stft_encoder import STFTEncoder from espnet2.tasks.enh import EnhancementTask from espnet2.tasks.enh_s2t import EnhS2TTask diff --git a/test/espnet2/bin/test_enh_s2t_train.py b/test/espnet2/bin/test_enh_s2t_train.py index 2cd4fe6f94f..b75431df8b9 100644 --- a/test/espnet2/bin/test_enh_s2t_train.py +++ b/test/espnet2/bin/test_enh_s2t_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.enh_s2t_train import get_parser -from espnet2.bin.enh_s2t_train import main +from espnet2.bin.enh_s2t_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_enh_scoring.py b/test/espnet2/bin/test_enh_scoring.py index a4e6f31ae88..fb4ec500e53 100644 --- a/test/espnet2/bin/test_enh_scoring.py +++ b/test/espnet2/bin/test_enh_scoring.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.enh_scoring import get_parser -from espnet2.bin.enh_scoring import main +from espnet2.bin.enh_scoring import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_enh_train.py b/test/espnet2/bin/test_enh_train.py index 15620a92851..23939a826e6 100644 --- a/test/espnet2/bin/test_enh_train.py +++ b/test/espnet2/bin/test_enh_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.enh_train import get_parser -from espnet2.bin.enh_train import main +from espnet2.bin.enh_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_hubert_train.py b/test/espnet2/bin/test_hubert_train.py index 912cb4cae68..d74afcd197d 100644 --- a/test/espnet2/bin/test_hubert_train.py +++ b/test/espnet2/bin/test_hubert_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.hubert_train import get_parser -from espnet2.bin.hubert_train import main +from espnet2.bin.hubert_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_lm_calc_perplexity.py b/test/espnet2/bin/test_lm_calc_perplexity.py index 51126a783f7..e8010aad525 100644 --- a/test/espnet2/bin/test_lm_calc_perplexity.py +++ b/test/espnet2/bin/test_lm_calc_perplexity.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.lm_calc_perplexity import get_parser -from espnet2.bin.lm_calc_perplexity import main +from espnet2.bin.lm_calc_perplexity import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_lm_train.py b/test/espnet2/bin/test_lm_train.py index ff1c7dce247..1889a985087 100644 --- a/test/espnet2/bin/test_lm_train.py +++ b/test/espnet2/bin/test_lm_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.lm_train import get_parser -from espnet2.bin.lm_train import main +from espnet2.bin.lm_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_pack.py b/test/espnet2/bin/test_pack.py index 0e242de8036..fb98168f365 100755 --- a/test/espnet2/bin/test_pack.py +++ b/test/espnet2/bin/test_pack.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.pack import get_parser -from espnet2.bin.pack import main +from espnet2.bin.pack import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_st_inference.py b/test/espnet2/bin/test_st_inference.py index 3910479456b..f43d2be403c 100644 --- a/test/espnet2/bin/test_st_inference.py +++ b/test/espnet2/bin/test_st_inference.py @@ -1,15 +1,13 @@ +import string from argparse import ArgumentParser from pathlib import Path -import string import numpy as np import pytest -from espnet.nets.beam_search import Hypothesis -from espnet2.bin.st_inference import get_parser -from espnet2.bin.st_inference import main -from espnet2.bin.st_inference import Speech2Text +from espnet2.bin.st_inference import Speech2Text, get_parser, main from espnet2.tasks.st import STTask +from espnet.nets.beam_search import Hypothesis def test_get_parser(): diff --git a/test/espnet2/bin/test_st_train.py b/test/espnet2/bin/test_st_train.py index 5be899f0a38..5fd51bdec02 100644 --- a/test/espnet2/bin/test_st_train.py +++ b/test/espnet2/bin/test_st_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.st_train import get_parser -from espnet2.bin.st_train import main +from espnet2.bin.st_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_tokenize_text.py b/test/espnet2/bin/test_tokenize_text.py index 18ecc79dd59..42b59a5359c 100755 --- a/test/espnet2/bin/test_tokenize_text.py +++ b/test/espnet2/bin/test_tokenize_text.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.tokenize_text import get_parser -from espnet2.bin.tokenize_text import main +from espnet2.bin.tokenize_text import get_parser, main def test_get_parser(): diff --git a/test/espnet2/bin/test_tts_inference.py b/test/espnet2/bin/test_tts_inference.py index 7c0c5e281a9..b8d47e4eb5e 100644 --- a/test/espnet2/bin/test_tts_inference.py +++ b/test/espnet2/bin/test_tts_inference.py @@ -1,12 +1,10 @@ +import string from argparse import ArgumentParser from pathlib import Path -import string import pytest -from espnet2.bin.tts_inference import get_parser -from espnet2.bin.tts_inference import main -from espnet2.bin.tts_inference import Text2Speech +from espnet2.bin.tts_inference import Text2Speech, get_parser, main from espnet2.tasks.tts import TTSTask diff --git a/test/espnet2/bin/test_tts_train.py b/test/espnet2/bin/test_tts_train.py index 236bae3fe27..5b5bf18e58a 100644 --- a/test/espnet2/bin/test_tts_train.py +++ b/test/espnet2/bin/test_tts_train.py @@ -2,8 +2,7 @@ import pytest -from espnet2.bin.tts_train import get_parser -from espnet2.bin.tts_train import main +from espnet2.bin.tts_train import get_parser, main def test_get_parser(): diff --git a/test/espnet2/enh/decoder/test_stft_decoder.py b/test/espnet2/enh/decoder/test_stft_decoder.py index 4389d7b858f..3443bfc073e 100644 --- a/test/espnet2/enh/decoder/test_stft_decoder.py +++ b/test/espnet2/enh/decoder/test_stft_decoder.py @@ -1,5 +1,4 @@ import pytest - import torch from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/layers/test_complex_utils.py b/test/espnet2/enh/layers/test_complex_utils.py index 6404f33eaa3..32a31a5c7f4 100644 --- a/test/espnet2/enh/layers/test_complex_utils.py +++ b/test/espnet2/enh/layers/test_complex_utils.py @@ -1,20 +1,13 @@ -from packaging.version import parse as V - import numpy as np import pytest import torch import torch_complex.functional as FC +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.complex_utils import cat -from espnet2.enh.layers.complex_utils import complex_norm -from espnet2.enh.layers.complex_utils import einsum -from espnet2.enh.layers.complex_utils import inverse -from espnet2.enh.layers.complex_utils import matmul -from espnet2.enh.layers.complex_utils import solve -from espnet2.enh.layers.complex_utils import stack -from espnet2.enh.layers.complex_utils import trace - +from espnet2.enh.layers.complex_utils import (cat, complex_norm, einsum, + inverse, matmul, solve, stack, + trace) is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") # invertible matrix diff --git a/test/espnet2/enh/layers/test_conv_utils.py b/test/espnet2/enh/layers/test_conv_utils.py index 7e7ea22672c..876e9f375a5 100644 --- a/test/espnet2/enh/layers/test_conv_utils.py +++ b/test/espnet2/enh/layers/test_conv_utils.py @@ -1,8 +1,8 @@ import pytest import torch -from espnet2.enh.layers.conv_utils import conv2d_output_shape -from espnet2.enh.layers.conv_utils import convtransp2d_output_shape +from espnet2.enh.layers.conv_utils import (conv2d_output_shape, + convtransp2d_output_shape) @pytest.mark.parametrize("input_dim", [(10, 17), (10, 33)]) diff --git a/test/espnet2/enh/layers/test_enh_layers.py b/test/espnet2/enh/layers/test_enh_layers.py index 3d4f0a84ead..3ee389f1559 100644 --- a/test/espnet2/enh/layers/test_enh_layers.py +++ b/test/espnet2/enh/layers/test_enh_layers.py @@ -1,15 +1,13 @@ -from packaging.version import parse as V - import numpy as np import pytest import torch import torch_complex.functional as FC +from packaging.version import parse as V from torch_complex.tensor import ComplexTensor -from espnet2.enh.layers.beamformer import generalized_eigenvalue_decomposition -from espnet2.enh.layers.beamformer import get_rtf -from espnet2.enh.layers.beamformer import gev_phase_correction -from espnet2.enh.layers.beamformer import signal_framing +from espnet2.enh.layers.beamformer import ( + generalized_eigenvalue_decomposition, get_rtf, gev_phase_correction, + signal_framing) from espnet2.enh.layers.complex_utils import solve from espnet2.layers.stft import Stft diff --git a/test/espnet2/enh/loss/criterions/test_tf_domain.py b/test/espnet2/enh/loss/criterions/test_tf_domain.py index 117a16545db..a12b68686f1 100644 --- a/test/espnet2/enh/loss/criterions/test_tf_domain.py +++ b/test/espnet2/enh/loss/criterions/test_tf_domain.py @@ -1,15 +1,13 @@ -from packaging.version import parse as V import pytest import torch - +from packaging.version import parse as V from torch_complex import ComplexTensor -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainAbsCoherence -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainCrossEntropy -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE - +from espnet2.enh.loss.criterions.tf_domain import (FrequencyDomainAbsCoherence, + FrequencyDomainCrossEntropy, + FrequencyDomainDPCL, + FrequencyDomainL1, + FrequencyDomainMSE) is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/test/espnet2/enh/loss/criterions/test_time_domain.py b/test/espnet2/enh/loss/criterions/test_time_domain.py index 208b23ab85f..5a1e94111be 100644 --- a/test/espnet2/enh/loss/criterions/test_time_domain.py +++ b/test/espnet2/enh/loss/criterions/test_time_domain.py @@ -1,12 +1,10 @@ import pytest import torch -from espnet2.enh.loss.criterions.time_domain import CISDRLoss -from espnet2.enh.loss.criterions.time_domain import SDRLoss -from espnet2.enh.loss.criterions.time_domain import SISNRLoss -from espnet2.enh.loss.criterions.time_domain import SNRLoss -from espnet2.enh.loss.criterions.time_domain import TimeDomainL1 -from espnet2.enh.loss.criterions.time_domain import TimeDomainMSE +from espnet2.enh.loss.criterions.time_domain import (CISDRLoss, SDRLoss, + SISNRLoss, SNRLoss, + TimeDomainL1, + TimeDomainMSE) @pytest.mark.parametrize("criterion_class", [CISDRLoss, SISNRLoss, SNRLoss, SDRLoss]) diff --git a/test/espnet2/enh/loss/wrappers/test_multilayer_pit_solver.py b/test/espnet2/enh/loss/wrappers/test_multilayer_pit_solver.py index 3505a007eee..63db4587e38 100644 --- a/test/espnet2/enh/loss/wrappers/test_multilayer_pit_solver.py +++ b/test/espnet2/enh/loss/wrappers/test_multilayer_pit_solver.py @@ -1,5 +1,4 @@ import pytest - import torch from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 diff --git a/test/espnet2/enh/loss/wrappers/test_pit_solver.py b/test/espnet2/enh/loss/wrappers/test_pit_solver.py index ddba099e17e..edc60f4e69e 100644 --- a/test/espnet2/enh/loss/wrappers/test_pit_solver.py +++ b/test/espnet2/enh/loss/wrappers/test_pit_solver.py @@ -2,8 +2,8 @@ import torch import torch.nn.functional as F -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainCrossEntropy -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 +from espnet2.enh.loss.criterions.tf_domain import (FrequencyDomainCrossEntropy, + FrequencyDomainL1) from espnet2.enh.loss.wrappers.pit_solver import PITSolver diff --git a/test/espnet2/enh/separator/test_beamformer.py b/test/espnet2/enh/separator/test_beamformer.py index eddf317ee86..9e58b428ad5 100644 --- a/test/espnet2/enh/separator/test_beamformer.py +++ b/test/espnet2/enh/separator/test_beamformer.py @@ -1,12 +1,11 @@ -from packaging.version import parse as V import pytest import torch +from packaging.version import parse as V from espnet2.enh.encoder.stft_encoder import STFTEncoder from espnet2.enh.layers.dnn_beamformer import BEAMFORMER_TYPES from espnet2.enh.separator.neural_beamformer import NeuralBeamformer - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") random_speech = torch.tensor( [ diff --git a/test/espnet2/enh/separator/test_conformer_separator.py b/test/espnet2/enh/separator/test_conformer_separator.py index 2ba800acaa2..b9e0b924d65 100644 --- a/test/espnet2/enh/separator/test_conformer_separator.py +++ b/test/espnet2/enh/separator/test_conformer_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex.tensor import ComplexTensor diff --git a/test/espnet2/enh/separator/test_dan_separator.py b/test/espnet2/enh/separator/test_dan_separator.py index 675176bea8a..2ea1767ad46 100644 --- a/test/espnet2/enh/separator/test_dan_separator.py +++ b/test/espnet2/enh/separator/test_dan_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_dc_crn_separator.py b/test/espnet2/enh/separator/test_dc_crn_separator.py index 8f60b62399a..21cb54bbd6f 100644 --- a/test/espnet2/enh/separator/test_dc_crn_separator.py +++ b/test/espnet2/enh/separator/test_dc_crn_separator.py @@ -1,13 +1,11 @@ -from packaging.version import parse as V import pytest - import torch +from packaging.version import parse as V from torch_complex import ComplexTensor from espnet2.enh.layers.complex_utils import is_complex from espnet2.enh.separator.dc_crn_separator import DC_CRNSeparator - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/test/espnet2/enh/separator/test_dccrn_separator.py b/test/espnet2/enh/separator/test_dccrn_separator.py index 3a075ac42ba..d30ba0c9ee0 100644 --- a/test/espnet2/enh/separator/test_dccrn_separator.py +++ b/test/espnet2/enh/separator/test_dccrn_separator.py @@ -1,7 +1,6 @@ -from packaging.version import parse as V import pytest - import torch +from packaging.version import parse as V from torch_complex import ComplexTensor from espnet2.enh.separator.dccrn_separator import DCCRNSeparator diff --git a/test/espnet2/enh/separator/test_dpcl_e2e_separator.py b/test/espnet2/enh/separator/test_dpcl_e2e_separator.py index c470a9ee83f..574bc26b22f 100644 --- a/test/espnet2/enh/separator/test_dpcl_e2e_separator.py +++ b/test/espnet2/enh/separator/test_dpcl_e2e_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_dpcl_separator.py b/test/espnet2/enh/separator/test_dpcl_separator.py index 19304579af8..3c7693492e0 100644 --- a/test/espnet2/enh/separator/test_dpcl_separator.py +++ b/test/espnet2/enh/separator/test_dpcl_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_dprnn_separator.py b/test/espnet2/enh/separator/test_dprnn_separator.py index 24b653d562a..e4441b20650 100644 --- a/test/espnet2/enh/separator/test_dprnn_separator.py +++ b/test/espnet2/enh/separator/test_dprnn_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_fasnet_separator.py b/test/espnet2/enh/separator/test_fasnet_separator.py index 603dc9ce680..bfe21aaed38 100644 --- a/test/espnet2/enh/separator/test_fasnet_separator.py +++ b/test/espnet2/enh/separator/test_fasnet_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor diff --git a/test/espnet2/enh/separator/test_rnn_separator.py b/test/espnet2/enh/separator/test_rnn_separator.py index 0371ffed0bf..62478c300f0 100644 --- a/test/espnet2/enh/separator/test_rnn_separator.py +++ b/test/espnet2/enh/separator/test_rnn_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_skim_separator.py b/test/espnet2/enh/separator/test_skim_separator.py index e1594cd5620..ce21e4254b9 100644 --- a/test/espnet2/enh/separator/test_skim_separator.py +++ b/test/espnet2/enh/separator/test_skim_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_svoice_separator.py b/test/espnet2/enh/separator/test_svoice_separator.py index b2fb191856c..45d79c0e3d0 100644 --- a/test/espnet2/enh/separator/test_svoice_separator.py +++ b/test/espnet2/enh/separator/test_svoice_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor diff --git a/test/espnet2/enh/separator/test_tcn_separator.py b/test/espnet2/enh/separator/test_tcn_separator.py index f2babeda466..380f858d180 100644 --- a/test/espnet2/enh/separator/test_tcn_separator.py +++ b/test/espnet2/enh/separator/test_tcn_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/separator/test_transformer_separator.py b/test/espnet2/enh/separator/test_transformer_separator.py index 474bbff14f5..2dfa6a346d3 100644 --- a/test/espnet2/enh/separator/test_transformer_separator.py +++ b/test/espnet2/enh/separator/test_transformer_separator.py @@ -1,5 +1,4 @@ import pytest - import torch from torch import Tensor from torch_complex import ComplexTensor diff --git a/test/espnet2/enh/test_espnet_enh_s2t_model.py b/test/espnet2/enh/test_espnet_enh_s2t_model.py index 5f7df398130..383376c5104 100644 --- a/test/espnet2/enh/test_espnet_enh_s2t_model.py +++ b/test/espnet2/enh/test_espnet_enh_s2t_model.py @@ -14,7 +14,6 @@ from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver from espnet2.enh.separator.rnn_separator import RNNSeparator - enh_stft_encoder = STFTEncoder( n_fft=32, hop_length=16, diff --git a/test/espnet2/enh/test_espnet_model.py b/test/espnet2/enh/test_espnet_model.py index 906b42bbac3..d257d544535 100644 --- a/test/espnet2/enh/test_espnet_model.py +++ b/test/espnet2/enh/test_espnet_model.py @@ -1,7 +1,6 @@ -from packaging.version import parse as V - import pytest import torch +from packaging.version import parse as V from espnet2.enh.decoder.conv_decoder import ConvDecoder from espnet2.enh.decoder.null_decoder import NullDecoder @@ -10,8 +9,8 @@ from espnet2.enh.encoder.null_encoder import NullEncoder from espnet2.enh.encoder.stft_encoder import STFTEncoder from espnet2.enh.espnet_model import ESPnetEnhancementModel -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE +from espnet2.enh.loss.criterions.tf_domain import (FrequencyDomainL1, + FrequencyDomainMSE) from espnet2.enh.loss.criterions.time_domain import SISNRLoss from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver from espnet2.enh.loss.wrappers.multilayer_pit_solver import MultiLayerPITSolver @@ -25,7 +24,6 @@ from espnet2.enh.separator.tcn_separator import TCNSeparator from espnet2.enh.separator.transformer_separator import TransformerSeparator - is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") diff --git a/test/espnet2/fileio/test_npy_scp.py b/test/espnet2/fileio/test_npy_scp.py index 4f81b68ed64..f867a25bd53 100644 --- a/test/espnet2/fileio/test_npy_scp.py +++ b/test/espnet2/fileio/test_npy_scp.py @@ -3,10 +3,8 @@ import numpy as np import pytest -from espnet2.fileio.npy_scp import NpyScpReader -from espnet2.fileio.npy_scp import NpyScpWriter -from espnet2.fileio.sound_scp import SoundScpReader -from espnet2.fileio.sound_scp import SoundScpWriter +from espnet2.fileio.npy_scp import NpyScpReader, NpyScpWriter +from espnet2.fileio.sound_scp import SoundScpReader, SoundScpWriter def test_NpyScpReader(tmp_path: Path): diff --git a/test/espnet2/fileio/test_read_text.py b/test/espnet2/fileio/test_read_text.py index feace34bdf7..ffd3a81259a 100644 --- a/test/espnet2/fileio/test_read_text.py +++ b/test/espnet2/fileio/test_read_text.py @@ -3,8 +3,7 @@ import numpy as np import pytest -from espnet2.fileio.read_text import load_num_sequence_text -from espnet2.fileio.read_text import read_2column_text +from espnet2.fileio.read_text import load_num_sequence_text, read_2column_text def test_read_2column_text(tmp_path: Path): diff --git a/test/espnet2/gan_tts/hifigan/test_hifigan.py b/test/espnet2/gan_tts/hifigan/test_hifigan.py index 1bfc7308103..33ed75281c0 100644 --- a/test/espnet2/gan_tts/hifigan/test_hifigan.py +++ b/test/espnet2/gan_tts/hifigan/test_hifigan.py @@ -7,12 +7,12 @@ import pytest import torch -from espnet2.gan_tts.hifigan import HiFiGANGenerator -from espnet2.gan_tts.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import FeatureMatchLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import MelSpectrogramLoss +from espnet2.gan_tts.hifigan import (HiFiGANGenerator, + HiFiGANMultiScaleMultiPeriodDiscriminator) +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + MelSpectrogramLoss) def make_hifigan_generator_args(**kwargs): diff --git a/test/espnet2/gan_tts/joint/test_joint_text2wav.py b/test/espnet2/gan_tts/joint/test_joint_text2wav.py index f0ed087da20..a82b215afba 100644 --- a/test/espnet2/gan_tts/joint/test_joint_text2wav.py +++ b/test/espnet2/gan_tts/joint/test_joint_text2wav.py @@ -3,10 +3,9 @@ """Test VITS related modules.""" -from packaging.version import parse as V - import pytest import torch +from packaging.version import parse as V from espnet2.gan_tts.joint import JointText2Wav diff --git a/test/espnet2/gan_tts/melgan/test_melgan.py b/test/espnet2/gan_tts/melgan/test_melgan.py index 81d5874007b..ed5496fa780 100644 --- a/test/espnet2/gan_tts/melgan/test_melgan.py +++ b/test/espnet2/gan_tts/melgan/test_melgan.py @@ -7,11 +7,11 @@ import pytest import torch -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import FeatureMatchLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.melgan import MelGANGenerator -from espnet2.gan_tts.melgan import MelGANMultiScaleDiscriminator +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss) +from espnet2.gan_tts.melgan import (MelGANGenerator, + MelGANMultiScaleDiscriminator) def make_melgan_generator_args(**kwargs): diff --git a/test/espnet2/gan_tts/parallel_wavegan/test_parallel_wavegan.py b/test/espnet2/gan_tts/parallel_wavegan/test_parallel_wavegan.py index 098ce45ea8c..d7103ba9c67 100644 --- a/test/espnet2/gan_tts/parallel_wavegan/test_parallel_wavegan.py +++ b/test/espnet2/gan_tts/parallel_wavegan/test_parallel_wavegan.py @@ -7,10 +7,10 @@ import pytest import torch -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.parallel_wavegan import ParallelWaveGANDiscriminator -from espnet2.gan_tts.parallel_wavegan import ParallelWaveGANGenerator +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + GeneratorAdversarialLoss) +from espnet2.gan_tts.parallel_wavegan import (ParallelWaveGANDiscriminator, + ParallelWaveGANGenerator) def make_generator_args(**kwargs): @@ -134,9 +134,8 @@ def test_parallel_wavegan_generator_and_discriminator(dict_g, dict_d): not is_parallel_wavegan_available, reason="parallel_wavegan is not installed." ) def test_parallel_wavegan_compatibility(): - from parallel_wavegan.models import ( - ParallelWaveGANGenerator as PWGParallelWaveGANGenerator, # NOQA - ) + from parallel_wavegan.models import \ + ParallelWaveGANGenerator as PWGParallelWaveGANGenerator # NOQA model_pwg = PWGParallelWaveGANGenerator(**make_generator_args()) model_espnet2 = ParallelWaveGANGenerator(**make_generator_args()) diff --git a/test/espnet2/gan_tts/style_melgan/test_style_melgan.py b/test/espnet2/gan_tts/style_melgan/test_style_melgan.py index 8f8f3f546f2..b3450aadb3c 100644 --- a/test/espnet2/gan_tts/style_melgan/test_style_melgan.py +++ b/test/espnet2/gan_tts/style_melgan/test_style_melgan.py @@ -7,10 +7,10 @@ import pytest import torch -from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss -from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss -from espnet2.gan_tts.style_melgan import StyleMelGANDiscriminator -from espnet2.gan_tts.style_melgan import StyleMelGANGenerator +from espnet2.gan_tts.hifigan.loss import (DiscriminatorAdversarialLoss, + GeneratorAdversarialLoss) +from espnet2.gan_tts.style_melgan import (StyleMelGANDiscriminator, + StyleMelGANGenerator) def make_style_melgan_generator_args(**kwargs): @@ -124,7 +124,8 @@ def test_style_melgan_trainable(dict_g, dict_d): not is_parallel_wavegan_available, reason="parallel_wavegan is not installed." ) def test_parallel_wavegan_compatibility(): - from parallel_wavegan.models import StyleMelGANGenerator as PWGStyleMelGANGenerator + from parallel_wavegan.models import \ + StyleMelGANGenerator as PWGStyleMelGANGenerator model_pwg = PWGStyleMelGANGenerator(**make_style_melgan_generator_args()) model_espnet2 = StyleMelGANGenerator(**make_style_melgan_generator_args()) diff --git a/test/espnet2/hubert/test_hubert_loss.py b/test/espnet2/hubert/test_hubert_loss.py index f51aecafb1c..34174e598f2 100644 --- a/test/espnet2/hubert/test_hubert_loss.py +++ b/test/espnet2/hubert/test_hubert_loss.py @@ -1,10 +1,9 @@ import pytest import torch +from espnet2.asr.encoder.hubert_encoder import \ + FairseqHubertPretrainEncoder # noqa: H301 from espnet2.hubert.hubert_loss import HubertPretrainLoss # noqa: H301 -from espnet2.asr.encoder.hubert_encoder import ( - FairseqHubertPretrainEncoder, # noqa: H301 -) pytest.importorskip("fairseq") diff --git a/test/espnet2/iterators/test_chunk_iter_factory.py b/test/espnet2/iterators/test_chunk_iter_factory.py index 5011dc8c3c5..74f51970c70 100644 --- a/test/espnet2/iterators/test_chunk_iter_factory.py +++ b/test/espnet2/iterators/test_chunk_iter_factory.py @@ -1,8 +1,8 @@ +import numpy as np + from espnet2.iterators.chunk_iter_factory import ChunkIterFactory from espnet2.train.collate_fn import CommonCollateFn -import numpy as np - class Dataset: def __init__(self): diff --git a/test/espnet2/layers/test_sinc_filters.py b/test/espnet2/layers/test_sinc_filters.py index c6ee0244383..97ca33d5a43 100644 --- a/test/espnet2/layers/test_sinc_filters.py +++ b/test/espnet2/layers/test_sinc_filters.py @@ -1,9 +1,7 @@ import torch -from espnet2.layers.sinc_conv import BarkScale -from espnet2.layers.sinc_conv import LogCompression -from espnet2.layers.sinc_conv import MelScale -from espnet2.layers.sinc_conv import SincConv +from espnet2.layers.sinc_conv import (BarkScale, LogCompression, MelScale, + SincConv) def test_log_compression(): diff --git a/test/espnet2/lm/test_seq_rnn_lm.py b/test/espnet2/lm/test_seq_rnn_lm.py index 480b9d549b3..bedc8eb54c9 100644 --- a/test/espnet2/lm/test_seq_rnn_lm.py +++ b/test/espnet2/lm/test_seq_rnn_lm.py @@ -1,9 +1,9 @@ import pytest import torch +from espnet2.lm.seq_rnn_lm import SequentialRNNLM from espnet.nets.batch_beam_search import BatchBeamSearch from espnet.nets.beam_search import BeamSearch -from espnet2.lm.seq_rnn_lm import SequentialRNNLM @pytest.mark.parametrize("rnn_type", ["LSTM", "GRU", "RNN_TANH", "RNN_RELU"]) diff --git a/test/espnet2/lm/test_transformer_lm.py b/test/espnet2/lm/test_transformer_lm.py index 2f52785fc25..36adce59f88 100644 --- a/test/espnet2/lm/test_transformer_lm.py +++ b/test/espnet2/lm/test_transformer_lm.py @@ -1,9 +1,9 @@ import pytest import torch +from espnet2.lm.transformer_lm import TransformerLM from espnet.nets.batch_beam_search import BatchBeamSearch from espnet.nets.beam_search import BeamSearch -from espnet2.lm.transformer_lm import TransformerLM @pytest.mark.parametrize("pos_enc", ["sinusoidal", None]) diff --git a/test/espnet2/main_funcs/test_calculate_all_attentions.py b/test/espnet2/main_funcs/test_calculate_all_attentions.py index e33dbd303ef..2b63b22549f 100644 --- a/test/espnet2/main_funcs/test_calculate_all_attentions.py +++ b/test/espnet2/main_funcs/test_calculate_all_attentions.py @@ -4,11 +4,13 @@ import pytest import torch -from espnet.nets.pytorch_backend.rnn.attentions import AttAdd -from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention from espnet2.asr.decoder.rnn_decoder import RNNDecoder -from espnet2.main_funcs.calculate_all_attentions import calculate_all_attentions +from espnet2.main_funcs.calculate_all_attentions import \ + calculate_all_attentions from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.pytorch_backend.rnn.attentions import AttAdd +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention class Dummy(AbsESPnetModel): diff --git a/test/espnet2/main_funcs/test_pack_funcs.py b/test/espnet2/main_funcs/test_pack_funcs.py index 839a0e24e19..0e5d2edb943 100644 --- a/test/espnet2/main_funcs/test_pack_funcs.py +++ b/test/espnet2/main_funcs/test_pack_funcs.py @@ -1,12 +1,11 @@ -from pathlib import Path import tarfile +from pathlib import Path import pytest import yaml -from espnet2.main_funcs.pack_funcs import find_path_and_change_it_recursive -from espnet2.main_funcs.pack_funcs import pack -from espnet2.main_funcs.pack_funcs import unpack +from espnet2.main_funcs.pack_funcs import (find_path_and_change_it_recursive, + pack, unpack) def test_find_path_and_change_it_recursive(): diff --git a/test/espnet2/text/test_sentencepiece_tokenizer.py b/test/espnet2/text/test_sentencepiece_tokenizer.py index 7baea4191f1..eabf741fdf3 100644 --- a/test/espnet2/text/test_sentencepiece_tokenizer.py +++ b/test/espnet2/text/test_sentencepiece_tokenizer.py @@ -1,5 +1,5 @@ -from pathlib import Path import string +from pathlib import Path import pytest import sentencepiece as spm diff --git a/test/espnet2/torch_utils/test_device_funcs.py b/test/espnet2/torch_utils/test_device_funcs.py index 2ddce8de3d7..69b6274db17 100644 --- a/test/espnet2/torch_utils/test_device_funcs.py +++ b/test/espnet2/torch_utils/test_device_funcs.py @@ -4,8 +4,7 @@ import pytest import torch -from espnet2.torch_utils.device_funcs import force_gatherable -from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.device_funcs import force_gatherable, to_device x = torch.tensor(10) diff --git a/test/espnet2/train/test_collate_fn.py b/test/espnet2/train/test_collate_fn.py index 8c69fcb9061..75841910824 100644 --- a/test/espnet2/train/test_collate_fn.py +++ b/test/espnet2/train/test_collate_fn.py @@ -1,8 +1,7 @@ import numpy as np import pytest -from espnet2.train.collate_fn import common_collate_fn -from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.collate_fn import CommonCollateFn, common_collate_fn @pytest.mark.parametrize( diff --git a/test/espnet2/train/test_distributed_utils.py b/test/espnet2/train/test_distributed_utils.py index c52fed773eb..8c62a1f125c 100644 --- a/test/espnet2/train/test_distributed_utils.py +++ b/test/espnet2/train/test_distributed_utils.py @@ -1,14 +1,13 @@ import argparse +import unittest.mock from concurrent.futures.process import ProcessPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor -import unittest.mock import pytest from espnet2.tasks.abs_task import AbsTask -from espnet2.train.distributed_utils import DistributedOption -from espnet2.train.distributed_utils import free_port -from espnet2.train.distributed_utils import resolve_distributed_mode +from espnet2.train.distributed_utils import (DistributedOption, free_port, + resolve_distributed_mode) from espnet2.utils.build_dataclass import build_dataclass diff --git a/test/espnet2/train/test_reporter.py b/test/espnet2/train/test_reporter.py index 9cd796d665c..ec1f7efc14a 100644 --- a/test/espnet2/train/test_reporter.py +++ b/test/espnet2/train/test_reporter.py @@ -1,16 +1,13 @@ import logging -from pathlib import Path import uuid +from pathlib import Path import numpy as np import pytest import torch from torch.utils.tensorboard import SummaryWriter -from espnet2.train.reporter import aggregate -from espnet2.train.reporter import Average -from espnet2.train.reporter import ReportedValue -from espnet2.train.reporter import Reporter +from espnet2.train.reporter import Average, ReportedValue, Reporter, aggregate @pytest.mark.parametrize("weight1,weight2", [(None, None), (19, np.array(9))]) diff --git a/test/espnet2/tts/feats_extract/test_log_mel_fbank.py b/test/espnet2/tts/feats_extract/test_log_mel_fbank.py index 28135a0c42e..c9d7c69b6c4 100644 --- a/test/espnet2/tts/feats_extract/test_log_mel_fbank.py +++ b/test/espnet2/tts/feats_extract/test_log_mel_fbank.py @@ -1,8 +1,8 @@ import numpy as np import torch -from espnet.transform.spectrogram import logmelspectrogram from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank +from espnet.transform.spectrogram import logmelspectrogram def test_forward(): diff --git a/test/espnet2/tts/feats_extract/test_log_spectrogram.py b/test/espnet2/tts/feats_extract/test_log_spectrogram.py index 7c30a7185b2..77c97f6797a 100644 --- a/test/espnet2/tts/feats_extract/test_log_spectrogram.py +++ b/test/espnet2/tts/feats_extract/test_log_spectrogram.py @@ -1,8 +1,8 @@ import numpy as np import torch -from espnet.transform.spectrogram import spectrogram from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram +from espnet.transform.spectrogram import spectrogram def test_forward(): diff --git a/test/espnet2/utils/test_build_dataclass.py b/test/espnet2/utils/test_build_dataclass.py index 17606933a3b..a5984c6e0d3 100644 --- a/test/espnet2/utils/test_build_dataclass.py +++ b/test/espnet2/utils/test_build_dataclass.py @@ -1,5 +1,5 @@ -from argparse import Namespace import dataclasses +from argparse import Namespace import pytest diff --git a/test/espnet2/utils/test_sized_dict.py b/test/espnet2/utils/test_sized_dict.py index 3b275d9afb6..67f5dd6267e 100644 --- a/test/espnet2/utils/test_sized_dict.py +++ b/test/espnet2/utils/test_sized_dict.py @@ -5,8 +5,7 @@ import pytest import torch.multiprocessing -from espnet2.utils.sized_dict import get_size -from espnet2.utils.sized_dict import SizedDict +from espnet2.utils.sized_dict import SizedDict, get_size def test_get_size(): diff --git a/test/espnet2/utils/test_types.py b/test/espnet2/utils/test_types.py index cc3d1fbe9fc..519d23f6d71 100644 --- a/test/espnet2/utils/test_types.py +++ b/test/espnet2/utils/test_types.py @@ -3,15 +3,10 @@ import pytest -from espnet2.utils.types import float_or_none -from espnet2.utils.types import humanfriendly_parse_size_or_none -from espnet2.utils.types import int_or_none -from espnet2.utils.types import remove_parenthesis -from espnet2.utils.types import str2bool -from espnet2.utils.types import str2pair_str -from espnet2.utils.types import str2triple_str -from espnet2.utils.types import str_or_int -from espnet2.utils.types import str_or_none +from espnet2.utils.types import (float_or_none, + humanfriendly_parse_size_or_none, int_or_none, + remove_parenthesis, str2bool, str2pair_str, + str2triple_str, str_or_int, str_or_none) @contextmanager diff --git a/test/test_asr_init.py b/test/test_asr_init.py index b9254828ab2..c2136524d2e 100644 --- a/test/test_asr_init.py +++ b/test/test_asr_init.py @@ -10,10 +10,9 @@ import torch import espnet.nets.pytorch_backend.lm.default as lm_pytorch - from espnet.asr.asr_utils import torch_save -from espnet.asr.pytorch_backend.asr_init import freeze_modules -from espnet.asr.pytorch_backend.asr_init import load_trained_modules +from espnet.asr.pytorch_backend.asr_init import (freeze_modules, + load_trained_modules) from espnet.nets.beam_search_transducer import BeamSearchTransducer from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/test/test_batch_beam_search.py b/test/test_batch_beam_search.py index dd40842197f..7b692f12edc 100644 --- a/test/test_batch_beam_search.py +++ b/test/test_batch_beam_search.py @@ -1,20 +1,17 @@ +import os from argparse import Namespace +from test.test_beam_search import prepare, transformer_args import numpy -import os import pytest import torch -from espnet.nets.batch_beam_search import BatchBeamSearch -from espnet.nets.batch_beam_search import BeamSearch +from espnet.nets.batch_beam_search import BatchBeamSearch, BeamSearch from espnet.nets.beam_search import Hypothesis from espnet.nets.lm_interface import dynamic_import_lm from espnet.nets.scorers.length_bonus import LengthBonus from espnet.nets.scorers.ngram import NgramFullScorer -from test.test_beam_search import prepare -from test.test_beam_search import transformer_args - def test_batchfy_hyp(): vocab_size = 5 diff --git a/test/test_custom_transducer.py b/test/test_custom_transducer.py index bf6101365cd..69ce111a86e 100644 --- a/test/test_custom_transducer.py +++ b/test/test_custom_transducer.py @@ -1,18 +1,18 @@ # coding: utf-8 import argparse -from packaging.version import parse as V +import json import tempfile -import json import pytest import torch +from packaging.version import parse as V -from espnet.asr.pytorch_backend.asr_init import load_trained_model import espnet.lm.pytorch_backend.extlm as extlm_pytorch +import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.asr.pytorch_backend.asr_init import load_trained_model from espnet.nets.beam_search_transducer import BeamSearchTransducer from espnet.nets.pytorch_backend.e2e_asr_transducer import E2E -import espnet.nets.pytorch_backend.lm.default as lm_pytorch from espnet.nets.pytorch_backend.transducer.blocks import build_blocks is_torch_1_5_plus = V(torch.__version__) >= V("1.5.0") diff --git a/test/test_e2e_asr.py b/test/test_e2e_asr.py index a9f29478298..33292c26858 100644 --- a/test/test_e2e_asr.py +++ b/test/test_e2e_asr.py @@ -9,20 +9,20 @@ import importlib import os import tempfile +from test.utils_test import make_dummy_json import chainer import numpy as np import pytest import torch -from espnet.asr import asr_utils import espnet.nets.chainer_backend.e2e_asr as ch_asr import espnet.nets.pytorch_backend.e2e_asr as th_asr +from espnet.asr import asr_utils from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E from espnet.utils.training.batchfy import make_batchset -from test.utils_test import make_dummy_json def make_arg(**kwargs): @@ -744,6 +744,7 @@ def test_multi_gpu_trainable(module): loss.backward(loss.new_ones(ngpu)) # trainable else: import copy + import cupy losses = [] diff --git a/test/test_e2e_asr_conformer.py b/test/test_e2e_asr_conformer.py index 783b89d445b..72e8bec6da8 100644 --- a/test/test_e2e_asr_conformer.py +++ b/test/test_e2e_asr_conformer.py @@ -1,4 +1,5 @@ import argparse + import pytest import torch diff --git a/test/test_e2e_asr_maskctc.py b/test/test_e2e_asr_maskctc.py index f9154f3ff27..93031d7db3a 100644 --- a/test/test_e2e_asr_maskctc.py +++ b/test/test_e2e_asr_maskctc.py @@ -1,4 +1,5 @@ import argparse + import pytest import torch diff --git a/test/test_e2e_asr_mulenc.py b/test/test_e2e_asr_mulenc.py index 88a04bac456..bd049f6f78a 100644 --- a/test/test_e2e_asr_mulenc.py +++ b/test/test_e2e_asr_mulenc.py @@ -9,6 +9,7 @@ import importlib import os import tempfile +from test.utils_test import make_dummy_json import numpy as np import pytest @@ -16,7 +17,6 @@ from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet.utils.training.batchfy import make_batchset -from test.utils_test import make_dummy_json def make_arg(num_encs, **kwargs): diff --git a/test/test_e2e_asr_transducer.py b/test/test_e2e_asr_transducer.py index 4a115433cfd..d17116d7b28 100644 --- a/test/test_e2e_asr_transducer.py +++ b/test/test_e2e_asr_transducer.py @@ -1,19 +1,19 @@ # coding: utf-8 import argparse -from packaging.version import parse as V +import json import tempfile -import json import numpy as np import pytest import torch +from packaging.version import parse as V -from espnet.asr.pytorch_backend.asr_init import load_trained_model import espnet.lm.pytorch_backend.extlm as extlm_pytorch +import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.asr.pytorch_backend.asr_init import load_trained_model from espnet.nets.beam_search_transducer import BeamSearchTransducer from espnet.nets.pytorch_backend.e2e_asr_transducer import E2E -import espnet.nets.pytorch_backend.lm.default as lm_pytorch from espnet.nets.pytorch_backend.nets_utils import pad_list is_torch_1_4_plus = V(torch.__version__) >= V("1.4.0") diff --git a/test/test_e2e_asr_transformer.py b/test/test_e2e_asr_transformer.py index 6fd338eefb3..c10e39761b7 100644 --- a/test/test_e2e_asr_transformer.py +++ b/test/test_e2e_asr_transformer.py @@ -1,4 +1,5 @@ import argparse + import chainer import numpy import pytest @@ -7,10 +8,10 @@ import espnet.nets.chainer_backend.e2e_asr_transformer as ch import espnet.nets.pytorch_backend.e2e_asr_transformer as th from espnet.nets.pytorch_backend.nets_utils import rename_state_dict -from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos -from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask -from espnet.nets.pytorch_backend.transformer.mask import target_mask from espnet.nets.pytorch_backend.transformer import plot +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.mask import (subsequent_mask, + target_mask) def test_sequential(): diff --git a/test/test_e2e_compatibility.py b/test/test_e2e_compatibility.py index ea1f1e3b5f1..7930b0bf095 100644 --- a/test/test_e2e_compatibility.py +++ b/test/test_e2e_compatibility.py @@ -8,20 +8,18 @@ import importlib import os -from os.path import join import re import shutil import subprocess import tempfile +from os.path import join import chainer import numpy as np import pytest import torch -from espnet.asr.asr_utils import chainer_load -from espnet.asr.asr_utils import get_model_conf -from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import chainer_load, get_model_conf, torch_load def download_zip_from_google_drive(download_dir, file_id): diff --git a/test/test_e2e_mt.py b/test/test_e2e_mt.py index 4c2158b1856..5d31336f397 100644 --- a/test/test_e2e_mt.py +++ b/test/test_e2e_mt.py @@ -9,6 +9,7 @@ import importlib import os import tempfile +from test.utils_test import make_dummy_json_mt import chainer import numpy as np @@ -17,7 +18,6 @@ from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet.utils.training.batchfy import make_batchset -from test.utils_test import make_dummy_json_mt def make_arg(**kwargs): diff --git a/test/test_e2e_mt_transformer.py b/test/test_e2e_mt_transformer.py index cf2ad01a9ec..698e01ae08a 100644 --- a/test/test_e2e_mt_transformer.py +++ b/test/test_e2e_mt_transformer.py @@ -4,6 +4,7 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse + import pytest import torch diff --git a/test/test_e2e_st.py b/test/test_e2e_st.py index f3e53369128..fec33f139a4 100644 --- a/test/test_e2e_st.py +++ b/test/test_e2e_st.py @@ -9,6 +9,7 @@ import importlib import os import tempfile +from test.utils_test import make_dummy_json_st import chainer import numpy as np @@ -17,7 +18,6 @@ from espnet.nets.pytorch_backend.nets_utils import pad_list from espnet.utils.training.batchfy import make_batchset -from test.utils_test import make_dummy_json_st def make_arg(**kwargs): diff --git a/test/test_e2e_st_conformer.py b/test/test_e2e_st_conformer.py index be0246ce1fc..a6e35c172cf 100644 --- a/test/test_e2e_st_conformer.py +++ b/test/test_e2e_st_conformer.py @@ -3,6 +3,7 @@ # Copyright 2019 Hirofumi Inaguma # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse + import pytest import torch diff --git a/test/test_e2e_st_transformer.py b/test/test_e2e_st_transformer.py index e10622993ef..ea3cb6799db 100644 --- a/test/test_e2e_st_transformer.py +++ b/test/test_e2e_st_transformer.py @@ -3,6 +3,7 @@ # Copyright 2019 Hirofumi Inaguma # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse + import pytest import torch diff --git a/test/test_e2e_tts_fastspeech.py b/test/test_e2e_tts_fastspeech.py index 6b66902746c..da9463cc171 100644 --- a/test/test_e2e_tts_fastspeech.py +++ b/test/test_e2e_tts_fastspeech.py @@ -8,20 +8,20 @@ import os import shutil import tempfile - from argparse import Namespace import numpy as np import pytest import torch -from espnet.nets.pytorch_backend.e2e_tts_fastspeech import FeedForwardTransformer +from espnet.nets.pytorch_backend.e2e_tts_fastspeech import \ + FeedForwardTransformer from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2 from espnet.nets.pytorch_backend.e2e_tts_transformer import Transformer -from espnet.nets.pytorch_backend.fastspeech.duration_calculator import ( - DurationCalculator, # noqa: H301 -) -from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator +from espnet.nets.pytorch_backend.fastspeech.duration_calculator import \ + DurationCalculator # noqa: H301 +from espnet.nets.pytorch_backend.fastspeech.length_regulator import \ + LengthRegulator from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/test/test_e2e_tts_tacotron2.py b/test/test_e2e_tts_tacotron2.py index dd226b9f0a4..07aa18b0fd9 100644 --- a/test/test_e2e_tts_tacotron2.py +++ b/test/test_e2e_tts_tacotron2.py @@ -3,15 +3,14 @@ # Copyright 2019 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -from __future__ import print_function -from __future__ import division +from __future__ import division, print_function + +from argparse import Namespace import numpy as np import pytest import torch -from argparse import Namespace - from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2 from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/test/test_e2e_tts_transformer.py b/test/test_e2e_tts_transformer.py index d4013ebd3cd..c61987247ad 100644 --- a/test/test_e2e_tts_transformer.py +++ b/test/test_e2e_tts_transformer.py @@ -4,14 +4,14 @@ # Copyright 2019 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +from argparse import Namespace + import numpy as np import pytest import torch -from argparse import Namespace - -from espnet.nets.pytorch_backend.e2e_tts_transformer import subsequent_mask -from espnet.nets.pytorch_backend.e2e_tts_transformer import Transformer +from espnet.nets.pytorch_backend.e2e_tts_transformer import (Transformer, + subsequent_mask) from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/test/test_e2e_vc_tacotron2.py b/test/test_e2e_vc_tacotron2.py index abc61d9aff4..26faa463987 100644 --- a/test/test_e2e_vc_tacotron2.py +++ b/test/test_e2e_vc_tacotron2.py @@ -4,15 +4,14 @@ # Copyright 2020 Wen-Chin Huang # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -from __future__ import print_function -from __future__ import division +from __future__ import division, print_function + +from argparse import Namespace import numpy as np import pytest import torch -from argparse import Namespace - from espnet.nets.pytorch_backend.e2e_vc_tacotron2 import Tacotron2 from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/test/test_e2e_vc_transformer.py b/test/test_e2e_vc_transformer.py index 37e3a4ad808..5b0dc34ea2c 100644 --- a/test/test_e2e_vc_transformer.py +++ b/test/test_e2e_vc_transformer.py @@ -4,15 +4,15 @@ # Copyright 2020 Wen-Chin Huang # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +from argparse import Namespace from math import floor + import numpy as np import pytest import torch -from argparse import Namespace - -from espnet.nets.pytorch_backend.e2e_vc_transformer import subsequent_mask -from espnet.nets.pytorch_backend.e2e_vc_transformer import Transformer +from espnet.nets.pytorch_backend.e2e_vc_transformer import (Transformer, + subsequent_mask) from espnet.nets.pytorch_backend.nets_utils import pad_list diff --git a/test/test_lm.py b/test/test_lm.py index 06d0ac68f93..df7c3e790ee 100644 --- a/test/test_lm.py +++ b/test/test_lm.py @@ -1,17 +1,16 @@ +from test.test_beam_search import prepare, rnn_args + import chainer import numpy import pytest import torch import espnet.lm.chainer_backend.lm as lm_chainer +import espnet.nets.pytorch_backend.lm.default as lm_pytorch from espnet.nets.beam_search import beam_search from espnet.nets.lm_interface import dynamic_import_lm -import espnet.nets.pytorch_backend.lm.default as lm_pytorch from espnet.nets.scorers.length_bonus import LengthBonus -from test.test_beam_search import prepare -from test.test_beam_search import rnn_args - def transfer_lstm(ch_lstm, th_lstm): ch_lstm.upward.W.data[:] = 1 diff --git a/test/test_multi_spkrs.py b/test/test_multi_spkrs.py index a11430679b0..647cb7b5384 100644 --- a/test/test_multi_spkrs.py +++ b/test/test_multi_spkrs.py @@ -5,11 +5,11 @@ import argparse import importlib -import numpy import re -import torch +import numpy import pytest +import torch def make_arg(**kwargs): diff --git a/test/test_ngram.py b/test/test_ngram.py index 306e9b277ae..fd9b40a5a60 100644 --- a/test/test_ngram.py +++ b/test/test_ngram.py @@ -1,8 +1,8 @@ import os -import pytest - from math import isclose +import pytest + kenlm = pytest.importorskip("kenlm") diff --git a/test/test_positional_encoding.py b/test/test_positional_encoding.py index 6637a5245c1..305232d62c4 100644 --- a/test/test_positional_encoding.py +++ b/test/test_positional_encoding.py @@ -1,10 +1,8 @@ import pytest import torch - -from espnet.nets.pytorch_backend.transformer.embedding import LearnableFourierPosEnc -from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding -from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ( + LearnableFourierPosEnc, PositionalEncoding, ScaledPositionalEncoding) @pytest.mark.parametrize( diff --git a/test/test_recog.py b/test/test_recog.py index 465331eaf2a..1a4539d58eb 100644 --- a/test/test_recog.py +++ b/test/test_recog.py @@ -10,8 +10,8 @@ import torch import espnet.lm.pytorch_backend.extlm as extlm_pytorch -from espnet.nets.pytorch_backend import e2e_asr import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.nets.pytorch_backend import e2e_asr def make_arg(**kwargs): diff --git a/test/test_scheduler.py b/test/test_scheduler.py index 893e8a53495..fadbfb94bc8 100644 --- a/test/test_scheduler.py +++ b/test/test_scheduler.py @@ -1,12 +1,12 @@ -from espnet.scheduler.chainer import ChainerScheduler -from espnet.scheduler.pytorch import PyTorchScheduler -from espnet.scheduler import scheduler - import chainer import numpy import pytest import torch +from espnet.scheduler import scheduler +from espnet.scheduler.chainer import ChainerScheduler +from espnet.scheduler.pytorch import PyTorchScheduler + @pytest.mark.parametrize("name", scheduler.SCHEDULER_DICT.keys()) def test_scheduler(name): diff --git a/test/test_sentencepiece.py b/test/test_sentencepiece.py index ecfbd92ef99..0b57bc443f0 100644 --- a/test/test_sentencepiece.py +++ b/test/test_sentencepiece.py @@ -2,7 +2,6 @@ import sentencepiece as spm - root = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/test_transformer_decode.py b/test/test_transformer_decode.py index b63fab4784b..5c92c019095 100644 --- a/test/test_transformer_decode.py +++ b/test/test_transformer_decode.py @@ -6,7 +6,6 @@ from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask - RTOL = 1e-4 diff --git a/test/test_utils.py b/test/test_utils.py index 7103bd2e012..5673da4ccc5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 +from test.utils_test import make_dummy_json + import h5py import kaldiio import numpy as np import pytest -from espnet.utils.io_utils import LoadInputsAndTargets -from espnet.utils.io_utils import SoundHDF5File +from espnet.utils.io_utils import LoadInputsAndTargets, SoundHDF5File from espnet.utils.training.batchfy import make_batchset -from test.utils_test import make_dummy_json @pytest.mark.parametrize("swap_io", [True, False]) diff --git a/utils/addjson.py b/utils/addjson.py index 9649352c42f..aef2a0359be 100755 --- a/utils/addjson.py +++ b/utils/addjson.py @@ -10,7 +10,6 @@ import json import logging import sys - from distutils.util import strtobool from espnet.utils.cli_utils import get_commandline_args diff --git a/utils/apply-cmvn.py b/utils/apply-cmvn.py index bf5c6ac05a3..cde56aabd52 100755 --- a/utils/apply-cmvn.py +++ b/utils/apply-cmvn.py @@ -1,15 +1,14 @@ #!/usr/bin/env python3 import argparse -from distutils.util import strtobool import logging +from distutils.util import strtobool import kaldiio import numpy from espnet.transform.cmvn import CMVN from espnet.utils.cli_readers import file_reader_helper -from espnet.utils.cli_utils import get_commandline_args -from espnet.utils.cli_utils import is_scipy_wav_style +from espnet.utils.cli_utils import get_commandline_args, is_scipy_wav_style from espnet.utils.cli_writers import file_writer_helper diff --git a/utils/calculate_rtf.py b/utils/calculate_rtf.py index 6be8dffd8eb..e6e2fb1efa0 100755 --- a/utils/calculate_rtf.py +++ b/utils/calculate_rtf.py @@ -6,10 +6,11 @@ import argparse import codecs -from dateutil import parser import glob import os +from dateutil import parser + def get_parser(): parser = argparse.ArgumentParser(description="calculate real time factor (RTF)") diff --git a/utils/compute-cmvn-stats.py b/utils/compute-cmvn-stats.py index 067daec6cfb..a2f8a586ed7 100755 --- a/utils/compute-cmvn-stats.py +++ b/utils/compute-cmvn-stats.py @@ -7,8 +7,7 @@ from espnet.transform.transformation import Transformation from espnet.utils.cli_readers import file_reader_helper -from espnet.utils.cli_utils import get_commandline_args -from espnet.utils.cli_utils import is_scipy_wav_style +from espnet.utils.cli_utils import get_commandline_args, is_scipy_wav_style from espnet.utils.cli_writers import file_writer_helper diff --git a/utils/compute-fbank-feats.py b/utils/compute-fbank-feats.py index fabded4d5ac..ae1f9a55f0a 100755 --- a/utils/compute-fbank-feats.py +++ b/utils/compute-fbank-feats.py @@ -4,17 +4,17 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -from distutils.util import strtobool import logging +from distutils.util import strtobool import kaldiio import numpy import resampy +from espnet2.utils.types import int_or_none from espnet.transform.spectrogram import logmelspectrogram from espnet.utils.cli_utils import get_commandline_args from espnet.utils.cli_writers import file_writer_helper -from espnet2.utils.types import int_or_none def get_parser(): diff --git a/utils/compute-stft-feats.py b/utils/compute-stft-feats.py index fe4cdc563b2..c8264b00d11 100755 --- a/utils/compute-stft-feats.py +++ b/utils/compute-stft-feats.py @@ -4,17 +4,17 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -from distutils.util import strtobool import logging +from distutils.util import strtobool import kaldiio import numpy import resampy +from espnet2.utils.types import int_or_none from espnet.transform.spectrogram import spectrogram from espnet.utils.cli_utils import get_commandline_args from espnet.utils.cli_writers import file_writer_helper -from espnet2.utils.types import int_or_none def get_parser(): diff --git a/utils/convert_fbank_to_wav.py b/utils/convert_fbank_to_wav.py index ccb4a9c439b..73e6edb3947 100755 --- a/utils/convert_fbank_to_wav.py +++ b/utils/convert_fbank_to_wav.py @@ -7,16 +7,14 @@ import logging import os -from packaging.version import parse as V - import librosa import numpy as np +from packaging.version import parse as V from scipy.io.wavfile import write from espnet.utils.cli_readers import file_reader_helper from espnet.utils.cli_utils import get_commandline_args - EPS = 1e-10 diff --git a/utils/copy-feats.py b/utils/copy-feats.py index 1a43d5737db..ad400ba601e 100755 --- a/utils/copy-feats.py +++ b/utils/copy-feats.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 import argparse -from distutils.util import strtobool import logging +from distutils.util import strtobool from espnet.transform.transformation import Transformation from espnet.utils.cli_readers import file_reader_helper -from espnet.utils.cli_utils import get_commandline_args -from espnet.utils.cli_utils import is_scipy_wav_style +from espnet.utils.cli_utils import get_commandline_args, is_scipy_wav_style from espnet.utils.cli_writers import file_writer_helper diff --git a/utils/dump-pcm.py b/utils/dump-pcm.py index a942a323528..df5a1ecf79c 100755 --- a/utils/dump-pcm.py +++ b/utils/dump-pcm.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import argparse -from distutils.util import strtobool import logging +from distutils.util import strtobool import kaldiio import numpy diff --git a/utils/eval-source-separation.py b/utils/eval-source-separation.py index 6ba1c024cd5..780e3c6b12e 100755 --- a/utils/eval-source-separation.py +++ b/utils/eval-source-separation.py @@ -1,21 +1,21 @@ #!/usr/bin/env python3 import argparse -from collections import OrderedDict -from distutils.util import strtobool import itertools import logging import os -from pathlib import Path import shutil import subprocess import sys -from tempfile import TemporaryDirectory import warnings +from collections import OrderedDict +from distutils.util import strtobool +from pathlib import Path +from tempfile import TemporaryDirectory import museval import numpy as np -from pystoi.stoi import stoi import soundfile +from pystoi.stoi import stoi from espnet.utils.cli_utils import get_commandline_args diff --git a/utils/eval_perm_free_error.py b/utils/eval_perm_free_error.py index 2f1b15132b2..27814bed545 100755 --- a/utils/eval_perm_free_error.py +++ b/utils/eval_perm_free_error.py @@ -8,10 +8,10 @@ import json import logging import re -import six import sys import numpy as np +import six def permutationDFS(source, start, res): diff --git a/utils/feat-to-shape.py b/utils/feat-to-shape.py index 559abcd9e25..5c34bc363a9 100755 --- a/utils/feat-to-shape.py +++ b/utils/feat-to-shape.py @@ -5,8 +5,7 @@ from espnet.transform.transformation import Transformation from espnet.utils.cli_readers import file_reader_helper -from espnet.utils.cli_utils import get_commandline_args -from espnet.utils.cli_utils import is_scipy_wav_style +from espnet.utils.cli_utils import get_commandline_args, is_scipy_wav_style def get_parser(): diff --git a/utils/feats2npy.py b/utils/feats2npy.py index 456e244e735..72efb99cddf 100755 --- a/utils/feats2npy.py +++ b/utils/feats2npy.py @@ -2,11 +2,12 @@ # coding: utf-8 import argparse -from kaldiio import ReadHelper -import numpy as np import os -from os.path import join import sys +from os.path import join + +import numpy as np +from kaldiio import ReadHelper def get_parser(): diff --git a/utils/generate_wav_from_fbank.py b/utils/generate_wav_from_fbank.py index 1664b418797..a4b6fbd2f2a 100755 --- a/utils/generate_wav_from_fbank.py +++ b/utils/generate_wav_from_fbank.py @@ -15,13 +15,11 @@ import numpy as np import pysptk import torch - from scipy.io.wavfile import write from sklearn.preprocessing import StandardScaler -from espnet.nets.pytorch_backend.wavenet import decode_mu_law -from espnet.nets.pytorch_backend.wavenet import encode_mu_law -from espnet.nets.pytorch_backend.wavenet import WaveNet +from espnet.nets.pytorch_backend.wavenet import (WaveNet, decode_mu_law, + encode_mu_law) from espnet.utils.cli_readers import file_reader_helper from espnet.utils.cli_utils import get_commandline_args diff --git a/utils/json2sctm.py b/utils/json2sctm.py index e482f958a9c..45c6085ffac 100644 --- a/utils/json2sctm.py +++ b/utils/json2sctm.py @@ -6,7 +6,6 @@ import subprocess import sys - is_python2 = sys.version_info[0] == 2 @@ -29,9 +28,7 @@ def get_parser(): def main(args): - from utils import json2trn - from utils import trn2ctm - from utils import trn2stm + from utils import json2trn, trn2ctm, trn2stm parser = get_parser() args = parser.parse_args(args) diff --git a/utils/make_pair_json.py b/utils/make_pair_json.py index 236fc6a839b..4002a010ba7 100755 --- a/utils/make_pair_json.py +++ b/utils/make_pair_json.py @@ -5,10 +5,10 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -from io import open import json import logging import sys +from io import open from espnet.utils.cli_utils import get_commandline_args diff --git a/utils/mcd_calculate.py b/utils/mcd_calculate.py index 4504f2eb396..d2575a02e7d 100755 --- a/utils/mcd_calculate.py +++ b/utils/mcd_calculate.py @@ -11,14 +11,13 @@ import multiprocessing as mp import os -from fastdtw import fastdtw import numpy as np import pysptk import pyworld as pw import scipy +from fastdtw import fastdtw from scipy.io import wavfile -from scipy.signal import firwin -from scipy.signal import lfilter +from scipy.signal import firwin, lfilter def find_files(root_dir, query="*.wav", include_root_dir=True): diff --git a/utils/merge_scp2json.py b/utils/merge_scp2json.py index 8ee4aef48be..269e803238b 100755 --- a/utils/merge_scp2json.py +++ b/utils/merge_scp2json.py @@ -4,11 +4,11 @@ import argparse import codecs -from distutils.util import strtobool -from io import open import json import logging import sys +from distutils.util import strtobool +from io import open from espnet.utils.cli_utils import get_commandline_args diff --git a/utils/spm_train b/utils/spm_train index 0b247aee0dc..134a0b1d30a 100755 --- a/utils/spm_train +++ b/utils/spm_train @@ -8,6 +8,5 @@ import sys import sentencepiece as spm - if __name__ == "__main__": spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) diff --git a/utils/text2vocabulary.py b/utils/text2vocabulary.py index b0737d460cd..b45bc645e69 100755 --- a/utils/text2vocabulary.py +++ b/utils/text2vocabulary.py @@ -6,9 +6,10 @@ import argparse import codecs import logging -import six import sys +import six + is_python2 = sys.version_info[0] == 2