Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
kamo-naoyuki committed May 12, 2022
1 parent 934b161 commit bb0d0aa
Show file tree
Hide file tree
Showing 18 changed files with 29 additions and 40 deletions.
6 changes: 3 additions & 3 deletions egs2/TEMPLATE/asr1/asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ if ! "${skip_train}"; then
log "LM collect-stats started... log: '${_logdir}/stats.*.log'"
# NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
# but it's used only for deciding the sample ids.
# shellcheck disable=SC2086
# shellcheck disable=SC2046,SC2086
${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
${python} -m espnet2.bin.lm_train \
--collect_stats true \
Expand Down Expand Up @@ -967,7 +967,7 @@ if ! "${skip_train}"; then
# NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
# but it's used only for deciding the sample ids.

# shellcheck disable=SC2086
# shellcheck disable=SC2046,SC2086
${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
${python} -m espnet2.bin.asr_train \
--collect_stats true \
Expand Down Expand Up @@ -1242,7 +1242,7 @@ if ! "${skip_eval}"; then

# 2. Submit decoding jobs
log "Decoding started... log: '${_logdir}/asr_inference.*.log'"
# shellcheck disable=SC2086
# shellcheck disable=SC2046,SC2086
${_cmd} --gpu "${_ngpu}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
${python} -m ${asr_inference_tool} \
--batch_size ${batch_size} \
Expand Down
6 changes: 2 additions & 4 deletions espnet/asr/pytorch_backend/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"""Training/decoding definition for the speech recognition task."""

import copy
from packaging.version import parse as V
import itertools
import json
import logging
import math
import os
from packaging.version import parse as V

from chainer import reporter as reporter_module
from chainer import training
Expand Down Expand Up @@ -999,9 +999,7 @@ def recog(args):

# Dunno why but weight_observer from dynamic quantized module must have
# dtype=torch.qint8 with torch < 1.5 although dtype=torch.float16 is supported.
if args.quantize_dtype == "float16" and torch.__version__ < V(
"1.5.0"
):
if args.quantize_dtype == "float16" and torch.__version__ < V("1.5.0"):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch "
"version < 1.5.0. Switching to qint8 dtype instead."
Expand Down
4 changes: 1 addition & 3 deletions espnet/asr/pytorch_backend/recog.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def recog_v2(args):
"Quantized LSTM in ESPnet is only supported with torch 1.4+."
)

if args.quantize_dtype == "float16" and torch.__version__ < V(
"1.5.0"
):
if args.quantize_dtype == "float16" and torch.__version__ < V("1.5.0"):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch "
"version < 1.5.0. Switching to qint8 dtype instead."
Expand Down
8 changes: 2 additions & 6 deletions espnet/nets/pytorch_backend/ctc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from packaging.version import parse as V
import logging
from packaging.version import parse as V

import numpy as np
import six
Expand Down Expand Up @@ -28,11 +28,7 @@ def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True):
self.probs = None # for visualization

# In case of Pytorch >= 1.7.0, CTC will be always builtin
self.ctc_type = (
ctc_type
if V(torch.__version__) < V("1.7.0")
else "builtin"
)
self.ctc_type = ctc_type if V(torch.__version__) < V("1.7.0") else "builtin"

if ctc_type != self.ctc_type:
logging.warning(f"CTC was set to {self.ctc_type} due to PyTorch version.")
Expand Down
6 changes: 3 additions & 3 deletions espnet/nets/pytorch_backend/e2e_tts_fastspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def _forward(
alpha=1.0,
):
# forward encoder
x_masks = self._source_mask(ilens)
x_masks = self._source_mask(ilens).to(xs.device)
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)

# integrate speaker embedding
Expand All @@ -603,7 +603,7 @@ def _forward(
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
olens_in = olens
h_masks = self._source_mask(olens_in)
h_masks = self._source_mask(olens_in).to(xs.device)
else:
h_masks = None
zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim)
Expand Down Expand Up @@ -816,7 +816,7 @@ def _source_mask(self, ilens):
[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
x_masks = make_non_pad_mask(ilens)
return x_masks.unsqueeze(-2)

def _load_teacher_model(self, model_path):
Expand Down
12 changes: 6 additions & 6 deletions espnet/nets/pytorch_backend/e2e_vc_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs):
xs_ds, ilens_ds = xs, ilens

# forward encoder
x_masks = self._source_mask(ilens_ds)
x_masks = self._source_mask(ilens_ds).to(xs.device)
hs, hs_masks = self.encoder(xs_ds, x_masks)

