diff --git a/espnet2/tasks/gan_tts.py b/espnet2/tasks/gan_tts.py index bfdc343407f..d7eb4a0395f 100644 --- a/espnet2/tasks/gan_tts.py +++ b/espnet2/tasks/gan_tts.py @@ -221,7 +221,9 @@ def build_collate_fn( ]: assert check_argument_types() return CommonCollateFn( - float_pad_value=0.0, int_pad_value=0, not_sequence=["spembs", "sids"] + float_pad_value=0.0, + int_pad_value=0, + not_sequence=["spembs", "sids", "lids"], ) @classmethod @@ -260,10 +262,25 @@ def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: - retval = ("spembs", "sids", "durations", "pitch", "energy") + retval = ( + "spembs", + "durations", + "pitch", + "energy", + "sids", + "lids", + ) else: # Inference mode - retval = ("spembs", "sids", "speech", "durations", "pitch", "energy") + retval = ( + "spembs", + "speech", + "durations", + "pitch", + "energy", + "sids", + "lids", + ) return retval @classmethod diff --git a/espnet2/tasks/tts.py b/espnet2/tasks/tts.py index df21faf9365..ecaf981773f 100644 --- a/espnet2/tasks/tts.py +++ b/espnet2/tasks/tts.py @@ -260,10 +260,25 @@ def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: - retval = ("spembs", "durations", "pitch", "energy", "sids", "lids") + retval = ( + "spembs", + "durations", + "pitch", + "energy", + "sids", + "lids", + ) else: # Inference mode - retval = ("spembs", "speech", "durations", "sids", "lids") + retval = ( + "spembs", + "speech", + "durations", + "pitch", + "energy", + "sids", + "lids", + ) return retval @classmethod