Skip to content

Sentencepiece #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/conformer/train_ga_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
parser.add_argument("--acs", type=int, default=None,
help="Train accumulation steps")

parser.add_argument("--sentence_piece", default=False, action="store_true",
help="Whether to use `SentencePiece` model")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")

Expand All @@ -68,15 +71,18 @@
from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.subwords and os.path.exists(args.subwords):
if args.sentence_piece:
print("Loading SentencePiece model ...")
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
elif args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
else:
Expand Down
10 changes: 8 additions & 2 deletions examples/conformer/train_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")

parser.add_argument("--sentence_piece", default=False, action="store_true",
help="Whether to use `SentencePiece` model")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")

Expand Down Expand Up @@ -65,15 +68,18 @@
from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainer
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.subwords and os.path.exists(args.subwords):
if args.sentence_piece:
print("Loading SentencePiece model ...")
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
elif args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
else:
Expand Down
109 changes: 109 additions & 0 deletions scripts/create_mls_trans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2020 M. Yusuf Sarıgöz (@monatis)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import librosa
import tqdm
import tensorflow as tf

# example usage: python create_mls_trans.py -dataset-home /mnt/datasets/mls --language polish --opus

base_url = "https://dl.fbaipublicfiles.com/mls/"

langs = [
"dutch",
"english",
"german",
"french",
"italian",
"portuguese",
"polish",
"spanish"
]

splits = [
"dev",
"test",
"train"
]

chars = set()

def prepare_split(dataset_dir, split, opus=False):
# Setup necessary paths
split_home = os.path.join(dataset_dir, split)
transcripts_infile = os.path.join(split_home, 'transcripts.txt')
transcripts_outfile = os.path.join(split_home, 'transcripts_tfasr.tsv')
audio_home = os.path.join(split_home, "audio")
extension = ".opus" if opus else ".flac"
transcripts = []

# Make paths absolute, get durations and read chars to form alphabet later on
with open(transcripts_infile, 'r', encoding='utf8') as infile:
for line in tqdm.tqdm(infile.readlines(), desc=f"Reading from {transcripts_infile}..."):
file_id, transcript = line.strip().split('\t')
speaker_id, book_id, _ = file_id.split('_')
audio_path = os.path.join(audio_home, speaker_id, book_id, f"{file_id}{extension}")
y, sr = librosa.load(audio_path, sr=None)
duration = librosa.get_duration(y, sr)
transcripts.append(f"{audio_path}\t{duration:2f}\t{transcript}\n")
for char in transcript:
chars.add(char)

# Write transcripts to file
with open(transcripts_outfile, 'w', encoding='utf8') as outfile:
outfile.write("PATH\tDURATION\tTRANSCRIPT\n")
for t in tqdm.tqdm(transcripts, desc=f"Writing to {transcripts_outfile}"):
outfile.write(t)


def make_alphabet_file(filepath, chars_list, lang):
print(f"Writing alphabet to {filepath}...")
with open(filepath, 'w', encoding='utf8') as outfile:
outfile.write(f"# Alphabet file for language {lang}\n")
outfile.write("Automatically generated. Do not edit\n#\n")
for char in sorted(list(chars_list)):
outfile.write(f"{char}\n")

outfile.write("# end of file")


if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Download and prepare MLS dataset in a given language")
ap.add_argument("--dataset-home", "-d", help="Path to home directory to download and prepare dataset. Default to ~/.keras", default=None, required=False)
ap.add_argument("--language", "-l", type=str, choices=langs, help="Any name of language included in MLS", default=None, required=True)
ap.add_argument("--opus", help="Whether to use dataset in opus format or not", default=False, action='store_true')

args = ap.parse_args()
fname = "mls_{}{}.tar.gz".format(args.language, "_opus" if args.opus else "")
subdir = fname[:-7]
dataset_home = os.path.abspath(args.dataset_home)
dataset_dir = os.path.join(dataset_home, subdir)
full_url = base_url + fname

downloaded_file = tf.keras.utils.get_file(
fname,
full_url,
cache_subdir=dataset_home,
extract=True
)

print(f"Dataset extracted to {dataset_dir}. Preparing...")

for split in splits:
prepare_split(dataset_dir=dataset_dir, split=split, opus=args.opus)

make_alphabet_file(os.path.join(dataset_dir, "alphabet.txt"), chars, args.language)
32 changes: 32 additions & 0 deletions scripts/generate_vocab_sentencepiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import argparse
from tensorflow_asr.utils import setup_environment, setup_strategy

setup_environment()
import tensorflow as tf

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()

parser = argparse.ArgumentParser(prog="Vocab Training with SentencePiece")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")

args = parser.parse_args()

strategy = setup_strategy(args.devices)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.text_featurizers import SentencePieceFeaturizer

config = Config(args.config, learning=True)

print("Generating subwords ...")
text_featurizer = SentencePieceFeaturizer.build_from_corpus(
config.decoder_config
)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
"tqdm>=4.54.1",
"colorama>=0.4.4",
"nlpaug>=1.1.1",
"nltk>=3.5"
"nltk>=3.5",
"sentencepiece>=0.1.94"
]

setuptools.setup(
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_asr/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, config: dict = None):
self.norm_score = config.get("norm_score", True)
self.lm_config = config.get("lm_config", {})
self.additional_properties = config.get("additional_properties", {})
self.output_path_prefix = preprocess_paths(config.get("output_path_prefix", None))
self.model_type = config.get("model_type", None)
self.corpus_files = config.get("corpus_files", None)


