Skip to content

Commit

Permalink
change LooseVersion to parse
Browse files Browse the repository at this point in the history
  • Loading branch information
kamo-naoyuki committed May 12, 2022
1 parent f899a05 commit 1c344a9
Show file tree
Hide file tree
Showing 59 changed files with 133 additions and 163 deletions.
2 changes: 1 addition & 1 deletion ci/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ python3 -m pip freeze
# Check pytorch version
python3 <<EOF
import torch
from distutils.version import LooseVersion as L
from packaging.version import parse as L
version = '$TH_VERSION'.split(".")
next_version = f"{version[0]}.{version[1]}.{int(version[2]) + 1}"
Expand Down
14 changes: 7 additions & 7 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ echo "==== use_streaming, feats_type=raw, token_types=bpe, model_conf.extract_fe
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --encoder=contextual_block_transformer --decoder=transformer
--encoder_conf block_size=40 --encoder_conf hop_size=16 --encoder_conf look_ahead=16"

if python3 -c "import k2" &> /dev/null; then
echo "==== use_k2, num_paths > nll_batch_size, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --num_paths 500 --nll_batch_size 20 --use_k2 true --ngpu 0 --stage 12 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"

echo "==== use_k2, num_paths == nll_batch_size, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --num_paths 20 --nll_batch_size 20 --use_k2 true --ngpu 0 --stage 12 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
Expand All @@ -68,15 +68,15 @@ rm -rf exp dump data
# NOTE(kan-bayashi): pytorch 1.4 - 1.6 works but 1.6 has a problem with CPU,
# so we test this recipe using only pytorch > 1.6 here.
# See also: https://github.com/pytorch/pytorch/issues/42446
if python3 -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) > L("1.6")' &> /dev/null; then
if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) > L("1.6")' &> /dev/null; then
./run.sh --fs 22050 --tts_task gan_tts --feats_extract linear_spectrogram --feats_normalize none --inference_model latest.pth \
--ngpu 0 --stop-stage 8 --skip-upload false --train-args "--num_iters_per_epoch 1 --max_epoch 1" --python "${python}"
rm -rf exp dump data
fi
cd "${cwd}"

# [ESPnet2] test enh recipe
if python -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) >= L("1.2.0")' &> /dev/null; then
if python -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.2.0")' &> /dev/null; then
cd ./egs2/mini_an4/enh1
echo "==== [ESPnet2] ENH ==="
./run.sh --stage 1 --stop-stage 1 --python "${python}"
Expand All @@ -101,7 +101,7 @@ if python3 -c "import fairseq" &> /dev/null; then
fi

