Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Commit of work for final sprint for Round 1 paper submission. #27

Merged
merged 1 commit into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "kaldi"]
path = third_party/kaldi
url = git://github.com/galv/kaldi.git
[submodule "third_party/pydub"]
path = third_party/pydub
url = https://github.com/galv/pydub.git
2 changes: 2 additions & 0 deletions galvasr2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ py_binary(
"align/spark/dsalign_lib.py",
"align/spark/event_listener.py",
"align/spark/schemas.py",
"align/spark/timeout.py",
] +
glob(["align/*.py"]),
deps = [
Expand All @@ -48,6 +49,7 @@ py_binary(
"align/spark/align_lib.py",
"align/spark/event_listener.py",
"align/spark/schemas.py",
"align/spark/timeout.py",
] +
glob(["align/*.py"]),
deps = [
Expand Down
44 changes: 27 additions & 17 deletions galvasr2/align/spark/align_cuda_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pyspark
import pyspark.sql.functions as F
import pyspark.sql.types as T
import tensorflow as tf

from galvasr2.align.spark.align_lib import fix_text_udf, load_audio_id_text_id_mapping, load_transcripts, prepare_create_audio_segments_udf, TemporaryMountDirectory
from galvasr2.align.spark.dsalign_lib import prepare_align_udf
Expand Down Expand Up @@ -45,17 +46,19 @@
'Input directory. Exact format of this is a bit undefined right now and will likely change.')


def create_wav_scp(wav_scp_file_name: str, rows: List[pyspark.Row], base_path: str):
def create_wav_scp(wav_scp_file_name: str, rows: List[pyspark.Row], base_path: str, ctm_path: str):
with open(wav_scp_file_name, "w") as fh:
lines = []
for row in rows:
import tqdm
for row in tqdm.tqdm(rows):
key = os.path.join(row.identifier, row.audio_document_id)
if ctm_path is not None:
output_file_name = os.path.join(ctm_path, row.kaldi_normalized_uttid + ".ctm")
if tf.io.gfile.exists(output_file_name):
continue
path = os.path.join(base_path, key)
value = f"/usr/bin/sox \"{path}\" -t wav --channels 1 --rate 8000 --encoding signed --bits 16 - |"
line = f"{row.kaldi_normalized_uttid} {value}\n"
lines.append(line)
fh.write(line)
# shuffle(lines)
# fh.writelines(lines)

def split_wav_scp(posix_wav_scp, work_dir, num_splits):
Expand All @@ -68,7 +71,7 @@ def split_wav_scp(posix_wav_scp, work_dir, num_splits):
file_handles.append(open(file_names[-1], "w"))
with open(posix_wav_scp) as fh:
for i, line in enumerate(fh):
i = i % 4
i = i % num_splits
file_handles[i].write(line)
for fh in file_handles:
fh.close()
Expand All @@ -83,7 +86,9 @@ def main(argv):
.config("spark.sql.execution.arrow.pyspark.enabled", "true")\
.config("spark.driver.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")\
.config("spark.executor.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")\
.config('spark.driver.memory', '310g')\
.config("spark.history.fs.logDirectory", "/spark-events")\
.config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")\
.getOrCreate()
spark.sparkContext.setLogLevel("INFO") # "ALL" for very verbose logging
logging.getLogger("py4j").setLevel(logging.ERROR)
Expand All @@ -96,6 +101,7 @@ def main(argv):
if not FLAGS.work_dir.startswith("gs://"):
os.makedirs(FLAGS.work_dir, exist_ok=True)
wav_scp = os.path.join(FLAGS.work_dir, "wav.scp")
ctm_out_dir = os.path.join(FLAGS.work_dir, "decoder_ctm_dir")
if FLAGS.stage <= 0:
catalogue_df = catalogue_df.cache()
# catalogue_df.write.mode("overwrite").format("csv").options(header="true").save(key_int_mapping)
Expand All @@ -107,7 +113,7 @@ def main(argv):
unmount_cmd=["fusermount", "-u"]) as temp_dir_name:
posix_wav_scp = re.sub(r'^{0}'.format(FLAGS.input_gcs_bucket),
temp_dir_name, wav_scp)
create_wav_scp(posix_wav_scp, training_sample_rows, FLAGS.input_dir)
create_wav_scp(posix_wav_scp, training_sample_rows, FLAGS.input_dir, ctm_out_dir)

# /development/lingvo-source/output_ctm_dir/

