From 9be1e17028bf2f5f9af6fa1fcfb18745cfddf366 Mon Sep 17 00:00:00 2001 From: juberti Date: Sat, 22 Jun 2024 18:59:30 -0700 Subject: [PATCH 1/2] Minor fixes to ds_tool and infer_tool - --upload_split param to allow the dest split to be different than the src split - allow @file syntax for --prompt - add retries and timeouts to TTS requests --- ultravox/tools/ds_tool.py | 7 +++++++ ultravox/tools/infer_tool.py | 5 +++++ ultravox/tools/tts.py | 11 +++++++++-- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ultravox/tools/ds_tool.py b/ultravox/tools/ds_tool.py index b4d47c05..a25a91f5 100644 --- a/ultravox/tools/ds_tool.py +++ b/ultravox/tools/ds_tool.py @@ -98,6 +98,7 @@ class DatasetToolArgs: num_workers: int = simple_parsing.field(default=16, alias="-w") upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u") + upload_split: Optional[str] = simple_parsing.field(default=None) upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B") num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N") private: bool = simple_parsing.field(default=False) @@ -110,6 +111,11 @@ class DatasetToolArgs: positional=True, ) + def __post_init__(self): + assert ( + not self.upload_split or self.dataset_split + ), "Must specify dataset_split when using upload_split" + def main(args: DatasetToolArgs): ds_name = args.dataset_name @@ -132,6 +138,7 @@ def main(args: DatasetToolArgs): "token": token, "revision": args.upload_branch, "private": args.private, + "split": args.upload_split, } if args.num_shards is not None: hub_args["num_shards"] = {split: args.num_shards for split in data_dict.keys()} diff --git a/ultravox/tools/infer_tool.py b/ultravox/tools/infer_tool.py index 0aa9427c..cf8b21f8 100644 --- a/ultravox/tools/infer_tool.py +++ b/ultravox/tools/infer_tool.py @@ -78,6 +78,11 @@ class InferArgs: # JSON output json: bool = simple_parsing.field(default=False) + def __post_init__(self): + if self.prompt and self.prompt.startswith("@"): + with open(self.prompt[1:], "r") as f: + self.prompt = f.read() + def run_tui( index: int, diff --git a/ultravox/tools/tts.py b/ultravox/tools/tts.py index 3dc690d5..b36a5fde 100644 --- a/ultravox/tools/tts.py +++ b/ultravox/tools/tts.py @@ -9,6 +9,8 @@ import soundfile as sf RANDOM_VOICE_KEY = "random" +REQUEST_TIMEOUT = 30 +NUM_RETRIES = 3 def _make_ssml(voice: str, text: str): @@ -23,6 +25,10 @@ def _make_ssml(voice: str, text: str): class Client(abc.ABC): def __init__(self, sample_rate: int = 16000): self._session = requests.Session() + retries = requests.adapters.Retry(total=NUM_RETRIES) + self._session.mount( + "https://", requests.adapters.HTTPAdapter(max_retries=retries) + ) self._sample_rate = sample_rate @abc.abstractmethod @@ -30,7 +36,9 @@ def tts(self, text: str, voice: Optional[str] = None): raise NotImplementedError def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]): - response = self._session.post(url, headers=headers, json=json) + response = self._session.post( + url, headers=headers, json=json, timeout=REQUEST_TIMEOUT + ) response.raise_for_status() return response @@ -132,7 +140,6 @@ def tts(self, text: str, voice: Optional[str] = None): i = np.random.randint(len(self.ALL_VOICES)) + os.getpid() voice = self.ALL_VOICES[i % len(self.ALL_VOICES)] url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000" - print("url", url) headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]} body = { "text": text, From 8f0c7cbf6c9b3f0dcd066012f0084a47f7e9ac24 Mon Sep 17 00:00:00 2001 From: juberti Date: Mon, 24 Jun 2024 15:46:01 -0700 Subject: [PATCH 2/2] docs --- ultravox/tools/ds_tool/ds_tool.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index c1ea2243..c3aa0dda 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -104,21 +104,24 @@ def _map_sample(self, sample): # just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt @dataclasses.dataclass class DatasetToolArgs: + # HF source dataset parameters dataset_name: str = simple_parsing.field(alias="-d") dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S") dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s") + # Local processing parameters shuffle: bool = simple_parsing.field(default=False) shuffle_seed: int = simple_parsing.field(default=42) num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n") num_workers: int = simple_parsing.field(default=16, alias="-w") + # HF destination dataset parameters upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u") + # eg if the original split="train", but we want to upload it as "validation" upload_split: Optional[str] = simple_parsing.field(default=None) upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B") num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N") private: bool = simple_parsing.field(default=False) - token: Optional[str] = None task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(