# [ESPnet2] test enh_asr1 recipe
if python -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) >= L("1.2.0")' &> /dev/null; then
if python -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.2.0")' &> /dev/null; then
cd ./egs2/mini_an4/enh_asr1
echo "==== [ESPnet2] ENH_ASR ==="
./run.sh --ngpu 0 --stage 0 --stop-stage 15 --skip-upload_hf false --feats-type "raw" --spk-num 1 --enh_asr_args "--max_epoch=1 --enh_separator_conf num_spk=1" --python "${python}"
Expand All @@ -122,7 +122,7 @@ done
for t in ${token_types}; do
./run.sh --stage 5 --stop-stage 5 --tgt_token_type "${t}" --src_token_type "${t}" --python "${python}"
done
for t in ${feats_types}; do
for t in ${feats_types}; do
for t2 in ${token_types}; do
echo "==== feats_type=${t}, token_types=${t2} ==="
./run.sh --ngpu 0 --stage 6 --stop-stage 13 --skip-upload false --feats-type "${t}" --tgt_token_type "${t2}" --src_token_type "${t2}" \
Expand All @@ -147,7 +147,7 @@ cd "${cwd}"
# [ESPnet2] Validate configuration files
echo "<blank>" > dummy_token_list
echo "==== [ESPnet2] Validation configuration files ==="
if python3 -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) >= L("1.8.0")' &> /dev/null; then
if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.8.0")' &> /dev/null; then
for f in egs2/*/asr1/conf/train_asr*.yaml; do
if [ "$f" == "egs2/fsc/asr1/conf/train_asr.yaml" ]; then
if ! python3 -c "import s3prl" > /dev/null; then
Expand Down
16 changes: 8 additions & 8 deletions egs2/aishell4/enh1/local/generate_fe_trainingdata.py.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
+++ generate_fe_trainingdata.new.py
@@ -1,8 +1,8 @@
#!/usr/bin/env python

-import io
+from distutils.version import LooseVersion
+from packaging.version import parse as V
import os
-import subprocess
+import sys
Expand All @@ -14,27 +14,27 @@
@@ -12,6 +12,10 @@
import librosa
import argparse

+
+is_py_3_3_plus = LooseVersion(sys.version) > LooseVersion("3.3")
+is_py_3_3_plus = V(sys.version) > V("3.3")
+
+
def get_line_context(file_path, line_number):
return linecache.getline(file_path, line_number).strip()

@@ -119,7 +123,7 @@
return data / max_val

def add_noise(clean, noise, rir, snr):
- random.seed(time.clock())
+ random.seed(time.perf_counter() if is_py_3_3_plus else time.clock())
if len(noise.shape) == 1 and len(clean.shape) > 1:
noise = add_reverb(noise, rir[:, 16:24])
noise = noise[:-7999]
@@ -189,7 +193,7 @@

for i in range(args.wavnum):

- random.seed(time.clock())
+ random.seed(time.perf_counter() if is_py_3_3_plus else time.clock())
wav1idx = random.randint(0, len(open(wavlist1,'r').readlines())-1)
Expand Down
2 changes: 1 addition & 1 deletion egs2/fsc/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ train_set="train"
valid_set="valid"
test_sets="test valid"

if python3 -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
asr_config=conf/train_asr.yaml
else
asr_config=conf/tuning/train_asr_transformer_adam_specaug.yaml #s3prl is installed when pytorch > 1.7. Hence using default frontend
Expand Down
2 changes: 1 addition & 1 deletion egs2/fsc_challenge/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ train_set="train"
valid_set="valid"
test_sets="test valid"

if python3 -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
asr_config=conf/train_asr.yaml
else
asr_config=conf/tuning/train_asr_transformer_adam_specaug.yaml #s3prl is installed when pytorch > 1.7. Hence using default frontend
Expand Down
2 changes: 1 addition & 1 deletion egs2/fsc_unseen/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ train_set="train"
valid_set="valid"
test_sets="test valid"

if python3 -c 'import torch as t; from distutils.version import LooseVersion as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
asr_config=conf/train_asr.yaml
else
asr_config=conf/tuning/train_asr_transformer_adam_specaug.yaml #s3prl is installed when pytorch > 1.7. Hence using default frontend
Expand Down
6 changes: 3 additions & 3 deletions espnet/asr/pytorch_backend/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Training/decoding definition for the speech recognition task."""

import copy
from distutils.version import LooseVersion
from packaging.version import parse as V
import itertools
import json
import logging
Expand Down Expand Up @@ -989,7 +989,7 @@ def recog(args):
# It seems quantized LSTM only supports non-packed sequence before torch 1.4.0.
# Reference issue: https://github.com/pytorch/pytorch/issues/27963
if (
torch.__version__ < LooseVersion("1.4.0")
torch.__version__ < V("1.4.0")
and "lstm" in train_args.etype
and torch.nn.LSTM in q_config
):
Expand All @@ -999,7 +999,7 @@ def recog(args):

# Dunno why but weight_observer from dynamic quantized module must have
# dtype=torch.qint8 with torch < 1.5 although dtype=torch.float16 is supported.
if args.quantize_dtype == "float16" and torch.__version__ < LooseVersion(
if args.quantize_dtype == "float16" and torch.__version__ < V(
"1.5.0"
):
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions espnet/asr/pytorch_backend/recog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""

from distutils.version import LooseVersion
from packaging.version import parse as V
import json
import logging

Expand Down Expand Up @@ -54,15 +54,15 @@ def recog_v2(args):

# See https://github.com/espnet/espnet/pull/3616 for more information.
if (
torch.__version__ < LooseVersion("1.4.0")
torch.__version__ < V("1.4.0")
and "lstm" in train_args.etype
and torch.nn.LSTM in q_config
):
raise ValueError(
"Quantized LSTM in ESPnet is only supported with torch 1.4+."
)

if args.quantize_dtype == "float16" and torch.__version__ < LooseVersion(
if args.quantize_dtype == "float16" and torch.__version__ < V(
"1.5.0"
):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions espnet/nets/pytorch_backend/ctc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from distutils.version import LooseVersion
from packaging.version import parse as V
import logging

import numpy as np
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True):
# In case of Pytorch >= 1.7.0, CTC will be always builtin
self.ctc_type = (
ctc_type
if LooseVersion(torch.__version__) < LooseVersion("1.7.0")
if V(torch.__version__) < V("1.7.0")
else "builtin"
)

Expand Down
12 changes: 6 additions & 6 deletions espnet/nets/pytorch_backend/e2e_tts_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs):
labels = labels[:, :max_olen]

# forward encoder
x_masks = self._source_mask(ilens)
x_masks = self._source_mask(ilens).to(xs.device)
hs, h_masks = self.encoder(xs, x_masks)

# integrate speaker embedding
Expand All @@ -732,7 +732,7 @@ def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs):
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

# forward decoder
y_masks = self._target_mask(olens_in)
y_masks = self._target_mask(olens_in).to(xs.device)
zs, _ = self.decoder(ys_in, y_masks, hs, h_masks)
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
Expand Down Expand Up @@ -975,7 +975,7 @@ def calculate_all_attentions(
self.eval()
with torch.no_grad():
# forward encoder
x_masks = self._source_mask(ilens)
x_masks = self._source_mask(ilens).to(xs.device)
hs, h_masks = self.encoder(xs, x_masks)

# integrate speaker embedding
Expand All @@ -994,7 +994,7 @@ def calculate_all_attentions(
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

# forward decoder
y_masks = self._target_mask(olens_in)
y_masks = self._target_mask(olens_in).to(xs.device)
zs, _ = self.decoder(ys_in, y_masks, hs, h_masks)

# calculate final outputs
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def _source_mask(self, ilens):
[[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
x_masks = make_non_pad_mask(ilens)
return x_masks.unsqueeze(-2)

def _target_mask(self, olens):
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def _target_mask(self, olens):
[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
y_masks = make_non_pad_mask(olens)
s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0)
return y_masks.unsqueeze(-2) & s_masks

Expand Down
4 changes: 2 additions & 2 deletions espnet2/asr/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from packaging.version import parse as V
import logging
from typing import Dict
from typing import List
Expand Down Expand Up @@ -29,7 +29,7 @@
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
Expand Down
4 changes: 2 additions & 2 deletions espnet2/asr/maskctc_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from packaging.version import parse as V
from itertools import groupby
import logging
from typing import Dict
Expand Down Expand Up @@ -31,7 +31,7 @@
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import force_gatherable

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
Expand Down
4 changes: 2 additions & 2 deletions espnet2/bin/tts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
import time

from distutils.version import LooseVersion
from packaging.version import parse as V
from pathlib import Path
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -300,7 +300,7 @@ def from_pretrained(
from parallel_wavegan import __version__

# NOTE(kan-bayashi): Filelock download is supported from 0.5.2
assert LooseVersion(__version__) > LooseVersion("0.5.1"), (
assert V(__version__) > V("0.5.1"), (
"Please install the latest parallel_wavegan "
"via `pip install -U parallel_wavegan`."
)
Expand Down
4 changes: 2 additions & 2 deletions espnet2/diar/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

from contextlib import contextmanager
from distutils.version import LooseVersion
from packaging.version import parse as V
from itertools import permutations
from typing import Dict
from typing import Optional
Expand All @@ -22,7 +22,7 @@
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/decoder/stft_decoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from distutils.version import LooseVersion
from packaging.version import parse as V
import torch
from torch_complex.tensor import ComplexTensor

from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.layers.stft import Stft

is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


class STFTDecoder(AbsDecoder):
Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/encoder/stft_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from distutils.version import LooseVersion
from packaging.version import parse as V
import torch
from torch_complex.tensor import ComplexTensor

from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.layers.stft import Stft

is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


class STFTEncoder(AbsEncoder):
Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/espnet_enh_s2t_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from packaging.version import parse as V
import logging
import random
from typing import Dict
Expand All @@ -16,7 +16,7 @@
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Enhancement model module."""
from distutils.version import LooseVersion
from packaging.version import parse as V
from typing import Dict
from typing import List
from typing import Optional
Expand All @@ -20,7 +20,7 @@
from espnet2.train.abs_espnet_model import AbsESPnetModel


is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

EPS = torch.finfo(torch.get_default_dtype()).eps

Expand Down
Loading

0 comments on commit 1c344a9

Please sign in to comment.