Expand All @@ -117,7 +123,6 @@ def main(argv):
# Can get 266x RTF with this configuration. Keep it?
# bath size of 100 and num channels of 100 works just fine

ctm_out_dir = os.path.join(FLAGS.work_dir, "decoder_ctm_dir")
if FLAGS.stage <= 1:
if not FLAGS.work_dir.startswith("gs://"):
os.makedirs(ctm_out_dir, exist_ok=True)
Expand Down Expand Up @@ -180,7 +185,6 @@ def run_gpu(posix_wav_scp_shard, gpu_number):
# /opt/kaldi/egs/aspire/s5/exp/tdnn_7b_chain_online/graph_pp/phones/word_boundary.int
# /opt/kaldi/egs/aspire/s5/exp/tdnn_7b_chain_online/graph_pp/phones/word_boundary.int
# --word-boundary-rxfilename=/opt/kaldi/egs/aspire/s5/exp/tdnn_7b_chain_online/graph_pp/phones/word_boundary.int \
assert False
if FLAGS.stage <= 2:
FAKE_WORDS = {"<eps>", "<unk>", "[laughter]", "[noise]", "<s>", "</s>", "#0"}
alphabet_set = set()
Expand All @@ -200,7 +204,8 @@ def run_gpu(posix_wav_scp_shard, gpu_number):
# TODO: Add options to DSAlign here
dsalign_args = dsalign_main.parse_args("")

align_udf = prepare_align_udf(dsalign_args, alphabet_path)
alphabet_normalized_path = "/development/lingvo-source/alphabet2.txt"
align_udf = prepare_align_udf(dsalign_args, alphabet_normalized_path)

ctm_df = spark.read.format("binaryFile").option("pathGlobFilter", "*.ctm").load(ctm_out_dir)
ctm_df = ctm_df.withColumn("kaldi_normalized_uttid", F.regexp_replace(F.reverse(F.split(ctm_df.path, "/"))[0], r"[.]ctm$", ""))
Expand All @@ -210,8 +215,8 @@ def run_gpu(posix_wav_scp_shard, gpu_number):
downsampled_catalogue_df = ctm_df.drop("ctm_content")

training_sample_rows = downsampled_catalogue_df.collect()
# training_sample_rows = training_sample_rows[:10]
transcripts_df = load_transcripts(spark, FLAGS.input_gcs_path, training_sample_rows)
# TODO: Fix this. We need to recover the original identifier for each ctm file.
ctm_df = ctm_df.join(transcripts_df, ['identifier', 'text_document_id'])

# alignments_df = ctm_df.select(align_udf(F.concat(ctm_df.identifier, F.lit("/"), ctm_df.text_document_id),
Expand All @@ -220,19 +225,22 @@ def run_gpu(posix_wav_scp_shard, gpu_number):
alignments_df = ctm_df.withColumn("alignments",
align_udf(F.concat(ctm_df.identifier, F.lit("/"), ctm_df.text_document_id),
F.concat(ctm_df.identifier, F.lit("/"), ctm_df.audio_document_id),
ctm_df.transcript, ctm_df.ctm_content))
ctm_df.transcript, ctm_df.ctm_content)).drop('ctm_content')
print("GALVEZ:schema")
alignments_df.printSchema()
import sys; sys.stdout.flush()

alignments_df.write.mode("overwrite").format("json").save(os.path.join(FLAGS.work_dir, "alignments_json"))

pass
training_data_export_dir = os.path.join(FLAGS.work_dir, "training_data_export")
if FLAGS.stage <= 3:
alignments_df = spark.read.json(os.path.join(FLAGS.work_dir, "alignments_json"))
# We would like the number of partitions to be some large multiple
# of the number of executors. Not every audio file is the same
# length, so this helps with load balancing.
alignments_df = alignments_df.repartition(960)
create_audio_segments_udf = prepare_create_audio_segments_udf(gs_bucket=FLAGS.input_gcs_bucket,
output_dir=os.path.join(FLAGS.work_dir, "training_set")
output_dir=os.path.join(FLAGS.work_dir, "training_set_wav")
)
audio_paths = F.concat(F.lit(FLAGS.input_gcs_path), F.lit("/"),
alignments_df.identifier, F.lit("/"),
Expand All @@ -245,12 +253,14 @@ def run_gpu(posix_wav_scp_shard, gpu_number):
alignments_df.audio_document_id,
alignments_df.text_document_id,
F.struct(
alignments_df.alignments.label,
alignments_df.output_paths
alignments_df.alignments.label.alias('label'),
alignments_df.output_paths,
F.expr("transform(arrays_zip(alignments.end_ms, alignments.start_ms), x -> x.end_ms - x.start_ms)").alias('duration_ms')
# (alignments_df.alignments.end_ms - alignments_df.alignments.start_ms).alias('duration_ms'),
).alias('training_data')
)
# coalesce(1) seems to make the create_audio_segments_udf function run serially
output_df.write.mode("overwrite").json(os.path.join(FLAGS.work_dir, "dataset_manifest"))
output_df.write.mode("overwrite").json(os.path.join(FLAGS.work_dir, "dataset_manifest_wav"))

