Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Aug 11, 2023
1 parent 14f0cb5 commit bf6fb9f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do
--model-filename $repo/exp/jit_trace.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--G $repo/data/lm/G_4_gram.pt \
--method $m \
--sample-rate 16000 \
Expand Down
2 changes: 1 addition & 1 deletion .github/scripts/test-ncnn-export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ log "Export via torch.jit.trace()"
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
\
--tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@
import logging
from pathlib import Path

import k2
import onnxruntime
import sentencepiece as spm
import torch
import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
Expand All @@ -154,7 +154,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -211,10 +211,10 @@ def get_parser():
)

parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -675,12 +675,14 @@ def main():

logging.info(f"device: {device}")

sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand Down

0 comments on commit bf6fb9f

Please sign in to comment.