Skip to content

Commit

Permalink
Minor refinements for some stale but recently merged PRs (#1354)
Browse files Browse the repository at this point in the history
* incorporate #1269

* incorporate #1301

* black formatted

* incorporate #1162

* black formatted
  • Loading branch information
JinZr authored Oct 31, 2023
1 parent c970df5 commit 23913f6
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 68 deletions.
2 changes: 1 addition & 1 deletion egs/aishell/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/gigaspeech/ASR/zipformer/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def train_dataloaders(
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
Expand Down
2 changes: 1 addition & 1 deletion egs/gigaspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
12 changes: 4 additions & 8 deletions egs/libriheavy/ASR/zipformer_prompt_asr/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def batched_params(self, param_group, group_params_names):

yield tuples # <-- calling code will do the actual optimization here!

for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])

Expand Down Expand Up @@ -181,7 +181,6 @@ def __init__(
size_update_period=4,
clipping_update_period=100,
):

defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
Expand Down Expand Up @@ -327,9 +326,7 @@ def step(self, closure=None):
batch = True

for group, group_params_names in zip(self.param_groups, self.parameters_names):

with self.batched_params(group["params"], group_params_names) as batches:

# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
Expand Down Expand Up @@ -429,7 +426,7 @@ def _get_clipping_scale(
clipping_update_period = group["clipping_update_period"]

tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
Expand Down Expand Up @@ -514,7 +511,7 @@ def _show_gradient_dominating_parameter(
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
Expand All @@ -530,7 +527,6 @@ def _show_gradient_dominating_parameter(
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):

proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)

Expand Down Expand Up @@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int):

# if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# 512
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
8 changes: 4 additions & 4 deletions egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
Expand Down Expand Up @@ -225,7 +225,7 @@ def train_dataloaders(
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
Expand Down Expand Up @@ -307,8 +307,8 @@ def train_dataloaders(
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
Expand Down
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import argparse
import logging
import math
import pprint
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pprint

import k2
import sentencepiece as spm
import torch
Expand Down
1 change: 0 additions & 1 deletion egs/librispeech/ASR/tiny_transducer_ctc/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def moving_avg(self, x: Tensor) -> Tensor:
return y

def forward(self, x: Tensor) -> Tensor:

assert len(x.shape) == 3, "Input is not a 3D tensor!"
y = self.exponential_moving_avg(x)
y = y.permute(0, 2, 1) # make channel last for squeeze op
Expand Down
31 changes: 10 additions & 21 deletions egs/librispeech/ASR/tiny_transducer_ctc/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@
import logging
from pathlib import Path

import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon
from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -143,13 +143,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -189,17 +186,9 @@ def main():

logging.info(f"device: {device}")

if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
else:
assert "lang_phone" in str(params.lang_dir)
phone_lexicon = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(phone_lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand Down
4 changes: 1 addition & 3 deletions egs/librispeech/ASR/tiny_transducer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:


def add_model_arguments(parser: argparse.ArgumentParser):

parser.add_argument(
"--encoder-dim",
type=int,
Expand Down Expand Up @@ -405,7 +404,6 @@ def get_params() -> AttributeDict:


def get_encoder_model(params: AttributeDict) -> nn.Module:

encoder = Conv1dNet(
output_dim=params.encoder_dim,
input_dim=params.feature_dim,
Expand Down Expand Up @@ -1043,7 +1041,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
21 changes: 9 additions & 12 deletions egs/librispeech/ASR/zipformer_ctc/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
from pathlib import Path

import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
Expand All @@ -33,8 +34,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -90,11 +90,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""It contains language related input files such as "lexicon.txt"
""",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand All @@ -113,17 +112,15 @@ def get_parser():
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)

params = get_params()
params.update(vars(args))

logging.info(params)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
logging.info(params)

device = torch.device("cpu")
if torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/zipformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
33 changes: 21 additions & 12 deletions icefall/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,14 @@ def get_count(count):

if stats_type == "eigs":
try:
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
eigs, _ = torch.linalg.eigh(stats)
else:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print(
"Error getting eigenvalues, trying another method."
)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
print("Error getting eigenvalues, trying another method.")
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
eigs, _ = torch.linalg.eig(stats)
eigs = eigs.abs()
else:
Expand Down Expand Up @@ -579,10 +577,15 @@ def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=get_class_name(_module))

if isinstance(o, Tensor) and o.dtype in (
torch.float32,
torch.float16,
torch.float64,
):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)

def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0]
Expand All @@ -596,9 +599,15 @@ def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=get_class_name(_module))
if isinstance(o, Tensor) and o.dtype in (
torch.float32,
torch.float16,
torch.float64,
):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)

module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)

Expand Down

0 comments on commit 23913f6

Please sign in to comment.