diff --git a/TTS/.models.json b/TTS/.models.json index a893f708f1..ba7b5f6289 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -5,12 +5,12 @@ "xtts_v1": { "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.", "hf_url": [ - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json" + "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", + "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json", + "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json" ], "default_vocoder": null, - "commit": "e9a1953e", + "commit": "e5140314", "license": "CPML", "contact": "info@coqui.ai", "tos_required": true diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 2a821a5d00..88ce100c72 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -172,7 +172,7 @@ def get_grad_norm_parameter_groups(self): "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), } - def init_gpt_for_inference(self, kv_cache=True): + def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False): seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 gpt_config = GPT2Config( vocab_size=self.max_mel_tokens, @@ -195,6 +195,17 @@ def init_gpt_for_inference(self, kv_cache=True): ) self.gpt.wte = self.mel_embedding + if use_deepspeed: + import deepspeed + self.ds_engine = deepspeed.init_inference( + model=self.gpt_inference.half(), # Transformers models + mp_size=1, # Number of GPU + dtype=torch.float32, # desired data type of output + replace_method="auto", # Lets DS autmatically identify the layer to replace + replace_with_kernel_inject=True, # replace the model with the kernel injector + ) + self.gpt_inference = self.ds_engine.module.eval() + def set_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1, 0), value=start_token) tar = F.pad(input, (0, 1), value=stop_token) @@ -543,3 +554,14 @@ def generate( if "return_dict_in_generate" in hf_generate_kwargs: return gen.sequences[:, gpt_inputs.shape[1] :], gen return gen[:, gpt_inputs.shape[1] :] + + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.gpt_inference.generate_stream( + fake_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, + do_stream=True, + **hf_generate_kwargs, + ) diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py new file mode 100644 index 0000000000..6439b455a0 --- /dev/null +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -0,0 +1,742 @@ +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm +import torchaudio + +from TTS.utils.io import load_fsspec + + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + cond_in_each_up_layer=False, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + self.cond_in_each_up_layer = cond_in_each_up_layer + + # initial upsampling layers + self.conv_pre = weight_norm( + Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm( + Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias) + ) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(cond_channels, ch, 1)) + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + + if self.cond_in_each_up_layer: + o = o + self.conds[i](g) + + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad( + c, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.forward(c) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + + +class ResNetSpeakerEncoder(nn.Module): + """This is copied from 🐸TTS to remove it from the dependencies. + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + def load_checkpoint( + self, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + print(" > Model fully restored. ") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"]) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion + +class HifiDecoder(torch.nn.Module): + def __init__( + self, + input_sample_rate=22050, + output_sample_rate=24000, + output_hop_length=256, + ar_mel_length_compression=1024, + decoder_input_dim=1024, + resblock_type_decoder="1", + resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_kernel_sizes_decoder=[3, 7, 11], + upsample_rates_decoder=[8, 8, 2, 2], + upsample_initial_channel_decoder=512, + upsample_kernel_sizes_decoder=[16, 16, 4, 4], + d_vector_dim=512, + cond_d_vector_in_each_upsampling_layer=True, + speaker_encoder_audio_config={ + "fft_size": 512, + "win_length": 400, + "hop_length": 160, + "sample_rate": 16000, + "preemphasis": 0.97, + "num_mels": 64, + }, + ): + super().__init__() + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + self.output_hop_length = output_hop_length + self.ar_mel_length_compression = ar_mel_length_compression + self.speaker_encoder_audio_config = speaker_encoder_audio_config + self.waveform_decoder = HifiganGenerator( + decoder_input_dim, + 1, + resblock_type_decoder, + resblock_dilation_sizes_decoder, + resblock_kernel_sizes_decoder, + upsample_kernel_sizes_decoder, + upsample_initial_channel_decoder, + upsample_rates_decoder, + inference_padding=0, + cond_channels=d_vector_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, + ) + self.speaker_encoder = ResNetSpeakerEncoder( + input_dim=64, + proj_dim=512, + log_input=True, + use_torch_spec=True, + audio_config=speaker_encoder_audio_config, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, latents, g=None): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + + z = torch.nn.functional.interpolate( + latents.transpose(1, 2), + scale_factor=[self.ar_mel_length_compression / self.output_hop_length], + mode="linear", + ).squeeze(1) + # upsample to the right sr + if self.output_sample_rate != self.input_sample_rate: + z = torch.nn.functional.interpolate( + z, + scale_factor=[self.output_sample_rate / self.input_sample_rate], + mode="linear", + ).squeeze(0) + o = self.waveform_decoder(z, g=g) + return o + + @torch.no_grad() + def inference(self, c, g): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + return self.forward(c, g=g) + + def load_checkpoint( + self, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + # remove unused keys + state = state["model"] + states_keys = list(state.keys()) + for key in states_keys: + if "waveform_decoder." not in key and "speaker_encoder." not in key: + del state[key] + + self.load_state_dict(state) + if eval: + self.eval() + assert not self.training + self.waveform_decoder.remove_weight_norm() diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py new file mode 100644 index 0000000000..8bdd2291ff --- /dev/null +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -0,0 +1,1057 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +from transformers import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, + DisjunctiveConstraint, + BeamSearchScorer, + PhrasalConstraint, + ConstrainedBeamSearchScorer, + PreTrainedModel, +) +import numpy as np +import random +import warnings +import inspect +from transformers.generation.utils import GenerateOutput, SampleOutput, logger +import torch +from typing import Callable, List, Optional, Union +from torch import nn +import torch.distributed as dist +import copy + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + #setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config( + self.config + ) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + elif ( + not has_default_max_length and generation_config.max_new_tokens is not None + ): + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None + or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`" + ) + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams + * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + has_default_typical_p = ( + kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + ) + if not has_default_typical_p: + raise ValueError( + "Decoder argument `typical_p` is not supported with beam groups." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + if generation_config.num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained generation." + ) + + if generation_config.do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation." + ) + + if ( + generation_config.num_beam_groups is not None + and generation_config.num_beam_groups > 1 + ): + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation." + ) + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long() + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import PreTrainedModel + from transformers import AutoTokenizer, AutoModelForCausalLM + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-560m", torch_dtype=torch.float16 + ) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer( + prompt_text, return_tensors="pt", add_special_tokens=False + ).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 8dd81facab..a279528925 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -224,7 +224,10 @@ def preprocess_text(self, txt, lang): txt = " ".join([result["kana"] for result in results]) txt = basic_cleaners(txt) elif lang == "en": + if txt[:4] == "[en]": + txt = txt[4:] txt = english_cleaners(txt) + txt = "[en]" + txt elif lang == "ar": txt = arabic_cleaners(txt) elif lang == "zh-cn": diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index a23a0f5f65..2b48074477 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.vocoder import UnivNetGenerator +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder +from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec +init_stream_support() def load_audio(audiopath, sr=22050): """ @@ -195,13 +198,12 @@ class XttsArgs(Coqpit): Args: gpt_batch_size (int): The size of the auto-regressive batch. enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. - lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False. kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. - vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True. For GPT model: ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. @@ -231,12 +233,12 @@ class XttsArgs(Coqpit): gpt_batch_size: int = 1 enable_redaction: bool = False - lazy_load: bool = True kv_cache: bool = True gpt_checkpoint: str = None clvp_checkpoint: str = None decoder_checkpoint: str = None num_chars: int = 255 + use_hifigan: bool = True # XTTS GPT Encoder params tokenizer_file: str = "" @@ -266,6 +268,15 @@ class XttsArgs(Coqpit): diff_layer_drop: int = 0 diff_unconditioned_percentage: int = 0 + # HifiGAN Decoder params + input_sample_rate: int = 22050 + output_sample_rate: int = 24000 + output_hop_length: int = 256 + ar_mel_length_compression: int = 1024 + decoder_input_dim: int = 1024 + d_vector_dim: int = 512 + cond_d_vector_in_each_upsampling_layer: bool = True + # constants duration_const: int = 102400 @@ -285,7 +296,6 @@ class Xtts(BaseTTS): def __init__(self, config: Coqpit): super().__init__(config, ap=None, tokenizer=None) - self.lazy_load = self.args.lazy_load self.mel_stats_path = None self.config = config self.gpt_checkpoint = self.args.gpt_checkpoint @@ -295,7 +305,6 @@ def __init__(self, config: Coqpit): self.tokenizer = VoiceBpeTokenizer() self.gpt = None - self.diffusion_decoder = None self.init_models() self.register_buffer("mel_stats", torch.ones(80)) @@ -322,40 +331,39 @@ def init_models(self): stop_audio_token=self.args.gpt_stop_audio_token, ) - self.diffusion_decoder = DiffusionTts( - model_channels=self.args.diff_model_channels, - num_layers=self.args.diff_num_layers, - in_channels=self.args.diff_in_channels, - out_channels=self.args.diff_out_channels, - in_latent_channels=self.args.diff_in_latent_channels, - in_tokens=self.args.diff_in_tokens, - dropout=self.args.diff_dropout, - use_fp16=self.args.diff_use_fp16, - num_heads=self.args.diff_num_heads, - layer_drop=self.args.diff_layer_drop, - unconditioned_percentage=self.args.diff_unconditioned_percentage, - ) - self.vocoder = UnivNetGenerator() + if self.args.use_hifigan: + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.ar_mel_length_compression, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) + + else: + self.diffusion_decoder = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ) + self.vocoder = UnivNetGenerator() @property def device(self): return next(self.parameters()).device - @contextmanager - def lazy_load_model(self, model): - """Context to load a model on demand. - - Args: - model (nn.Module): The model to be loaded. - """ - if self.lazy_load: - yield model - else: - m = model.to(self.device) - yield m - m = model.cpu() - + @torch.inference_mode() def get_gpt_cond_latents(self, audio_path: str, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. @@ -370,6 +378,7 @@ def get_gpt_cond_latents(self, audio_path: str, length: int = 3): cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) return cond_latent.transpose(1, 2) + @torch.inference_mode() def get_diffusion_cond_latents( self, audio_path, @@ -389,20 +398,33 @@ def get_diffusion_cond_latents( ) diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) - with self.lazy_load_model(self.diffusion_decoder) as diffusion: - diffusion_latent = diffusion.get_conditioning(diffusion_conds) + diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds) return diffusion_latent + @torch.inference_mode() + def get_speaker_embedding( + self, + audio_path + ): + audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) + speaker_embedding = self.hifigan_decoder.speaker_encoder.forward( + audio.to(self.device), l2_norm=True + ).unsqueeze(-1).to(self.device) + return speaker_embedding + def get_conditioning_latents( self, audio_path, gpt_cond_len=3, - ): + ): + speaker_embedding = None + diffusion_cond_latents = None + if self.args.use_hifigan: + speaker_embedding = self.get_speaker_embedding(audio_path) + else: + diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] - diffusion_cond_latents = self.get_diffusion_cond_latents( - audio_path, - ) - return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device) + return gpt_cond_latents, diffusion_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. @@ -447,10 +469,10 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs "decoder_sampler": config.decoder_sampler, } settings.update(kwargs) # allow overriding of preset settings with kwargs - return self.inference(text, ref_audio_path, language, **settings) + return self.full_inference(text, ref_audio_path, language, **settings) - @torch.no_grad() - def inference( + @torch.inference_mode() + def full_inference( self, text, ref_audio_path, @@ -525,81 +547,202 @@ def inference( Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ - text = f"[{language}]{text.strip().lower()}" - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) - - assert ( - text_tokens.shape[-1] < self.args.gpt_max_text_tokens - ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - ( gpt_cond_latent, diffusion_conditioning, + speaker_embedding ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) - - diffuser = load_discrete_vocoder_diffuser( - desired_diffusion_steps=decoder_iterations, + return self.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + decoder_iterations=decoder_iterations, cond_free=cond_free, cond_free_k=cond_free_k, - sampler=decoder_sampler, + diffusion_temperature=diffusion_temperature, + decoder_sampler=decoder_sampler, + **hf_generate_kwargs, ) + + @torch.inference_mode() + def inference( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + # GPT inference + temperature=0.65, + length_penalty=1, + repetition_penalty=2.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Decoder inference + decoder_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + decoder_sampler="ddim", + **hf_generate_kwargs, + ): + text = f"[{language}]{text.strip().lower()}" + text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) - with torch.no_grad(): - self.gpt = self.gpt.to(self.device) - with self.lazy_load_model(self.gpt) as gpt: - gpt_codes = gpt.generate( - cond_latents=gpt_cond_latent, - text_inputs=text_tokens, - input_tokens=None, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=self.gpt_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - output_attentions=False, - **hf_generate_kwargs, - ) + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - with self.lazy_load_model(self.gpt) as gpt: - expected_output_len = torch.tensor( - [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device - ) - text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) - gpt_latents = gpt( - text_tokens, - text_len, - gpt_codes, - expected_output_len, - cond_latents=gpt_cond_latent, - return_attentions=False, - return_latent=True, - ) - silence_token = 83 - ctokens = 0 - for k in range(gpt_codes.shape[-1]): - if gpt_codes[0, k] == silence_token: - ctokens += 1 - else: - ctokens = 0 - if ctokens > 8: - gpt_latents = gpt_latents[:, :k] - break - - with self.lazy_load_model(self.diffusion_decoder) as diffusion: + if not self.args.use_hifigan: + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=decoder_iterations, + cond_free=cond_free, + cond_free_k=cond_free_k, + sampler=decoder_sampler, + ) + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + silence_token = 83 + ctokens = 0 + for k in range(gpt_codes.shape[-1]): + if gpt_codes[0, k] == silence_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: + gpt_latents = gpt_latents[:, :k] + break + + if self.args.use_hifigan: + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + else: mel = do_spectrogram_diffusion( - diffusion, + self.diffusion_decoder, diffuser, gpt_latents, diffusion_conditioning, temperature=diffusion_temperature, ) - with self.lazy_load_model(self.vocoder) as vocoder: - wav = vocoder.inference(mel) + wav = self.vocoder.inference(mel) return {"wav": wav.cpu().numpy().squeeze()} + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + @torch.inference_mode() + def inference_stream( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # Streaming + stream_chunk_size=20, + overlap_wav_len=1024, + # GPT inference + temperature=0.65, + length_penalty=1, + repetition_penalty=2.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Decoder inference + **hf_generate_kwargs, + ): + assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." + text = f"[{language}]{text.strip().lower()}" + text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk + def forward(self): raise NotImplementedError("XTTS Training is not implemented") @@ -616,7 +759,14 @@ def eval(self): # pylint: disable=redefined-builtin super().eval() def load_checkpoint( - self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True + self, + config, + checkpoint_dir=None, + checkpoint_path=None, + vocab_path=None, + eval=True, + strict=True, + use_deepspeed=False, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -626,7 +776,7 @@ def load_checkpoint( checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. vocab_path (str, optional): The path to the vocabulary file. Defaults to None. - eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False. + eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. Returns: @@ -636,19 +786,26 @@ def load_checkpoint( model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") - if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")): - self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json")) + if os.path.exists(vocab_path): + self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) self.init_models() if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) - self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict) + + checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] + ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"] + for key in list(checkpoint.keys()): + if key.split(".")[0] in ignore_keys: + del checkpoint[key] + self.load_state_dict(checkpoint, strict=strict) if eval: - self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() + if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() + if hasattr(self, "vocoder"): self.vocoder.eval() + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() - self.diffusion_decoder.eval() - self.vocoder.eval() def train_step(self): raise NotImplementedError("XTTS Training is not implemented") diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index 85a3afbabf..ff6bcf974a 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml) Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai). You can also mail us at info@coqui.ai. -Using 🐸TTS API: +### Inference +#### 🐸TTS API ```python from TTS.api import TTS @@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t file_path="output.wav", speaker_wav="/path/to/target/speaker.wav", language="en") - -# generate speech by cloning a voice using custom settings -tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - file_path="output.wav", - speaker_wav="/path/to/target/speaker.wav", - language="en", - decoder_iterations=30) ``` -Using 🐸TTS Command line: +#### 🐸TTS Command line ```console tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \ @@ -58,25 +52,85 @@ Using 🐸TTS Command line: --use_cuda true ``` -Using model directly: +#### model directly + +If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first. + +```console +pip install deepspeed==0.8.3 +``` + +```python +import os +import torch +import torchaudio +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + +print("Loading model...") +config = XttsConfig() +config.load_json("/path/to/xtts/config.json") +model = Xtts.init_from_config(config) +model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True) +model.cuda() + +print("Computing speaker latents...") +gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") + +print("Inference...") +out = model.inference( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + "en", + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + temperature=0.7, # Add custom parameters here +) +torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) +``` + + +#### streaming inference + +Here the goal is to stream the audio as it is being generated. This is useful for real-time applications. +Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster. + ```python +import os +import time +import torch +import torchaudio from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts +print("Loading model...") config = XttsConfig() config.load_json("/path/to/xtts/config.json") model = Xtts.init_from_config(config) -model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True) +model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True) model.cuda() -outputs = model.synthesize( +print("Computing speaker latents...") +gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") + +print("Inference...") +t0 = time.time() +chunks = model.inference_stream( "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", - config, - speaker_wav="/data/TTS-public/_refclips/3.wav", - gpt_cond_len=3, - language="en", + "en", + gpt_cond_latent, + speaker_embedding ) + +wav_chuncks = [] +for i, chunk in enumerate(chunks): + if i == 0: + print(f"Time to first chunck: {time.time() - t0}") + print(f"Received chunk {i} of audio length {chunk.shape[-1]}") + wav_chuncks.append(chunk) +wav = torch.cat(wav_chuncks, dim=0) +torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) ``` diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 9c62827641..dc16d7932f 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -93,6 +93,34 @@ def test_xtts(): f'--speaker_wav "{speaker_wav}" --language_idx "en"' ) +def test_xtts_streaming(): + """Testing the new inference_stream method""" + from TTS.tts.configs.xtts_config import XttsConfig + from TTS.tts.models.xtts import Xtts + speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") + config = XttsConfig() + config.load_json(os.path.join(model_path, "config.json")) + model = Xtts.init_from_config(config) + model.load_checkpoint(config, checkpoint_dir=model_path) + model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + + print("Computing speaker latents...") + gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) + + print("Inference...") + chunks = model.inference_stream( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + "en", + gpt_cond_latent, + speaker_embedding + ) + wav_chuncks = [] + for i, chunk in enumerate(chunks): + if i == 0: + assert chunk.shape[-1] > 5000 + wav_chuncks.append(chunk) + assert len(wav_chuncks) > 1 def test_tortoise(): output_path = os.path.join(get_tests_output_path(), "output.wav")