Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Text-To-Speech pipeline #24952

Merged
merged 37 commits into from
Aug 17, 2023
Merged

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Jul 20, 2023

What does this PR do?

Until recently, there was only one TTS model in Transformers. Recent (Bark) and future (FastSpeechConformer2) additions have and will further enrich the number of TTS models in Transformers.
This may be the best time to add a text-to-speech pipeline to Transformers.

This PR tentatively proposes:

  • The addition of a text-to-speech pipeline whose design could be modified in line with future TTS additions.
  • Add a class AutoModelForTextToSpeech
  • Add a processor task to the pipeline code to facilitate use of the processor.

My conception of the architecture for now:

  • Backward compatibility with FastSpeechConformer2, retaining the ability to use its hacked generate_speech method.
  • Future compatibility with future TTS models, counting on the fact that these models will use a generate method to generate audio.
  • Possible compatibility with other TTA (text-to-audio) models such as MusicGen.

What I'm counting on:

  • future models should have a generate method, even if they are not AR models per se (for the moment, FastSpeechConformer2 is not AR and has no such method) or counts on an additional head model (FastSpeechConformer2 needs a vocoder on top to pass from a spectrogram to an audio - see discussion here).
  • future models will use a Processor even if they only use a tokenizer, to allow easy use of other conditional inputs such as audio or speaker embeddings. And the processor must be added to PROCESSOR_MAPPING (not the case of MusicGen atm).

I'm open to further discuss the architecture and to make some changes!

EDIT: for reference, I've made another design choice following internal discussions. It is discussed here.

Fixes #22487

Note: I was inspired by @LysandreJik draft of a TTS pipeline.

Before submitting

Who can review?

Hey @sanchit-gandhi and @Narsil, I think you're the right people to talk to before the core review!

Copy link
Contributor Author

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some remarks on my code, I'm open to discuss about it!