class DatasetConfig:
Expand Down
153 changes: 152 additions & 1 deletion tensorflow_asr/featurizers/text_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import abc
import codecs
import unicodedata

from multiprocessing import cpu_count
import sentencepiece as sp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tds
Expand Down Expand Up @@ -300,3 +301,153 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))


class SentencePieceFeaturizer(TextFeaturizer):
"""
Extract text feature based on sentence piece package.
"""
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 1
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 2
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 3
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 0 # unused, by default

def __init__(self, decoder_config: dict, model=None):
super().__init__(decoder_config)
self.model = model
self.blank = 0 # treats blank as 0 (pad)
self.upoints = None
# vocab size
self.num_classes = self.model.get_piece_size()
# create upoints
self.__init_upoints()

def __init_upoints(self):
text = [""]
for idx in range(1, self.num_classes):
text.append(self.model.decode_ids([idx]))
self.upoints = tf.strings.unicode_decode(text, "UTF-8")
self.upoints = self.upoints.to_tensor() # [num_classes, max_subword_length]

@classmethod
def build_from_corpus(cls, decoder_config: dict):
"""
--model_prefix: output model name prefix. <model_name>.model and <model_name>.vocab are generated.
--vocab_size: vocabulary size, e.g., 8000, 16000, or 32000
--model_type: model type. Choose from unigram (default), bpe, char, or word.
The input sentence must be pretokenized when using word type."""
decoder_cfg = DecoderConfig(decoder_config)
# Train SentencePiece Model
def corpus_iterator():
for file in decoder_cfg.corpus_files:
with open(file, "r", encoding="utf-8") as f:
lines = f.read().splitlines()
lines = lines[1:]
for line in lines:
line = line.split("\t")
yield line[-1]

sp.SentencePieceTrainer.Train(
sentence_iterator=corpus_iterator(),
model_prefix=decoder_cfg.output_path_prefix,
model_type=decoder_cfg.model_type,
vocab_size=decoder_cfg.target_vocab_size,
num_threads=cpu_count(),
unk_id=cls.UNK_TOKEN_ID,
bos_id=cls.BOS_TOKEN_ID,
eos_id=cls.EOS_TOKEN_ID,
pad_id=cls.PAD_TOKEN_ID,
unk_surface='__UNKNOWN__' # change default unk surface U+2047("⁇") by "__UNKNOWN__"
)
# Export fairseq dictionary
processor = sp.SentencePieceProcessor()
processor.Load(decoder_cfg.output_path_prefix + ".model")
vocab = {i: processor.IdToPiece(i) for i in range(processor.GetPieceSize())}
assert (
vocab.get(cls.UNK_TOKEN_ID) == cls.UNK_TOKEN
and vocab.get(cls.BOS_TOKEN_ID) == cls.BOS_TOKEN
and vocab.get(cls.EOS_TOKEN_ID) == cls.EOS_TOKEN
)
vocab = {
i: s
for i, s in vocab.items()
if s not in {cls.UNK_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.PAD_TOKEN}
}
with open(decoder_cfg.output_path_prefix + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f"{s} 1\n")

return cls(decoder_config, processor)

@classmethod
def load_from_file(cls, decoder_config: dict, filename: str = None):
if filename is not None:
filename_prefix = os.path.splitext(preprocess_paths(filename))[0]
else:
filename_prefix = decoder_config.get("output_path_prefix", None)
processor = sp.SentencePieceProcessor()
processor.load(filename_prefix + ".model")
return cls(decoder_config, processor)

def extract(self, text: str) -> tf.Tensor:
"""
Convert string to a list of integers
# encode: text => id
sp.encode_as_pieces('This is a test') --> ['▁This', '▁is', '▁a', '▁t', 'est']
sp.encode_as_ids('This is a test') --> [209, 31, 9, 375, 586]
Args:
text: string (sequence of characters)

Returns:
sequence of ints in tf.Tensor
"""
text = self.preprocess_text(text)
text = text.strip() # remove trailing space
indices = self.model.encode_as_ids(text)
return tf.convert_to_tensor(indices, dtype=tf.int32)

def iextract(self, indices: tf.Tensor) -> tf.Tensor:
"""
Convert list of indices to string
# decode: id => text
sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']) --> This is a test
sp.decode_ids([209, 31, 9, 375, 586]) --> This is a test

Args:
indices: tf.Tensor with dim [B, None]

Returns:
transcripts: tf.Tensor of dtype tf.string with dim [B]
"""
indices = self.normalize_indices(indices)
with tf.device("/CPU:0"): # string data is not supported on GPU
def decode(x):
if x[0] == self.blank: x = x[1:]
return self.model.decode_ids(x)

text = tf.map_fn(
lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string),
indices,
fn_output_signature=tf.TensorSpec([], dtype=tf.string)
)
return text

@tf.function(
input_signature=[
tf.TensorSpec([None], dtype=tf.int32)
]
)
def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
"""
Transform Predicted Indices to Unicode Code Points (for using tflite)
Args:
indices: tf.Tensor of Classes in shape [None]

Returns:
unicode code points transcript with dtype tf.int32 and shape [None]
"""
with tf.name_scope("indices2upoints"):
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))

Loading