Skip to content

Commit

Permalink
modify config & fix ci error
Browse files Browse the repository at this point in the history
  • Loading branch information
YushiUeda committed Mar 6, 2022
1 parent 1b31f46 commit d537330
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 97 deletions.
2 changes: 1 addition & 1 deletion egs2/swbd_sentiment/asr1/conf/decode_asr.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
beam_size: 10
beam_size: 1

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ frontend: s3prl
frontend_conf:
frontend_conf:
upstream: wav2vec2_large_ll60k # Note: If the upstream is changed, please change the input_size in the preencoder.
# If using hubert, change the above line to "upstream: hubert_large_ll60k"
download_dir: ./hub
multilayer_feature: True

Expand Down
16 changes: 9 additions & 7 deletions egs2/swbd_sentiment/asr1/local/prepare_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def majorityvote(line):

def normalize_transcript(transcript):
# remove punctuation except apostrophes
transcript = re.sub("(\.|\,|\?|\!|\-|\:|\;)", " \\1 ", transcript)
transcript = re.sub("\.|\,|\?|\!|\-|\:|\;", "", transcript)
transcript = re.sub(r"(\.|\,|\?|\!|\-|\:|\;)", " \\1 ", transcript)
transcript = re.sub(r"\.|\,|\?|\!|\-|\:|\;", "", transcript)
# remove tag (e.g. [LAUGHTER])
transcript = re.sub("\[.+\]", "", transcript)
transcript = re.sub(r"\[.+\]", "", transcript)
# Detect valid apostrophe cases and split those into two words
transcript = re.sub("([a-z])'([a-z])", "\\1 '\\2", transcript)
# Clean up special cases of standalone apostrophes
Expand Down Expand Up @@ -58,7 +58,8 @@ def process_data(
prev_linenum_wf = 0
for linenum, line_sf in enumerate(sf):
if linenum >= start_linenum and linenum < end_linenum:
# "sw02005_0[tab]0.0[tab]11.287375[tab]Neutral-{Questioning}#Neutral-{No emotion}#Neutral-{No emotion}"
# "sw02005_0[tab]0.0[tab]11.287375[tab]
# Neutral-{Questioning}#Neutral-{No emotion}#Neutral-{No emotion}"
utt_id_sf, start, end, sentiment = line_sf.strip().split("\t")
# "sw02005_0" -> "sw02005"
reco_id_sf = utt_id_sf.split("_")[0]
Expand All @@ -67,15 +68,17 @@ def process_data(
tf.seek(0)
for linenum_tf, line_tf in enumerate(tf):
if linenum_tf >= prev_linenum_tf:
# "sw02001-A_018732-018950 oh i see uh-huh" -> "sw02001-A_018732-018950" "oh i see uh-huh"
# "sw02001-A_018732-018950 oh i see uh-huh"
# -> "sw02001-A_018732-018950" "oh i see uh-huh"
utt_id_tf, transcript = line_tf.strip("\n").split(" ", 1)
# "sw02001-A_018732-018950" -> "sw02001-A" "018732-018950"
spk_id_tf, time_id = utt_id_tf.split("_")
# "sw02001-A" -> "sw02001"
reco_id_tf = spk_id_tf.split("-")[0]
# "018732-018950" -> "018732" "018950"
start_time_id, end_time_id = time_id.split("-")
# in case start and end time slightly differ in text and sentiment annotation
# in case start and end time slightly differ
# in text and sentiment annotation
eps = 0.05
if (
reco_id_tf == reco_id_sf
Expand All @@ -84,7 +87,6 @@ def process_data(
and end_time_id >= float2str(float(end) - eps)
and end_time_id <= float2str(float(end) + eps)
):
# print("{} {} {} {}".format(utt_id_tf, reco_id_sf, start_time_id, end_time_id))
# normalize transcript
transcript = normalize_transcript(transcript)
utt2spk_list.append(
Expand Down

0 comments on commit d537330

Please sign in to comment.