diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index cb09895824..efc92a04ef 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -376,7 +376,7 @@ def generate( # noqa: PLR0911 elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -401,7 +401,7 @@ def generate( # noqa: PLR0911 ) elif is_sample_gen_stream_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -463,7 +463,7 @@ def generate( # noqa: PLR0911 elif is_beam_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") diff --git a/pyproject.toml b/pyproject.toml index fc748ff46b..a8c52fc176 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ dependencies = [ "gruut[de,es,fr]==2.2.3", # Tortoise "einops>=0.6.0", - "transformers>=4.41.1", + "transformers>=4.42.0", # Bark "encodec>=0.1.1", # XTTS