diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 36691dc4..c3aa0dda 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -104,20 +104,24 @@ def _map_sample(self, sample): # just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt @dataclasses.dataclass class DatasetToolArgs: + # HF source dataset parameters dataset_name: str = simple_parsing.field(alias="-d") dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S") dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s") + # Local processing parameters shuffle: bool = simple_parsing.field(default=False) shuffle_seed: int = simple_parsing.field(default=42) num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n") num_workers: int = simple_parsing.field(default=16, alias="-w") + # HF destination dataset parameters upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u") + # eg if the original split="train", but we want to upload it as "validation" + upload_split: Optional[str] = simple_parsing.field(default=None) upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B") num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N") private: bool = simple_parsing.field(default=False) - token: Optional[str] = None task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups( @@ -126,6 +130,11 @@ class DatasetToolArgs: positional=True, ) + def __post_init__(self): + assert ( + not self.upload_split or self.dataset_split + ), "Must specify dataset_split when using upload_split" + def main(args: DatasetToolArgs): ds_name = args.dataset_name @@ -150,6 +159,7 @@ def main(args: DatasetToolArgs): "token": token, "revision": args.upload_branch, "private": args.private, + "split": args.upload_split, } if args.num_shards is not None: hub_args["num_shards"] = {split: args.num_shards for split in data_dict.keys()} diff --git a/ultravox/tools/ds_tool/tts.py b/ultravox/tools/ds_tool/tts.py index e6b8b9e4..78e11c1c 100644 --- a/ultravox/tools/ds_tool/tts.py +++ b/ultravox/tools/ds_tool/tts.py @@ -9,6 +9,8 @@ import soundfile as sf RANDOM_VOICE_KEY = "random" +REQUEST_TIMEOUT = 30 +NUM_RETRIES = 3 def _make_ssml(voice: str, text: str): @@ -23,6 +25,10 @@ def _make_ssml(voice: str, text: str): class Client(abc.ABC): def __init__(self, sample_rate: int = 16000): self._session = requests.Session() + retries = requests.adapters.Retry(total=NUM_RETRIES) + self._session.mount( + "https://", requests.adapters.HTTPAdapter(max_retries=retries) + ) self._sample_rate = sample_rate @abc.abstractmethod @@ -30,7 +36,9 @@ def tts(self, text: str, voice: Optional[str] = None) -> bytes: raise NotImplementedError def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]): - response = self._session.post(url, headers=headers, json=json) + response = self._session.post( + url, headers=headers, json=json, timeout=REQUEST_TIMEOUT + ) response.raise_for_status() return response diff --git a/ultravox/tools/infer_tool.py b/ultravox/tools/infer_tool.py index 0aa9427c..cf8b21f8 100644 --- a/ultravox/tools/infer_tool.py +++ b/ultravox/tools/infer_tool.py @@ -78,6 +78,11 @@ class InferArgs: # JSON output json: bool = simple_parsing.field(default=False) + def __post_init__(self): + if self.prompt and self.prompt.startswith("@"): + with open(self.prompt[1:], "r") as f: + self.prompt = f.read() + def run_tui( index: int,