Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes to ds_tool and infer_tool #36

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,24 @@ def _map_sample(self, sample):
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
@dataclasses.dataclass
class DatasetToolArgs:
# HF source dataset parameters
dataset_name: str = simple_parsing.field(alias="-d")
dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S")
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s")

# Local processing parameters
shuffle: bool = simple_parsing.field(default=False)
shuffle_seed: int = simple_parsing.field(default=42)
num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n")
num_workers: int = simple_parsing.field(default=16, alias="-w")

# HF destination dataset parameters
upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u")
# eg if the original split="train", but we want to upload it as "validation"
upload_split: Optional[str] = simple_parsing.field(default=None)
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B")
num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N")
private: bool = simple_parsing.field(default=False)

token: Optional[str] = None

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
Expand All @@ -126,6 +130,11 @@ class DatasetToolArgs:
positional=True,
)

def __post_init__(self):
assert (
not self.upload_split or self.dataset_split
), "Must specify dataset_split when using upload_split"
Comment on lines +134 to +136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't catch this the first time, but why can't upload_split be equal to dataset_split by default if not specified?



def main(args: DatasetToolArgs):
ds_name = args.dataset_name
Expand All @@ -150,6 +159,7 @@ def main(args: DatasetToolArgs):
"token": token,
"revision": args.upload_branch,
"private": args.private,
"split": args.upload_split,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, you didn't fix the conflict here!
Now on line 170 you'll be sending the split argument twice, and with different values!

}
if args.num_shards is not None:
hub_args["num_shards"] = {split: args.num_shards for split in data_dict.keys()}
Expand Down
10 changes: 9 additions & 1 deletion ultravox/tools/ds_tool/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import soundfile as sf

RANDOM_VOICE_KEY = "random"
REQUEST_TIMEOUT = 30
NUM_RETRIES = 3


def _make_ssml(voice: str, text: str):
Expand All @@ -23,14 +25,20 @@ def _make_ssml(voice: str, text: str):
class Client(abc.ABC):
def __init__(self, sample_rate: int = 16000):
self._session = requests.Session()
retries = requests.adapters.Retry(total=NUM_RETRIES)
self._session.mount(
"https://", requests.adapters.HTTPAdapter(max_retries=retries)
)
self._sample_rate = sample_rate

@abc.abstractmethod
def tts(self, text: str, voice: Optional[str] = None) -> bytes:
raise NotImplementedError

def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]):
response = self._session.post(url, headers=headers, json=json)
response = self._session.post(
url, headers=headers, json=json, timeout=REQUEST_TIMEOUT
)
response.raise_for_status()
return response

Expand Down
5 changes: 5 additions & 0 deletions ultravox/tools/infer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class InferArgs:
# JSON output
json: bool = simple_parsing.field(default=False)

def __post_init__(self):
if self.prompt and self.prompt.startswith("@"):
with open(self.prompt[1:], "r") as f:
self.prompt = f.read()


def run_tui(
index: int,
Expand Down
Loading