-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Poor Audio Quality with input_values Input in Parler_TTS #81
Comments
we should remove |
change the code after this comment: "# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask" if "input_values" in model_kwargs:
mask = (output_ids != generation_config.bos_token_id) & (output_ids != generation_config.pad_token_id)
else:
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id) I haven't looked into any details, for now it works. I found this bug by comparing the output_ids with the original input_ids encoded by dac, and there are some wrong delays in output_ids. |
Credits: 1. ylacombe - Add input_values to DACModel - dac_wrapper/modeling_dac.py - huggingface#110 (comment) 2. stg2015 - Delay mask adjustment for input_values - modeling_parler_tts.py - huggingface#81 (comment)
* Prep for Voice Steering feature Credits: 1. ylacombe - Add input_values to DACModel - dac_wrapper/modeling_dac.py - #110 (comment) 2. stg2015 - Delay mask adjustment for input_values - modeling_parler_tts.py - #81 (comment) * Prep for voice steering/cloning w/ fix for non-streaming generation * Applied simpler input handling per Guppy16's suggestion * Applied Guppy16's suggested optimization * Applied Guppy17's suggested optimization for voice steering * Update parler_tts/modeling_parler_tts.py --------- Co-authored-by: apresence <apresence@gmail.com> Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
* Prep for Voice Steering feature Credits: 1. ylacombe - Add input_values to DACModel - dac_wrapper/modeling_dac.py - huggingface/parler-tts#110 (comment) 2. stg2015 - Delay mask adjustment for input_values - modeling_parler_tts.py - huggingface/parler-tts#81 (comment) * Prep for voice steering/cloning w/ fix for non-streaming generation * Applied simpler input handling per Guppy16's suggestion * Applied Guppy16's suggested optimization * Applied Guppy17's suggested optimization for voice steering * Update parler_tts/modeling_parler_tts.py --------- Co-authored-by: apresence <apresence@gmail.com> Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
I am using the Parler_TTS model with a reference audio (
input_values
) during inference, similar to MusicGen, to perform continuation tasks.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids, input_values=input_values)
While the model continues in the style of the reference audio, the resulting audio quality is poor and contains a lot of noise.
Why does the audio quality degrade when using a reference audio input, and how can this be improved?
Thank you!
The text was updated successfully, but these errors were encountered: