Skip to content

Commit

Permalink
isort
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirHussein96 committed Apr 5, 2024
1 parent 891cf55 commit e14dae4
Show file tree
Hide file tree
Showing 18 changed files with 62 additions and 48 deletions.
1 change: 1 addition & 0 deletions egs/seame/ASR/local/cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import argparse

import jiwer


Expand Down
5 changes: 2 additions & 3 deletions egs/seame/ASR/local/compute_fbank_gpu_seame.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,19 @@
The generated fbank features are saved in data_seame/fbank.
"""

import argparse
import logging
import os
from pathlib import Path
import argparse

from lhotse import CutSet, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached

from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached


def get_args():
Expand Down
3 changes: 1 addition & 2 deletions egs/seame/ASR/local/compute_fbank_gpu_seame_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
The generated fbank features are saved in data_seame/fbank.
"""

import argparse
import logging
import os
from pathlib import Path
import argparse

from lhotse import CutSet, LilcomChunkyWriter

from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
Expand Down
5 changes: 3 additions & 2 deletions egs/seame/ASR/local/cuts_validate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/usr/bin/python

from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse
import logging
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
import pdb

from lhotse import CutSet, RecordingSet, SupervisionSet
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions


def get_parser():
parser = argparse.ArgumentParser(
Expand Down
2 changes: 1 addition & 1 deletion egs/seame/ASR/local/prepare_lang_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import argparse
import pdb
from pathlib import Path
from typing import Dict, List, Tuple

Expand All @@ -50,7 +51,6 @@
)

from icefall.utils import str2bool
import pdb


def lexicon_to_fst_no_sil(
Expand Down
5 changes: 3 additions & 2 deletions egs/seame/ASR/local/prepare_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
This script prepares transcript_words.txt from cutset
"""

from lhotse import CutSet
import argparse
import logging
import os
import pdb
from pathlib import Path
import os

from lhotse import CutSet


def get_parser():
Expand Down
3 changes: 2 additions & 1 deletion egs/seame/ASR/local/sample_hours.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
Sample data given duration in seconds.
"""

from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse
import logging
import os
from pathlib import Path

from lhotse import CutSet, RecordingSet, SupervisionSet


def get_parser():
parser = argparse.ArgumentParser(
Expand Down
1 change: 1 addition & 0 deletions egs/seame/ASR/local/train_bpe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import shutil
from pathlib import Path
from typing import Dict

import sentencepiece as spm


Expand Down
11 changes: 8 additions & 3 deletions egs/seame/ASR/local/wer_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
Compute WER per language
"""

import sys, codecs, math, pickle, unicodedata, re
from collections import Counter
import argparse
import codecs
import math
import pickle
import re
import sys
import unicodedata
from collections import Counter, defaultdict

from kaldialign import align
from collections import defaultdict


def get_parser():
Expand Down
5 changes: 2 additions & 3 deletions egs/seame/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
import logging
import math
import os
import re
import string
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -105,9 +107,6 @@
str2bool,
write_error_stats,
)
import string
import re


LOG_EPS = math.log(1e-10)

Expand Down
14 changes: 7 additions & 7 deletions egs/seame/ASR/zipformer/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@

import argparse
import logging
import sentencepiece as spm
import torch

from typing import Tuple
from torch import Tensor, nn

from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
import sentencepiece as spm
import torch
from scaling import BiasNorm
from torch import Tensor, nn
from train import (
add_model_arguments,
get_encoder_embed,
get_encoder_model,
get_joiner_model,
add_model_arguments,
get_params,
)
from zipformer import BypassModule

from icefall.profiler import get_model_profile
from icefall.utils import make_pad_mask


def get_parser():
parser = argparse.ArgumentParser(
Expand Down
6 changes: 3 additions & 3 deletions egs/seame/ASR/zipformer_hat/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
import logging
import math
import os
import re
import string
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
Expand All @@ -64,8 +66,8 @@
from beam_search import (
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from train import add_model_arguments, get_model, get_params
Expand All @@ -86,8 +88,6 @@
str2bool,
write_error_stats,
)
import string
import re

LOG_EPS = math.log(1e-10)

Expand Down
2 changes: 1 addition & 1 deletion egs/seame/ASR/zipformer_hat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear

from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear


class AsrModel(nn.Module):
Expand Down
14 changes: 7 additions & 7 deletions egs/seame/ASR/zipformer_hat_lid/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,28 @@
import logging
import math
import os
import re
import string
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import k2
import matplotlib.pyplot as plt
import seaborn as sns
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import SeameAsrDataModule
from beam_search import (
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from kaldialign import align
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from train import add_model_arguments, get_model, get_params

from icefall import ContextGraph, LmScorer, NgramLm
Expand All @@ -109,12 +115,6 @@
str2bool,
write_error_stats,
)
from kaldialign import align
from sklearn.metrics import f1_score, classification_report, confusion_matrix
import string
import re
import seaborn as sns
import matplotlib.pyplot as plt

LOG_EPS = math.log(1e-10)

Expand Down
3 changes: 2 additions & 1 deletion egs/seame/ASR/zipformer_hat_lid/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
import torch.nn as nn
from scaling import ScaledLinear
from typing import Optional


class Joiner(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion egs/seame/ASR/zipformer_hat_lid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear

from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear


class AsrModel(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion egs/seame/ASR/zipformer_hat_lid/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from torch.optim import Optimizer

import k2
import optim
import sentencepiece as spm
Expand All @@ -120,6 +120,7 @@
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2

Expand Down
25 changes: 15 additions & 10 deletions egs/seame/ASR/zipformer_hat_lid/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,33 @@
# limitations under the License.

import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import logging

import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from torch import Tensor, nn

Expand Down

0 comments on commit e14dae4

Please sign in to comment.