# integrate speaker embedding
Expand Down Expand Up @@ -701,7 +701,7 @@ def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs):
ilens_ds_st = ilens_ds

# forward decoder
y_masks = self._target_mask(olens_in)
y_masks = self._target_mask(olens_in).to(xs.device)
zs, _ = self.decoder(ys_in, y_masks, hs_int, hs_masks)
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
Expand Down Expand Up @@ -977,7 +977,7 @@ def calculate_all_attentions(
xs_ds, ilens_ds = xs, ilens

# forward encoder
x_masks = self._source_mask(ilens_ds)
x_masks = self._source_mask(ilens_ds).to(xs.device)
hs, hs_masks = self.encoder(xs_ds, x_masks)

# integrate speaker embedding
Expand All @@ -996,7 +996,7 @@ def calculate_all_attentions(
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

# forward decoder
y_masks = self._target_mask(olens_in)
y_masks = self._target_mask(olens_in).to(xs.device)
zs, _ = self.decoder(ys_in, y_masks, hs, hs_masks)

# calculate final outputs
Expand Down Expand Up @@ -1099,7 +1099,7 @@ def _source_mask(self, ilens):
[[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
x_masks = make_non_pad_mask(ilens)
return x_masks.unsqueeze(-2)

def _target_mask(self, olens):
Expand Down Expand Up @@ -1128,7 +1128,7 @@ def _target_mask(self, olens):
[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
y_masks = make_non_pad_mask(olens)
s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0)
return y_masks.unsqueeze(-2) & s_masks

Expand Down
5 changes: 1 addition & 4 deletions espnet/nets/pytorch_backend/nets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
raise ValueError("length_dim cannot be 0: {}".format(length_dim))

if not isinstance(lengths, list):
lengths = lengths.tolist()
else:
assert isinstance(lengths, torch.Tensor), type(lengths)
lengths = lengths.long()
lengths = lengths.long().tolist()

bs = int(len(lengths))
if maxlen is None:
Expand Down
2 changes: 1 addition & 1 deletion espnet2/asr/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from packaging.version import parse as V
import logging
from packaging.version import parse as V
from typing import Dict
from typing import List
from typing import Optional
Expand Down
2 changes: 1 addition & 1 deletion espnet2/asr/maskctc_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from packaging.version import parse as V
from itertools import groupby
import logging
from packaging.version import parse as V
from typing import Dict
from typing import List
from typing import Optional
Expand Down
2 changes: 1 addition & 1 deletion espnet2/diar/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

from contextlib import contextmanager
from packaging.version import parse as V
from itertools import permutations
from packaging.version import parse as V
from typing import Dict
from typing import Optional
from typing import Tuple
Expand Down
2 changes: 1 addition & 1 deletion espnet2/enh/espnet_enh_s2t_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from packaging.version import parse as V
import logging
from packaging.version import parse as V
import random
from typing import Dict
from typing import List
Expand Down
2 changes: 1 addition & 1 deletion espnet2/enh/loss/criterions/tf_domain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC
from abc import abstractmethod
from packaging.version import parse as V
from functools import reduce
import math
from packaging.version import parse as V

import torch
import torch.nn.functional as F
Expand Down
2 changes: 1 addition & 1 deletion espnet2/mt/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from packaging.version import parse as V
import logging
from packaging.version import parse as V
from typing import Dict
from typing import List
from typing import Optional
Expand Down
2 changes: 1 addition & 1 deletion espnet2/st/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from packaging.version import parse as V
import logging
from packaging.version import parse as V
from typing import Dict
from typing import List
from typing import Optional
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/abs_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from abc import abstractmethod
import argparse
from dataclasses import dataclass
from packaging.version import parse as V
import functools
import logging
import os
from packaging.version import parse as V
from pathlib import Path
import sys
from typing import Any
Expand Down
2 changes: 1 addition & 1 deletion espnet2/train/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from contextlib import contextmanager
import dataclasses
import datetime
from packaging.version import parse as V
import logging
from packaging.version import parse as V
from pathlib import Path
import time
from typing import ContextManager
Expand Down
2 changes: 1 addition & 1 deletion espnet2/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
from packaging.version import parse as V
import logging
from packaging.version import parse as V
from pathlib import Path
import time
from typing import Dict
Expand Down
2 changes: 1 addition & 1 deletion espnet2/utils/griffin_lim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import logging

from packaging.version import parse as V
from functools import partial
from packaging.version import parse as V
from typeguard import check_argument_types
from typing import Optional

Expand Down

0 comments on commit bb0d0aa

Please sign in to comment.