Comment on lines 1005 to 1012
MODEL_FOR_TEXT_TO_SPEECH_MAPPING_NAMES = OrderedDict(
[
# Model for Text-To-Speech mapping
("bark", "BarkModel"),
("speecht5", "SpeechT5ForTextToSpeech"),
]
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add MusicGen if its processor is added to PROCESSOR_MAPPING.

@@ -156,6 +160,7 @@
"sentiment-analysis": "text-classification",
"ner": "token-classification",
"vqa": "visual-question-answering",
"text-to-audio": "text-to-speech",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pipeline can also be used as a text-to-audio pipeline!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually put it the other way, text-to-audio is more general which we tend to prefer.

That way when the audio procuced becomes music it still works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@osanseviero FYI do we already have a task name for this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will put it the other way around!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far we've said text-to-speech (https://huggingface.co/models?pipeline_tag=text-to-speech&sort=trending), we don't have a more general task name

from transformers import Pipeline, SpeechT5HifiGan


ONLY_ONE_SPEAKER_EMBEDDINGS_LIST = ["bark"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some models can only be used with one speaker embedding even if batching.


ONLY_ONE_SPEAKER_EMBEDDINGS_LIST = ["bark"]

SPEAKER_EMBEDDINGS_KEY_MAPPING = {"bark": "history_prompt"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bark speaker embeddings attribute name was not consistent throughout the input->Pipeline->Model pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be the case for other models as well.

Comment on lines 60 to 66
elif self.is_speecht5 and isinstance(vocoder, str):
vocoder = SpeechT5HifiGan.from_pretrained(vocoder).to(self.model.device)
elif self.is_speecht5 and not isinstance(vocoder, SpeechT5HifiGan):
raise ValueError(
"""Must pass a valid vocoder to the TTSPipeline if speecht5 is used.
Try passing a repo_id or an instance of SpeechT5HifiGan."""
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vocoder should only be used with SpeechT5. Other models should maybe have a specific class with vocoder head.
WDYT?

Comment on lines 89 to 96
def preprocess(self, text, speaker_embeddings=None, **kwargs):
if self.is_speecht5:
inputs = self.processor(text=text, return_tensors="pt")
if speaker_embeddings is None:
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)

return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, only speaker_embeddings is taken into account, but future models could have other conditional inputs.

Comment on lines 111 to 129
def get_test_pipeline(self, model, tokenizer, processor):
speech_generator = TextToSpeechPipeline(model=model, tokenizer=tokenizer)
return speech_generator, ["This is a test", "Another test"]

def run_pipeline_test(self, speech_generator, _):
outputs = speech_generator("This is a test")
self.assertEqual(
outputs,
ANY(torch.Tensor),
)

outputs = speech_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
self.assertEqual(
outputs,
[
ANY(torch.Tensor),
ANY(torch.Tensor),
],
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not tested as long as there are no TTS models on https://huggingface.co/hf-internal-testing.

What should I do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can run the script to create them. So it's the Bark model right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ydshieh , yes indeed! Many thanks!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 20, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

It's already very nice work.

Biggest things:

  • REMOVE everything related to Processor, pipeline cannot rely on any object which doesn't have a standard API. (self.tokenizer(text) should always work, this is not the case for a processor).
  • REMOVE every instance of self.xxx = xxx pipelines are stateless, every arg should be in _sanitize_parameters and expected to be received as a named argument.
  • REMOVE any model specific code. Pipelines should be model agnostic, and rely on invariants of the underlying classes. If those invariants are not correct, we should fix those classes. (generate_speech for instance)

@@ -156,6 +160,7 @@
"sentiment-analysis": "text-classification",
"ner": "token-classification",
"vqa": "visual-question-answering",
"text-to-audio": "text-to-speech",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually put it the other way, text-to-audio is more general which we tend to prefer.

That way when the audio procuced becomes music it still works.

@@ -156,6 +160,7 @@
"sentiment-analysis": "text-classification",
"ner": "token-classification",
"vqa": "visual-question-answering",
"text-to-audio": "text-to-speech",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@osanseviero FYI do we already have a task name for this ?

@@ -398,6 +410,7 @@
NO_FEATURE_EXTRACTOR_TASKS = set()
NO_IMAGE_PROCESSOR_TASKS = set()
NO_TOKENIZER_TASKS = set()
NO_PROCESSOR_TASKS = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NO_PROCESSOR_TASKS = set()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipelines are never using processor. And it cannot.
Processor are handy tools, but they just do too much and have no stable API, you can send text, images, audio and whatnot, with no possibility for the caller to know what is the correct contract (and inspecting the signature is kind of a sin).

FeatureExtractor and Tokenizer should be enough I think.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FeatureExtractor and Tokenizer should be enough I think

The Tokenizer is required for all models to pre-process the text inputs to input ids. The FeatureExtractor allows us to convert an audio prompt to a log-mel spectrogram, so is only used for particular models (not Bark for instance).

Neither prepare the official speaker embeddings, which is only done with the processor (as discussed in the Bark PR)

So the design hurdle here is figuring out how we can prepare the speaker embeddings within the pipeline without the processor class. I feel quite strongly that the pipeline should be able to handle any speaker embeddings internally - if the user has to prepare the speaker embeddings outside of the pipeline, the complexity of using the API is more or less the same as using the model + processor, so there's no real point of switching to using the pipeline

@@ -407,11 +420,18 @@
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
NO_IMAGE_PROCESSOR_TASKS.add(task)
NO_PROCESSOR_TASKS.add(task)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NO_PROCESSOR_TASKS.add(task)

@@ -800,6 +823,7 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None

Comment on lines 117 to 135
if self.is_speecht5:
inputs = model_inputs["input_ids"]

speaker_embeddings = model_inputs["speaker_embeddings"]

with torch.no_grad():
speech = self.model.generate_speech(inputs, speaker_embeddings, vocoder=self.vocoder)
else:
if self.only_one_speaker_embeddings:
speaker_embeddings_key = SPEAKER_EMBEDDINGS_KEY_MAPPING.get(self.model_type, "speaker_embeddings")

# check batch_size > 1
if len(model_inputs["input_ids"]) > 1 and model_inputs.get(speaker_embeddings_key, None) is not None:
model_inputs[speaker_embeddings_key] = model_inputs[speaker_embeddings_key][0]

with torch.no_grad():
speech = self.model.generate(**model_inputs, **kwargs)

return speech
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.is_speecht5:
inputs = model_inputs["input_ids"]
speaker_embeddings = model_inputs["speaker_embeddings"]
with torch.no_grad():
speech = self.model.generate_speech(inputs, speaker_embeddings, vocoder=self.vocoder)
else:
if self.only_one_speaker_embeddings:
speaker_embeddings_key = SPEAKER_EMBEDDINGS_KEY_MAPPING.get(self.model_type, "speaker_embeddings")
# check batch_size > 1
if len(model_inputs["input_ids"]) > 1 and model_inputs.get(speaker_embeddings_key, None) is not None:
model_inputs[speaker_embeddings_key] = model_inputs[speaker_embeddings_key][0]
with torch.no_grad():
speech = self.model.generate(**model_inputs, **kwargs)
return speech
speech = self.model.generate(**model_inputs)
return speech

torch.no_grad is already done for you.
You cannot really inspect the tensors, since batching might be occurring for you.
self.vocoder is not allowed (since self isn't).

This function should be model agnostic. Every model that implements AutoModelForAudioXX (or something along those lines) should have the same API, and that's what the pipeline should rely on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Narsil ,
regarding:

You cannot really inspect the tensors, since batching might be occurring for you.

Batching is taken care for us but it forces additional inputs (for example, speaker_embeddings) to be repeated, which is not the expected behavior.
In that specific case, you have to inspect the tensors to know if there is batching.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can boil down to one if/else depending on if the model does auto-regressive generation, or one forward pass:

if self.model.can_generate:
    speech = self.model.generate(**model_inputs)
else:
    speech = self.model(**model_inputs)

This is somewhat similar to the differentiation we make between CTC and Seq2Seq models in the ASR pipeline:

if self.type in {"seq2seq", "seq2seq_whisper"}:

Let's try and handle the vocoder in the SpeechT5 modelling code, and the speaker embeddings outside of this function (e.g. in the processor)


def _sanitize_parameters(
self,
**generate_kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do no pass generate_kwargs as is.

We're having lots of issues on other pipeliens that did that.

Either whitelist the subset you want to allow or send a raw dict

pipeline(..., generate_kwargs={...}

The issue with doing this is that you're not knowing what params are or are not used, and at any point generate might add support for new kwargs that could clash with some in the pipeline.

This is already the case with max_length which exists for generate but also for the tokenization and have different meaning.

return preprocess_params, forward_params, postprocess_params

def postprocess(self, speech):
return speech
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipelines should return an object useable by non ML people.

I don't know what this object is from this code, but I'm pretty sure it's a tensor with no further information.

We can return a numpy array, but we need some way to communicate to the user how is the audio encoded so that he can save it as a file for instance, so f32le, sampling rate at least.

Copy link
Contributor Author

@ylacombe ylacombe Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil , I actually added a self.sampling_rate argument to __init__.py to allow the user to call pipe.sampling_rate if ever needed to generate or save audio!

I'm not sure about other ways to do it, do you have ideas?

(BTW, I will modify it to return np array !)

@require_torch
def test_small_model_pt(self):
speech_generator = pipeline(
task="text-to-speech", model="microsoft/speecht5_tts", framework="pt", vocoder="microsoft/speecht5_hifigan"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super dubious that we need 2 distinct models to run the pipeline.

I don't know how to fix it, but one model should know about the other or one should be a sound default for all models

Copy link
Contributor Author

@ylacombe ylacombe Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a specificity of speechT5, that we shouldn't allow again for new TTS models, and that's why there are some specific cases on my code!

I still need to address how to deal with speecht5 !

Comment on lines 85 to 87
self.assertEqual(
outputs,
ANY(torch.Tensor),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is testing a specific model, its output is deterministic and therefore we should actually check the output values.

There are some helpers to round the values to account for small uninteresting differences.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is deterministic but it's an audio array so it is way too long to check the output values. I can still do it with the first few values though!

@ylacombe
Copy link
Contributor Author

Hi @Narsil , thanks for your fast review!
Basically, I will refactor my code to meet your expectations !
There are still 2 things I'd like to discuss before and that I talked about in the comments:

  1. speechT5 specific case: speechT5 was introduced 5 months ago, and has two issues - it uses a .generate_speech method instead of a .generate, and it needs an additional vocoder on top in order to actually produce audio signals instead of a spectrogram. What's the best way to stay consistent with it and with the pipeline logic? Should I still introduce model specific code or should I work on modifying speechT5 instead? Modying speechT5 might be problematic since it was the first TTS model introduced so users might be used to its API and because it would leave BarkModel has the only TTS model supported in the pipeline for a short time
  2. speaker_embeddings and other Processor-related utilities: how to stay consistent with the library and continue to use some of the benefits of the Processor or continue to use speaker embeddings in an easy way? I fear that it might add unnecessary difficulties for the users to forward speaker_embeddings arguments, WDYT?

Anyways, many thanks again for the review!

@Narsil
Copy link
Contributor

Narsil commented Jul 21, 2023

Hi @Narsil , thanks for your fast review! Basically, I will refactor my code to meet your expectations ! There are still 2 things I'd like to discuss before and that I talked about in the comments:

1. **`speechT5` specific case:** `speechT5` was introduced 5 months ago, and has two issues - it uses a `.generate_speech` method instead of a `.generate`, and it needs an additional vocoder on top in order to actually produce audio signals instead of a spectrogram. What's the best way to stay consistent with it and with the pipeline logic? Should I still introduce model specific code or should I work on modifying `speechT5` instead?  Modying `speechT5` might be problematic since it was the first TTS model introduced so users might be used to its API and because it would leave `BarkModel` has the only TTS model supported in the pipeline for a short time

I will let maintainers focusing on audio answer to that @sanchit-gandhi I think.
But what I do know is that not relying on invariants within transformers makes pipelines play the never ending game of catch-up for every model thrown into the mix. pipelines see AutoModelFor which should have consistent API which we can rely on.

I remember talks about splitting generate and generate_speech to allow differentation between the 2.

For the vocoder, I don't know how, but it should be invinsible to users.
In ASR we've had ngram being added to the configuration for instance, which makes it loadable automatically.

2. **`speaker_embeddings`** and other `Processor`-related utilities: how to stay consistent with the library and continue to use some of the benefits of the Processor or continue to use speaker embeddings in an easy way? I fear that it might add unnecessary difficulties for the users to forward `speaker_embeddings` arguments, WDYT?

Again, there might already be solutions.
But loading from a random dataset some random data within preprocess is not really sustainable.

My suggestion to put this in usercode alleviates that contraint.

But in general having speaker_embedding for TTS should always be purely optional imo.

Anyways, many thanks again for the review!

@ylacombe
Copy link
Contributor Author

Thanks @Narsil , I will wait for @sanchit-gandhi opinion on it then!
What about this comment ?

nsarrazin added a commit to huggingface/huggingface.js that referenced this pull request Jul 27, 2023
The models I hardcoded were outdated for certain tools so I removed
them. Now the tools will use the model suggestions from the API instead
so we're always up to date.

I left the hardcoded model on `textToSpeech` since the suggested models
in the API don't work for now. (See [this
pr](huggingface/transformers#24952))
Copy link
Contributor Author

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Narsil , @sanchit-gandhi , and everyone interested in the matter!

Following an internal discussion with @Narsil, I've completely reformatted the TTS pipeline to best suit to the current Pipelines paradigm.

Among other things, here are the principal modifications:

  • Refactor Text-To-Speech as Text-To-Audio, since it is more general. Text-To-Speech can still be used as an alias.
  • Introduce two types of TTS models, text to waveform and text to spectrogram. It's better suited to the way TTS models can be classified, and makes it easy to introduce the use of a vocoder for spectrogram models into the pipeline.
  • Completely remove the use of Processors and only use Tokenizer. I will talk about it in further details.

With regards to the use of Tokenizers instead of Processors, I think it is still worth to discuss about it for this pipeline.

Pros:

  • Processors can mix all kinds of inputs and are therefore very difficult to maintain and generalize. In particular, they would introduce a lot of complexity and boundary cases. A simple example of this complexity is the fact that some models simply don't use a Processor (but simply a Tokenizer - e.g. incoming VITS), other models use a Processor with text only and optional speaker embeddings (e.g. Bark), and other models use a Processor for all types of input (e.g. Musicgen with text and audio). Another example of added complexity is that certain models can only take one speaker embedding per batch. Since the pipeline automatically assumes that inputs must be stacked, this feature adds further complexity and if-then conditions. In addition, it is necessary to keep a list of models that follow this feature. A final example is the fact that speaker embeddings already exist in a variety of types (dict, tensor, etc.). New models could probably use other variable types, which would require even more pipeline maintenance.
  • On the other hand, simply using Tokenizers would be straightforward and would fit in well to the text-to-audio pipeline name.
  • Pipelines are assumed to be general and stateless, meaning that a user shouldn't have to worry about which parameters to use and how they're called, nor about how a specific model works under the hood and why only certain models take speaker embeddings as input.

Cons:

  • If users want to use a conditional input parameter such as a speaker embedding or a conditional audio, they would have to figure out how to call them and use them on their own.
  • The pipeline is reduced to its bare minimum (which can also be seen as an advantage). Users may be disappointed, and this may reduce usage.

An example of how to use a speaker embedding for Bark is highlighted in the test_pipelines_text_to_audio.py comments.

I'd personally be in favor of this additional complexity, as users would prefer to easily use the particularities of TTS models, which looks like it could be of greater interest than simply generating bare audios. But I've reworked the code for now, aiming for simplicity, unlike the first version of the pipeline, in order to have a working POC and continue iteration if necessary.

Either ways, it also requires to modify the current TTS models and to think of API architecture guidelines for future TTS models.

In terms of current models:

  • SpeechT5ForTextToSpeech.generate_speech must be refactored to generate, in order to be usable in the current pipeline.
  • TTS/TTA models that use Processor should also be usable calling AutoTokenizer. For some models (e.g Musicgen) it's straight forward by simply adding a line to tokenization_auto.py, for others it requires additional rethinking (see comment about Bark in the test_pipelines_text_to_audio.py comments below).

At the moment, you can try the pipeline with Bark and Musicgen following the examples in test_pipelines_text_to_audio.py.

I would be happy to get your opinion on the matter and to further discuss how to solve the current limitations of the model.

Comment on lines 55 to 71
self.sampling_rate = sampling_rate
if self.vocoder is not None:
self.sampling_rate = self.vocoder.config.sampling_rate

if self.sampling_rate is None:
# get sampling_rate from config and generation config
self.sampling_rate = None

config = self.model.config.to_dict()
gen_config = self.model.__dict__.get("generation_config", None)
if gen_config is not None:
config.update(gen_config.to_dict())

for sampling_rate_name in ["sample_rate", "sampling_rate"]:
sampling_rate = config.get(sampling_rate_name, None)
if sampling_rate is not None:
self.sampling_rate = sampling_rate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_rate is a must-have to get a complete audio signal, since it is absolutely required to generate audio from a waveform.

Thus this code snippet tries to get the sampling rate from the model config or generation config, or from the vocoder if it exists.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that it's a necessity! Good with me to check in the config / generation config for an appropriate value

Comment on lines 101 to 107
preprocess_params = {
"max_length": 256,
"add_special_tokens": False,
"return_attention_mask": True,
"return_token_type_ids": False,
"padding": "max_length",
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, one of the limitation of the TTS pipeline is that the user must do some of the operations that were initially done in the Processor.

For example, BarkProcessor passes this preprocess_params under-the-hood.

I'm not sure yet how to best correct this behaviour, which adds unnecessary and unclear difficulty to the user. Some models can't be used without those parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one way to solve it for Bark would be to create a BarkTokenizer which call BertTokenizer under-the-hood and set these parameters.

Comment on lines 158 to 164
processor = AutoProcessor.from_pretrained("suno/bark-small")

temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5")

history_prompt = temp_inp["history_prompt"]

forward_params["history_prompt"] = history_prompt
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another limitation of the current TTS pipeline is the way a speaker embeddings, or an additional conditional input parameter must be passed.

Here, to use Bark speaker embeddings, one must do a blank call to BarkProcessor.

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks MUCH better I think, no ?

self.sampling_rate is not ideal, but since I see it flows from the vocoder it's perfectly fine (there's no way user can override it at query time anyway.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 3, 2023

I haven't generated the tiny models for bark. I will do it today 🙏 (I can't guarantee it will be able to be generated smoothly - usually they should be already on the Hub, and if not, it means the creation process has some issue for this model)

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 3, 2023

Hi @ylacombe . There are a few issues that blocks the creation of tiny model (of bark for pipeline testing).

The first one is tests/models/bark/test_modeling_bark.py has no BarkModelTest and BarkModelTester. Only the component models (fine, coarse, semantic).

Are those component models also used as standalone models? Or they are really just components for BarkModel and we expect the users to use BarkModel rather than those components?

More importantly, for the pipeline implemented in this PR, which model types do it needs.

Thanks in advance.

@ylacombe
Copy link
Contributor Author

ylacombe commented Aug 3, 2023

Hi @ydshieh, thanks for your help on the matter!

There's only BarkModelIntegrationTests for now. The other sub-models are used as components for BarkModel. Users are expected to use BarkModel, which will ultimately be used in the pipeline.

Let me know if I can help you with anything!

Copy link
Contributor Author

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sgugger , @sanchit-gandhi and @Narsil, this PR is ready for a final review !

There haven't been many changes since my last set of commits, except to respond to @sanchit-gandhi's criticism and to find an easier way to use Bark, as discussed in the comment below!

Many thanks !

Comment on lines +78 to +91
if self.model.config.model_type == "bark":
# bark Tokenizer is called with BarkProcessor which uses those kwargs
new_kwargs = {
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
"add_special_tokens": False,
"return_attention_mask": True,
"return_token_type_ids": False,
"padding": "max_length",
}

# priority is given to kwargs
new_kwargs.update(kwargs)

kwargs = new_kwargs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done to address the limitation that was discussed in this comment. It can be still override by users if they pass the kwargs value that they want, so it is not restrictive.

At the end of the day, the solution mentioned here doesn't work, because in the official Bark repositories, "tokenizer_class": "BertTokenizer" in the tokenizer_config.json.
It would thus be a problem for people using v.4.31 if I changed this for a future BarkTokenizer, so this solution is not possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me since we can't fix for the model / on the Hub, so this is the only way to use the tokenizer for Bark

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Let's keep the code compact in terms of vertical space in the test file please ;-)

@@ -0,0 +1,151 @@
from typing import List, Union

from transformers import Pipeline, SpeechT5HifiGan
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be relative imports.

Comment on lines 110 to 114
def __call__(
self,
text_inputs: Union[str, List[str]],
**forward_params,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: fits in one line.

Comment on lines 59 to 62
self.assertEqual(
ANY(np.ndarray),
audio,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fits in one line.

Comment on lines 143 to 151
# test using a speaker embedding

processor = AutoProcessor.from_pretrained("suno/bark-small")

temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5")

history_prompt = temp_inp["history_prompt"]

forward_params["history_prompt"] = history_prompt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are too many new lines. Only puts new lines when it makes sense to separate blocks of code please.

ANY(np.ndarray),
],
audio,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fits in one line.

Comment on lines 217 to 220
self.assertEqual(
ANY(np.ndarray),
outputs["audio"],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fits in one line.

Comment on lines 222 to 225
forward_params = {
"num_return_sequences": 2,
"do_sample": True,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fits in one line.

Comment on lines 230 to 236
self.assertEqual(
[
ANY(np.ndarray),
ANY(np.ndarray),
],
audio,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fits in one line.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, looking clean! Thanks for iterating on this @ylacombe. Would love a final review from @Narsil to confirm one minor design decision before merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support text-to-speech in pipeline function and in Optimum
7 participants