if __name__ == '__main__':
app.run(main)
75 changes: 66 additions & 9 deletions galvasr2/align/spark/align_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import subprocess
import sys
import tempfile
import threading
from typing import List, Tuple
import wave

Expand All @@ -17,6 +18,7 @@
import langid
import numpy as np
import pandas as pd
import pydub
from pydub import AudioSegment
import pyspark
from matching.games import HospitalResident
Expand All @@ -35,6 +37,7 @@
IntegerType, LongType)

from galvasr2.align.spark.schemas import ARCHIVE_ORG_SCHEMA
from galvasr2.align.spark.timeout import timeout

def DecodeToWavPipe(input_bytes, fmt):
cmd = f'sox -t {fmt} - -t wav --channels 1 --rate 16000 --encoding signed --bits 16 -'
Expand Down Expand Up @@ -298,6 +301,9 @@ def load_audio_id_text_id_mapping(spark, input_catalogue_path: str):

# from IPython import embed; embed()

print("GALVEZ:json=", df.count())
print("GALVEZ:exploded=", filtered_exploded_df.count())

text_df = filtered_exploded_df.select(
filtered_exploded_df.identifier,
filtered_exploded_df.exploded_files["name"].alias("text_document_id"),
Expand All @@ -316,11 +322,10 @@ def load_audio_id_text_id_mapping(spark, input_catalogue_path: str):
(filtered_exploded_df.exploded_files["name"].endswith('.mp3'))
)

# The cause of the problem
# This creates duplicates, doesn't it? Ooops
joined_df = audio_df.join(text_df, "identifier")
joined_df = joined_df.withColumn("levenshtein", F.levenshtein(joined_df.audio_document_id, joined_df.text_document_id))
audio_to_text_mapping_df = joined_df.groupBy("identifier").applyInPandas(fuzzy_matching, schema=FUZZY_MATCHING_RETURN_TYPE)
# This is needed for the hours-per-license job. We should probably avoid adding one-off columsn inside this function.
licenses_df = df.select(df.identifier, df.metadata.licenseurl.alias('licenseurl'))
audio_to_text_mapping_df = audio_to_text_mapping_df.join(licenses_df, ['identifier'])
return audio_to_text_mapping_df
Expand Down Expand Up @@ -400,8 +405,18 @@ def __init__(self, mount_cmd: List, unmount_cmd: List):
subprocess.check_call(mount_cmd + [self.name])

def __exit__(self, exc, value, tb):
subprocess.check_call(self._unmount_cmd + [self.name])
super().__exit__(exc, value, tb)
try:
subprocess.check_call(self._unmount_cmd + [self.name])
unmounted = True
except subprocess.CalledProcessError:
try:
subprocess.check_call(["umount", self.name])
unmounted = True
except subprocess.CalledProcessError:
unmounted = False
print(f"WARNING: Failed to unmount {self.name}. Not removing temporary directory. Trying to remove a temporary directory that fails to unmount results in: OSError: [Errno 5] Input/output error")
if unmounted:
super().__exit__(exc, value, tb)

def _prepare_soxi_udf(soxi_flags, spark_return_type, python_return_type):
@pandas_udf(spark_return_type)
Expand All @@ -418,20 +433,23 @@ def get_soxi_info_udf(audio_file_series: pd.Series) -> pd.Series:
cmd = f"soxi {soxi_flags} \"{audio_file}\""
try:
duration = subprocess.check_output(shlex.split(cmd), stderr=subprocess.DEVNULL, timeout=10) # 10 second timeout
# print("GALVEZ:value=")
# print(duration)
duration = python_return_type(duration.rstrip(b'\n'))
except subprocess.CalledProcessError:
# WARNING: Assumes that your return type default constructor returns a "reasonable" value.
# May return None instead?
duration = python_return_type()
except subprocess.TimeoutExpired:
print(f"Restarting on {audio_file}")
# Call again. Sometimes gcsfuse just stalls, so we need restartability
return get_soxi_info_udf(audio_file_series)
durations.append(duration)
return pd.Series(durations)
return get_soxi_info_udf

