Skip to content

Commit

Permalink
refactor(stream_generator): update special tokens for transformers>=4…
Browse files Browse the repository at this point in the history
….41.1

Fixes #31. The handling of special tokens in `transformers` was changed in
huggingface/transformers#30624 and
huggingface/transformers#30746. This updates the XTTS
streaming code accordingly.
  • Loading branch information
eginhard committed Jun 16, 2024
1 parent 063e9e9 commit 47aa14a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
18 changes: 5 additions & 13 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,7 @@ def generate( # noqa: PLR0911
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

# 3. Define model inputs
# inputs_tensor has to be defined
Expand All @@ -174,6 +163,9 @@ def generate( # noqa: PLR0911
)
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
Expand All @@ -182,7 +174,7 @@ def generate( # noqa: PLR0911
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ dependencies = [
"gruut[de,es,fr]==2.2.3",
# Tortoise
"einops>=0.6.0",
"transformers>=4.33.0,<4.41.0",
"transformers>=4.41.1",
# Bark
"encodec>=0.1.1",
# XTTS
Expand Down

0 comments on commit 47aa14a

Please sign in to comment.