Skip to content

Commit

Permalink
if output is tuple like facebook/hf-seamless-m4t-medium, waveform is … (
Browse files Browse the repository at this point in the history
#29722)

* if output is tuple like facebook/hf-seamless-m4t-medium, waveform is the first element

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* add test and fix batch issue

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* add dict output support for seamless_m4t

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Apr 5, 2024
1 parent 8b52fa6 commit 79d62b2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3496,7 +3496,6 @@ def generate(
self.device
)
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids

# second generation
unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech)
output_unit_ids = unit_ids.detach().clone()
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/pipelines/pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,12 @@ def __next__(self):
# Try to infer the size of the batch
if isinstance(processed, torch.Tensor):
first_tensor = processed
elif isinstance(processed, tuple):
first_tensor = processed[0]
else:
key = list(processed.keys())[0]
first_tensor = processed[key]

if isinstance(first_tensor, list):
observed_batch_size = len(first_tensor)
else:
Expand All @@ -140,7 +143,7 @@ def __next__(self):
# elements.
self.loader_batch_size = observed_batch_size
# Setting internal index to unwrap the batch
self._loader_batch_data = processed
self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
self._loader_batch_index = 0
return self.loader_batch_item()
else:
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def _sanitize_parameters(

def postprocess(self, waveform):
output_dict = {}

if isinstance(waveform, dict):
waveform = waveform["waveform"]
elif isinstance(waveform, tuple):
waveform = waveform[0]
output_dict["audio"] = waveform.cpu().float().numpy()
output_dict["sampling_rate"] = self.sampling_rate

Expand Down
21 changes: 21 additions & 0 deletions tests/pipelines/test_pipelines_text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,27 @@ def test_small_musicgen_pt(self):
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)

@slow
@require_torch
def test_medium_seamless_m4t_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")

for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]:
outputs = speech_generator("This is a test", forward_params=forward_params)
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)

# test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)

# test batching
outputs = speech_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)

@slow
@require_torch
def test_small_bark_pt(self):
Expand Down

0 comments on commit 79d62b2

Please sign in to comment.