diff --git a/egs2/TEMPLATE/asr1/asr.sh b/egs2/TEMPLATE/asr1/asr.sh index 65a0048eed9..763aceb7a34 100755 --- a/egs2/TEMPLATE/asr1/asr.sh +++ b/egs2/TEMPLATE/asr1/asr.sh @@ -755,7 +755,7 @@ if ! "${skip_train}"; then log "LM collect-stats started... log: '${_logdir}/stats.*.log'" # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted, # but it's used only for deciding the sample ids. - # shellcheck disable=SC2086 + # shellcheck disable=SC2046,SC2086 ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \ ${python} -m espnet2.bin.lm_train \ --collect_stats true \ @@ -967,7 +967,7 @@ if ! "${skip_train}"; then # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted, # but it's used only for deciding the sample ids. - # shellcheck disable=SC2086 + # shellcheck disable=SC2046,SC2086 ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \ ${python} -m espnet2.bin.asr_train \ --collect_stats true \ @@ -1242,7 +1242,7 @@ if ! "${skip_eval}"; then # 2. Submit decoding jobs log "Decoding started... log: '${_logdir}/asr_inference.*.log'" - # shellcheck disable=SC2086 + # shellcheck disable=SC2046,SC2086 ${_cmd} --gpu "${_ngpu}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ ${python} -m ${asr_inference_tool} \ --batch_size ${batch_size} \ diff --git a/espnet/asr/pytorch_backend/asr.py b/espnet/asr/pytorch_backend/asr.py index a83d9a27dc1..0effaaaa893 100644 --- a/espnet/asr/pytorch_backend/asr.py +++ b/espnet/asr/pytorch_backend/asr.py @@ -4,12 +4,12 @@ """Training/decoding definition for the speech recognition task.""" import copy -from packaging.version import parse as V import itertools import json import logging import math import os +from packaging.version import parse as V from chainer import reporter as reporter_module from chainer import training @@ -999,9 +999,7 @@ def recog(args): # Dunno why but weight_observer from dynamic quantized module must have # dtype=torch.qint8 with torch < 1.5 although dtype=torch.float16 is supported. - if args.quantize_dtype == "float16" and torch.__version__ < V( - "1.5.0" - ): + if args.quantize_dtype == "float16" and torch.__version__ < V("1.5.0"): raise ValueError( "float16 dtype for dynamic quantization is not supported with torch " "version < 1.5.0. Switching to qint8 dtype instead." diff --git a/espnet/asr/pytorch_backend/recog.py b/espnet/asr/pytorch_backend/recog.py index 68fea23a144..b64131d1ad2 100644 --- a/espnet/asr/pytorch_backend/recog.py +++ b/espnet/asr/pytorch_backend/recog.py @@ -62,9 +62,7 @@ def recog_v2(args): "Quantized LSTM in ESPnet is only supported with torch 1.4+." ) - if args.quantize_dtype == "float16" and torch.__version__ < V( - "1.5.0" - ): + if args.quantize_dtype == "float16" and torch.__version__ < V("1.5.0"): raise ValueError( "float16 dtype for dynamic quantization is not supported with torch " "version < 1.5.0. Switching to qint8 dtype instead." diff --git a/espnet/nets/pytorch_backend/ctc.py b/espnet/nets/pytorch_backend/ctc.py index c974df09b7a..96b2e4f52b9 100644 --- a/espnet/nets/pytorch_backend/ctc.py +++ b/espnet/nets/pytorch_backend/ctc.py @@ -1,5 +1,5 @@ -from packaging.version import parse as V import logging +from packaging.version import parse as V import numpy as np import six @@ -28,11 +28,7 @@ def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True): self.probs = None # for visualization # In case of Pytorch >= 1.7.0, CTC will be always builtin - self.ctc_type = ( - ctc_type - if V(torch.__version__) < V("1.7.0") - else "builtin" - ) + self.ctc_type = ctc_type if V(torch.__version__) < V("1.7.0") else "builtin" if ctc_type != self.ctc_type: logging.warning(f"CTC was set to {self.ctc_type} due to PyTorch version.") diff --git a/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py b/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py index c5a3069e53c..8c9f2bcb232 100644 --- a/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py +++ b/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py @@ -576,7 +576,7 @@ def _forward( alpha=1.0, ): # forward encoder - x_masks = self._source_mask(ilens) + x_masks = self._source_mask(ilens).to(xs.device) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate speaker embedding @@ -603,7 +603,7 @@ def _forward( olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens - h_masks = self._source_mask(olens_in) + h_masks = self._source_mask(olens_in).to(xs.device) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) @@ -816,7 +816,7 @@ def _source_mask(self, ilens): [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ - x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + x_masks = make_non_pad_mask(ilens) return x_masks.unsqueeze(-2) def _load_teacher_model(self, model_path): diff --git a/espnet/nets/pytorch_backend/e2e_vc_transformer.py b/espnet/nets/pytorch_backend/e2e_vc_transformer.py index c4e0144d412..99fd3f3962b 100644 --- a/espnet/nets/pytorch_backend/e2e_vc_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_vc_transformer.py @@ -673,7 +673,7 @@ def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs): xs_ds, ilens_ds = xs, ilens # forward encoder - x_masks = self._source_mask(ilens_ds) + x_masks = self._source_mask(ilens_ds).to(xs.device) hs, hs_masks = self.encoder(xs_ds, x_masks) # integrate speaker embedding @@ -701,7 +701,7 @@ def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs): ilens_ds_st = ilens_ds # forward decoder - y_masks = self._target_mask(olens_in) + y_masks = self._target_mask(olens_in).to(xs.device) zs, _ = self.decoder(ys_in, y_masks, hs_int, hs_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) @@ -977,7 +977,7 @@ def calculate_all_attentions( xs_ds, ilens_ds = xs, ilens # forward encoder - x_masks = self._source_mask(ilens_ds) + x_masks = self._source_mask(ilens_ds).to(xs.device) hs, hs_masks = self.encoder(xs_ds, x_masks) # integrate speaker embedding @@ -996,7 +996,7 @@ def calculate_all_attentions( ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder - y_masks = self._target_mask(olens_in) + y_masks = self._target_mask(olens_in).to(xs.device) zs, _ = self.decoder(ys_in, y_masks, hs, hs_masks) # calculate final outputs @@ -1099,7 +1099,7 @@ def _source_mask(self, ilens): [[1, 1, 1, 0, 0]]], dtype=torch.uint8) """ - x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + x_masks = make_non_pad_mask(ilens) return x_masks.unsqueeze(-2) def _target_mask(self, olens): @@ -1128,7 +1128,7 @@ def _target_mask(self, olens): [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ - y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) + y_masks = make_non_pad_mask(olens) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks diff --git a/espnet/nets/pytorch_backend/nets_utils.py b/espnet/nets/pytorch_backend/nets_utils.py index 3a7b1e079bc..638b0b0bf23 100644 --- a/espnet/nets/pytorch_backend/nets_utils.py +++ b/espnet/nets/pytorch_backend/nets_utils.py @@ -151,10 +151,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): raise ValueError("length_dim cannot be 0: {}".format(length_dim)) if not isinstance(lengths, list): - lengths = lengths.tolist() - else: - assert isinstance(lengths, torch.Tensor), type(lengths) - lengths = lengths.long() + lengths = lengths.long().tolist() bs = int(len(lengths)) if maxlen is None: diff --git a/espnet2/asr/espnet_model.py b/espnet2/asr/espnet_model.py index 5756598d2ff..67698e95115 100644 --- a/espnet2/asr/espnet_model.py +++ b/espnet2/asr/espnet_model.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from packaging.version import parse as V import logging +from packaging.version import parse as V from typing import Dict from typing import List from typing import Optional diff --git a/espnet2/asr/maskctc_model.py b/espnet2/asr/maskctc_model.py index 10d91de94c5..2a95eec89ea 100644 --- a/espnet2/asr/maskctc_model.py +++ b/espnet2/asr/maskctc_model.py @@ -1,7 +1,7 @@ from contextlib import contextmanager -from packaging.version import parse as V from itertools import groupby import logging +from packaging.version import parse as V from typing import Dict from typing import List from typing import Optional diff --git a/espnet2/diar/espnet_model.py b/espnet2/diar/espnet_model.py index 92b434e7642..2017316f70f 100644 --- a/espnet2/diar/espnet_model.py +++ b/espnet2/diar/espnet_model.py @@ -2,8 +2,8 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) from contextlib import contextmanager -from packaging.version import parse as V from itertools import permutations +from packaging.version import parse as V from typing import Dict from typing import Optional from typing import Tuple diff --git a/espnet2/enh/espnet_enh_s2t_model.py b/espnet2/enh/espnet_enh_s2t_model.py index c2e05654fce..4d37ce0b0c0 100644 --- a/espnet2/enh/espnet_enh_s2t_model.py +++ b/espnet2/enh/espnet_enh_s2t_model.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from packaging.version import parse as V import logging +from packaging.version import parse as V import random from typing import Dict from typing import List diff --git a/espnet2/enh/loss/criterions/tf_domain.py b/espnet2/enh/loss/criterions/tf_domain.py index cb81d7cf25d..4c4a91ef5d2 100644 --- a/espnet2/enh/loss/criterions/tf_domain.py +++ b/espnet2/enh/loss/criterions/tf_domain.py @@ -1,8 +1,8 @@ from abc import ABC from abc import abstractmethod -from packaging.version import parse as V from functools import reduce import math +from packaging.version import parse as V import torch import torch.nn.functional as F diff --git a/espnet2/mt/espnet_model.py b/espnet2/mt/espnet_model.py index b937cbe3dfd..8a493366046 100644 --- a/espnet2/mt/espnet_model.py +++ b/espnet2/mt/espnet_model.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from packaging.version import parse as V import logging +from packaging.version import parse as V from typing import Dict from typing import List from typing import Optional diff --git a/espnet2/st/espnet_model.py b/espnet2/st/espnet_model.py index fb8fcfdaee9..743b53d8288 100644 --- a/espnet2/st/espnet_model.py +++ b/espnet2/st/espnet_model.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from packaging.version import parse as V import logging +from packaging.version import parse as V from typing import Dict from typing import List from typing import Optional diff --git a/espnet2/tasks/abs_task.py b/espnet2/tasks/abs_task.py index 54c4cd26a43..0f23feaa93d 100644 --- a/espnet2/tasks/abs_task.py +++ b/espnet2/tasks/abs_task.py @@ -3,10 +3,10 @@ from abc import abstractmethod import argparse from dataclasses import dataclass -from packaging.version import parse as V import functools import logging import os +from packaging.version import parse as V from pathlib import Path import sys from typing import Any diff --git a/espnet2/train/reporter.py b/espnet2/train/reporter.py index 65b4ac6a9d8..be1d2a51fe5 100644 --- a/espnet2/train/reporter.py +++ b/espnet2/train/reporter.py @@ -3,8 +3,8 @@ from contextlib import contextmanager import dataclasses import datetime -from packaging.version import parse as V import logging +from packaging.version import parse as V from pathlib import Path import time from typing import ContextManager diff --git a/espnet2/train/trainer.py b/espnet2/train/trainer.py index 6fe2726880d..da8ea6144b4 100644 --- a/espnet2/train/trainer.py +++ b/espnet2/train/trainer.py @@ -3,8 +3,8 @@ from contextlib import contextmanager import dataclasses from dataclasses import is_dataclass -from packaging.version import parse as V import logging +from packaging.version import parse as V from pathlib import Path import time from typing import Dict diff --git a/espnet2/utils/griffin_lim.py b/espnet2/utils/griffin_lim.py index 3d4a948b7aa..c9b08cd1235 100644 --- a/espnet2/utils/griffin_lim.py +++ b/espnet2/utils/griffin_lim.py @@ -7,8 +7,8 @@ import logging -from packaging.version import parse as V from functools import partial +from packaging.version import parse as V from typeguard import check_argument_types from typing import Optional