Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference API for 🐶Bark #2685

Merged
merged 15 commits into from
Jun 28, 2023
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ disable=missing-docstring,
comprehension-escape,
duplicate-code,
not-callable,
import-outside-toplevel
import-outside-toplevel,
logging-fstring-interpolation,
logging-not-lazy

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
159 changes: 86 additions & 73 deletions TTS/.models.json

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def list_models():

def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
Expand Down Expand Up @@ -584,6 +584,8 @@ def tts_to_file(
Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
file_path (str, optional):
Output file path. Defaults to "output.wav".
kwargs (dict, optional):
Additional arguments for the model.
"""
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def main():
vc_config_path = config_path

# tts model with multiple files to be loaded from the directory path
if model_item.get("author", None) == "fairseq" or isinstance(model_item["github_rls_url"], list):
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
model_dir = model_path
tts_path = None
tts_config_path = None
Expand Down
105 changes: 105 additions & 0 deletions TTS/tts/configs/bark_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
from dataclasses import dataclass
from typing import Dict

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPTConfig
from TTS.tts.models.bark import BarkAudioConfig
from TTS.utils.generic_utils import get_user_data_dir


@dataclass
class BarkConfig(BaseTTSConfig):
"""Bark TTS configuration

Args:
model (str): model name that registers the model.
audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig().
num_chars (int): number of characters in the alphabet. Defaults to 0.
semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig().
fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig().
coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig().
CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024.
SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9.
SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000.
CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024.
N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2.
N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8.
COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75.
SAMPLE_RATE (int): sample rate. Defaults to 24_000.
USE_SMALLER_MODELS (bool): use smaller models. Defaults to False.
TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048.
SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000.
TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048.
TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049.
TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050.
SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051.
COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048.
COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050.
REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree".
REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None.
LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None.
SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None.
CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir().
DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir().
"""

model: str = "bark"
audio: BarkAudioConfig = BarkAudioConfig()
num_chars: int = 0
semantic_config: GPTConfig = GPTConfig()
fine_config: FineGPTConfig = FineGPTConfig()
coarse_config: GPTConfig = GPTConfig()
CONTEXT_WINDOW_SIZE: int = 1024
SEMANTIC_RATE_HZ: float = 49.9
SEMANTIC_VOCAB_SIZE: int = 10_000
CODEBOOK_SIZE: int = 1024
N_COARSE_CODEBOOKS: int = 2
N_FINE_CODEBOOKS: int = 8
COARSE_RATE_HZ: int = 75
SAMPLE_RATE: int = 24_000
USE_SMALLER_MODELS: bool = False

TEXT_ENCODING_OFFSET: int = 10_048
SEMANTIC_PAD_TOKEN: int = 10_000
TEXT_PAD_TOKEN: int = 129_595
SEMANTIC_INFER_TOKEN: int = 129_599
COARSE_SEMANTIC_PAD_TOKEN: int = 12_048
COARSE_INFER_TOKEN: int = 12_050

REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/"
REMOTE_MODEL_PATHS: Dict = None
LOCAL_MODEL_PATHS: Dict = None
SMALL_REMOTE_MODEL_PATHS: Dict = None
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0"))
DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers"))

def __post_init__(self):
self.REMOTE_MODEL_PATHS = {
"text": {
"path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}
self.LOCAL_MODEL_PATHS = {
"text": os.path.join(self.CACHE_DIR, "text_2.pt"),
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"),
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"),
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"),
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"),
}
self.SMALL_REMOTE_MODEL_PATHS = {
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")},
"coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")},
"fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")},
}
self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init
Empty file added TTS/tts/layers/bark/__init__.py
Empty file.
Empty file.
35 changes: 35 additions & 0 deletions TTS/tts/layers/bark/hubert/hubert_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer

import os.path
import shutil
import urllib.request

import huggingface_hub


class HubertManager:
@staticmethod
def make_sure_hubert_installed(
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
):
if not os.path.isfile(model_path):
print("Downloading HuBERT base model")
urllib.request.urlretrieve(download_url, model_path)
print("Downloaded HuBERT")
return model_path
return None

@staticmethod
def make_sure_tokenizer_installed(
model: str = "quantifier_hubert_base_ls960_14.pth",
repo: str = "GitMylo/bark-voice-cloning",
model_path: str = "",
):
model_dir = os.path.dirname(model_path)
if not os.path.isfile(model_path):
print("Downloading HuBERT custom tokenizer")
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
shutil.move(os.path.join(model_dir, model), model_path)
print("Downloaded tokenizer")
return model_path
return None
101 changes: 101 additions & 0 deletions TTS/tts/layers/bark/hubert/kmeans_hubert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Modified HuBERT model without kmeans.
Original author: https://github.com/lucidrains/
Modified by: https://www.github.com/gitmylo/
License: MIT
"""

# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py

import logging
from pathlib import Path

import fairseq
import torch
from einops import pack, unpack
from torch import nn
from torchaudio.functional import resample

logging.root.setLevel(logging.ERROR)


def round_down_nearest_multiple(num, divisor):
return num // divisor * divisor


def curtail_to_multiple(t, mult, from_left=False):
data_len = t.shape[-1]
rounded_seq_len = round_down_nearest_multiple(data_len, mult)
seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
return t[..., seq_slice]


def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d


class CustomHubert(nn.Module):
"""
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
or you can train your own
"""

def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer

if device is not None:
self.to(device)

model_path = Path(checkpoint_path)

assert model_path.exists(), f"path {checkpoint_path} does not exist"

checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

if device is not None:
model[0].to(device)

self.model = model[0]
self.model.eval()

@property
def groups(self):
return 1

@torch.no_grad()
def forward(self, wav_input, flatten=True, input_sample_hz=None):
device = wav_input.device

if exists(input_sample_hz):
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

embed = self.model(
wav_input,
features_only=True,
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
output_layer=self.output_layer,
)

embed, packed_shape = pack([embed["x"]], "* d")

# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())

codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()

if flatten:
return codebook_indices

(codebook_indices,) = unpack(codebook_indices, packed_shape, "*")
return codebook_indices
Loading