diff --git a/ultravox/tools/ds_tool.py b/ultravox/tools/ds_tool.py index dbb148fd..b4d47c05 100644 --- a/ultravox/tools/ds_tool.py +++ b/ultravox/tools/ds_tool.py @@ -8,8 +8,8 @@ from ultravox.tools import tts -chat_client = openai.Client() -tts_client = tts.AzureTts() +tts_client: tts.Client +chat_client: openai.Client DEFAULT_TEXTGEN_TEMPLATE = """Passage: {passage} @@ -22,14 +22,18 @@ @dataclasses.dataclass class TtsTask: + implementation: str = simple_parsing.field(default="azure", alias="-i") column_name: str = simple_parsing.field(default="question", alias="-c") audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a") voice: Optional[str] = simple_parsing.field(default=None, alias="-V") sample_rate: int = simple_parsing.field(default=16000, alias="-r") def __post_init__(self): + # The TTS client is separate from the task to avoid pickling issues when multiprocessing. + global tts_client if self.audio_column_name is None: self.audio_column_name = f"{self.column_name}_audio" + tts_client = tts.create_client(self.implementation, self.sample_rate) def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset: print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...') @@ -40,7 +44,7 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas def _map_sample(self, sample): text = sample[self.column_name] text = text["text"] if isinstance(text, dict) else text - sample[self.audio_column_name] = tts_client.tts(text) + sample[self.audio_column_name] = tts_client.tts(text, self.voice) return sample @@ -50,10 +54,15 @@ class TextGenerationTask: template: str = simple_parsing.field(default=DEFAULT_TEXTGEN_TEMPLATE, alias="-T") language_model: str = simple_parsing.field(default="gpt-4o", alias="-m") + base_url: Optional[str] = simple_parsing.field(default=None, alias="-b") + api_key: Optional[str] = simple_parsing.field(default=None, alias="-k") max_tokens: int = 128 temperature: float = 0 def __post_init__(self): + # The OAI client is separate from the task to avoid pickling issues when multiprocessing. + global chat_client + chat_client = openai.Client(base_url=self.base_url, api_key=self.api_key) if self.template.startswith("@"): with open(self.template[1:], "r") as template_file: self.template = template_file.read() @@ -75,9 +84,10 @@ def _map_sample(self, sample): # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. -# Ex: just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN -# Ex: just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/boolq-audio -c explanation -# Ex: just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt +# Example usages: +# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN +# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct +# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt @dataclasses.dataclass class DatasetToolArgs: dataset_name: str = simple_parsing.field(alias="-d") @@ -88,7 +98,7 @@ class DatasetToolArgs: num_workers: int = simple_parsing.field(default=16, alias="-w") upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u") - upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-b") + upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B") num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N") private: bool = simple_parsing.field(default=False) diff --git a/ultravox/tools/tts.py b/ultravox/tools/tts.py index ca855d65..3dc690d5 100644 --- a/ultravox/tools/tts.py +++ b/ultravox/tools/tts.py @@ -1,12 +1,15 @@ +import abc import io import os -from typing import Optional +from typing import Any, Dict, Optional from xml.sax import saxutils import numpy as np import requests import soundfile as sf +RANDOM_VOICE_KEY = "random" + def _make_ssml(voice: str, text: str): return f""" @@ -17,19 +20,68 @@ def _make_ssml(voice: str, text: str): """ -class AzureTts: - DEFAULT_VOICE = "en-US-JennyNeural" - - def __init__(self): +class Client(abc.ABC): + def __init__(self, sample_rate: int = 16000): self._session = requests.Session() + self._sample_rate = sample_rate + + @abc.abstractmethod + def tts(self, text: str, voice: Optional[str] = None): + raise NotImplementedError + + def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]): + response = self._session.post(url, headers=headers, json=json) + response.raise_for_status() + return response - def tts(self, text: str, voice: Optional[str] = None, sample_rate: int = 16000): + def _handle_pcm_response(self, response: requests.Response): + pcm_array = np.frombuffer(response.content, dtype=np.int16) + wav_bytes = io.BytesIO() + sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav") + return wav_bytes.getvalue() + + +class AzureTts(Client): + DEFAULT_VOICE = "en-US-JennyNeural" + ALL_VOICES = [ + "en-US-AvaNeural", + "en-US-AndrewNeural", + "en-US-EmmaNeural", + "en-US-BrianNeural", + "en-US-JennyNeural", + "en-US-GuyNeural", + "en-US-AriaNeural", + "en-US-DavisNeural", + "en-US-JaneNeural", + "en-US-JasonNeural", + "en-US-SaraNeural", + "en-US-TonyNeural", + "en-US-NancyNeural", + "en-US-AmberNeural", + "en-US-AnaNeural", + "en-US-AshleyNeural", + "en-US-BrandonNeural", + "en-US-ChristopherNeural", + "en-US-CoraNeural", + "en-US-ElizabethNeural", + "en-US-EricNeural", + "en-US-JacobNeural", + "en-US-MichelleNeural", + "en-US-MonicaNeural", + "en-US-RogerNeural", + ] + + def tts(self, text: str, voice: Optional[str] = None): voice = voice or self.DEFAULT_VOICE + if voice == RANDOM_VOICE_KEY: + voice = np.random.choice(self.ALL_VOICES) + assert voice region = "westus" api_key = os.environ.get("AZURE_TTS_API_KEY") or os.environ.get( "AZURE_WESTUS_TTS_API_KEY" ) - output_format = f"raw-{sample_rate // 1000}khz-16bit-mono-pcm" + assert api_key, "Please set the AZURE_TTS_API_KEY environment variable." + output_format = f"raw-{self._sample_rate // 1000}khz-16bit-mono-pcm" url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" headers = { "Ocp-Apim-Subscription-Key": api_key, @@ -38,10 +90,64 @@ def tts(self, text: str, voice: Optional[str] = None, sample_rate: int = 16000): "User-Agent": "MyTTS", } body = _make_ssml(voice, text) - response = self._session.post(url, headers=headers, data=body) - response.raise_for_status() + return self._handle_pcm_response(self._post(url, headers, body)) - pcm_array = np.frombuffer(response.content, dtype=np.int16) - wav_bytes = io.BytesIO() - sf.write(wav_bytes, pcm_array, sample_rate, format="wav") - return wav_bytes.getvalue() + +class ElevenTts(Client): + DEFAULT_VOICE = "21m00Tcm4TlvDq8ikWAM" + DEFAULT_MODEL = "eleven_multilingual_v2" + ALL_VOICES = [ + "21m00Tcm4TlvDq8ikWAM", + "29vD33N1CtxCmqQRPOHJ", + "2EiwWnXFnvU5JabPnv8n", + "5Q0t7uMcjvnagumLfvZi", + "AZnzlk1XvdvUeBnXmlld", + "CYw3kZ02Hs0563khs1Fj", + "D38z5RcWu1voky8WS1ja", + "EXAVITQu4vr4xnSDxMaL", + "ErXwobaYiN019PkySvjV", + "GBv7mTt0atIp3Br8iCZE", + "IKne3meq5aSn9XLyUdCD", + "JBFqnCBsd6RMkjVDRZzb", + "LcfcDJNUP1GQjkzn1xUU", + "MF3mGyEYCl7XYWbV9V6O", + "N2lVS1w4EtoT3dr4eOWO", + "ODq5zmih8GrVes37Dizd", + "SOYHLrjzK2X1ezoPC6cr", + "TX3LPaxmHKxFdv7VOQHJ", + "ThT5KcBeYPX3keUQqHPh", + "TxGEqnHWrfWFTfGW9XjX", + "VR6AewLTigWG4xSOukaG", + "XB0fDUnXU5powFXDhCwa", + "Xb7hH8MSUJpSbSDYk0k2", + "XrExE9yKIg1WjnnlVkGX", + "ZQe5CZNOzWyzPSCn5a3c", + "Zlb1dXrM653N07WRdFW3", + ] + + def tts(self, text: str, voice: Optional[str] = None): + voice = voice or self.DEFAULT_VOICE + if voice == RANDOM_VOICE_KEY: + # Every process has same random seed, so we mix in the PID here for more variation. + i = np.random.randint(len(self.ALL_VOICES)) + os.getpid() + voice = self.ALL_VOICES[i % len(self.ALL_VOICES)] + url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000" + print("url", url) + headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]} + body = { + "text": text, + "model_id": self.DEFAULT_MODEL, + "voice_settings": { + "stability": 0.5, + "similarity_boost": False, + }, + } + return self._handle_pcm_response(self._post(url, headers, body)) + + +def create_client(implementation: str, sample_rate: int): + if implementation == "azure": + return AzureTts(sample_rate=sample_rate) + elif implementation == "eleven": + return ElevenTts(sample_rate=sample_rate) + raise ValueError(f"Unknown TTS implementation: {implementation}")