get_audio_seconds_udf = _prepare_soxi_udf("-D", DoubleType(), float)
get_audio_sample_rate_udf = _prepare_soxi_udf("-r", DoubleType(), float)
get_audio_sample_rate_udf = _prepare_soxi_udf("-r", StringType(), str)
get_audio_annotations_udf = _prepare_soxi_udf("-a", BinaryType(), bytes)

def prepare_create_audio_segments_udf(gs_bucket: str, output_dir: str):
Expand All @@ -444,18 +462,57 @@ def create_audio_segments_udf(audio_file_gcs_paths: pd.Series, identifier_series
with TemporaryMountDirectory(
mount_cmd=["gcsfuse", "--implicit-dirs", gs_bucket.lstrip("gs://")],
unmount_cmd=["fusermount", "-u"]) as temp_dir_name:
if not output_dir.startswith("gs://"):
posix_output_dir = output_dir
else:
posix_output_dir = re.sub(r'^{0}'.format(gs_bucket), temp_dir_name, output_dir)
for audio_file_gcs_path, identifier, audio_document_id, start_ms_array, end_ms_array in zip(audio_file_gcs_paths, identifier_series, audio_document_id_series, start_ms_arrays, end_ms_arrays):
chunk_paths.append([])
audio_file_path = re.sub(r'^{0}'.format(gs_bucket), temp_dir_name, audio_file_gcs_path)
source = AudioSegment.from_file(audio_file_path)
this_file_output_dir = os.path.join(output_dir, identifier)
print(f"GALVEZ:audio_file_path={audio_file_path}")
try:
source = AudioSegment.from_file(audio_file_path, subprocess_timeout=100)
except subprocess.TimeoutExpired:
print("GALVEZ:timed out, need to retry")
return create_audio_segments_udf.func(audio_file_gcs_paths,
identifier_series,
audio_document_id_series,
start_ms_arrays,
end_ms_arrays)
except pydub.exceptions.CouldntDecodeError:
print(f"GALVEZ:problematic audio_file_path={audio_file_path}")
continue
identifier = identifier.replace('/', '_')
this_file_output_dir = os.path.join(posix_output_dir, identifier)
os.makedirs(this_file_output_dir, exist_ok=True)
base, _ = os.path.splitext(audio_document_id)
# We have to handle cases where audio_document_id contains a slash, like this one: collateral/gov.house.oversight.2007.03.19.iphone.mp3
# We could alternatively, but I worry about users of the dataset naively writing a glob pattern like "*/*.mp3", rather than "*/*/*.mp3".
# Furthermore there could be an arbitrary number of "/" characters. That is hard to handle programatically
base = base.replace('/', '_')
last_write_file_name = f"{base}-{(len(start_ms_array) - 1):04d}.wav"
already_done = timeout(os.path.exists, (os.path.join(this_file_output_dir, last_write_file_name),), timeout_duration=100)
if already_done:
pass
for i, (start_ms, end_ms) in enumerate(zip(start_ms_array, end_ms_array)):
# Flac encoding probably good
write_file_name = f"{base}-{i:04d}.wav"
fh = source[start_ms:end_ms].export(os.path.join(output_dir, identifier, write_file_name))
fh.close()
if not already_done:
write_path = os.path.join(this_file_output_dir, write_file_name)
try:
fh = source[start_ms:end_ms].export(write_path, format="wav", subprocess_timeout=100)
except subprocess.TimeoutExpired:
print("GALVEZ:timed out 2, need to retry")
return create_audio_segments_udf.func(audio_file_gcs_paths,
identifier_series,
audio_document_id_series,
start_ms_arrays,
end_ms_arrays)
except pydub.exceptions.CouldntEncodeError:
print(f"GALVEZ: Couldn't encode {write_path} [{start_ms}:{end_ms}]")
continue
else:
fh.close()
chunk_paths[-1].append(os.path.join(identifier, write_file_name))
return pd.Series(chunk_paths)
return create_audio_segments_udf
Expand Down
Loading