diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 939c42f3..db16f2e4 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -20,6 +20,7 @@ def __init__( device: Union[str, torch.device] = 'cpu', cache_dir: Path = None, ckpt_steps: int = None, + expose_expr: bool = False, export_spk: List[Tuple[str, Dict[str, float]]] = None, freeze_spk: Tuple[str, Dict[str, float]] = None ): @@ -58,6 +59,7 @@ def __init__( if self.model.predict_variances else None # Attributes for exporting + self.expose_expr = expose_expr self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \ if hparams['use_spk_id'] else None self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \ @@ -264,6 +266,10 @@ def _torch_export_model(self): note_midi, note_dur, pitch, + *([ + torch.ones_like(pitch) + if self.expose_expr else [] + ]), retake, *([torch.rand( 1, 15, hparams['hidden_size'], @@ -274,7 +280,9 @@ def _torch_export_model(self): input_names=[ 'encoder_out', 'ph_dur', 'note_midi', 'note_dur', - 'pitch', 'retake', + 'pitch', + *(['expr'] if self.expose_expr else []), + 'retake', *(['spk_embed'] if input_spk_embed else []) ], output_names=[ @@ -293,6 +301,7 @@ def _torch_export_model(self): 'note_dur': { 1: 'n_notes' }, + **({'expr': {1: 'n_frames'}} if self.expose_expr else {}), 'pitch': { 1: 'n_frames' }, diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 6ee081ec..2cbbda8f 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -162,14 +162,25 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): def forward_pitch_preprocess( self, encoder_out, ph_dur, note_midi, note_dur, - pitch=None, retake=None, spk_embed=None + pitch=None, expr=None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) - condition += self.pitch_retake_embed(retake.long()) + if expr is None: + retake_embed = self.pitch_retake_embed(retake.long()) + else: + retake_true_embed = self.pitch_retake_embed( + torch.ones(1, 1, dtype=torch.long, device=encoder_out.device) + ) # [B=1, T=1] => [B=1, T=1, H] + retake_false_embed = self.pitch_retake_embed( + torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device) + ) # [B=1, T=1] => [B=1, T=1, H] + expr = (expr * retake)[:, :, None] # [B, T, 1] + retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed + pitch_cond = condition + retake_embed frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) base_pitch = self.smooth(frame_midi_pitch) base_pitch = base_pitch * retake + pitch * ~retake - pitch_cond = condition + self.base_pitch_embed(base_pitch[:, :, None]) + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if hparams['use_spk_id'] and spk_embed is not None: pitch_cond += spk_embed return pitch_cond, base_pitch diff --git a/inference/ds_variance.py b/inference/ds_variance.py index 41d67f82..fdc99929 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -212,6 +212,22 @@ def preprocess_input( summary['pitch'] = 'manual' elif self.auto_completion_mode or self.global_predict_pitch: summary['pitch'] = 'auto' + + # Load expressiveness + expr = param.get('expr', 1.) + if isinstance(expr, (int, float, bool)): + summary['expr'] = f'static({expr:.3f})' + batch['expr'] = torch.FloatTensor([expr]).to(self.device)[:, None] # [B=1, T=1] + else: + summary['expr'] = 'dynamic' + expr = resample_align_curve( + np.array(expr.split(), np.float32), + original_timestep=float(param['expr_timestep']), + target_timestep=self.timestep, + align_length=T_s + ) + batch['expr'] = torch.from_numpy(expr.astype(np.float32)).to(self.device)[None] + else: summary['pitch'] = 'ignored' @@ -235,6 +251,7 @@ def forward_model(self, sample): ph_dur = sample['ph_dur'] mel2ph = sample['mel2ph'] base_pitch = sample['base_pitch'] + expr = sample.get('expr') pitch = sample.get('pitch') if hparams['use_spk_id']: @@ -255,7 +272,7 @@ def forward_model(self, sample): dur_pred, pitch_pred, variance_pred = self.model( txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, - mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, + mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, infer=True ) diff --git a/modules/toplevel.py b/modules/toplevel.py index a93ed1e3..01c52c09 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -114,7 +114,8 @@ def __init__(self, vocab_size): def forward( self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None, - base_pitch=None, pitch=None, pitch_retake=None, variance_retake: Dict[str, Tensor] = None, + base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None, + variance_retake: Dict[str, Tensor] = None, spk_id=None, infer=True, **kwargs ): if self.use_spk_id: @@ -151,10 +152,23 @@ def forward( if self.predict_pitch: if pitch_retake is None: - pitch_retake_embed = self.pitch_retake_embed(torch.ones_like(mel2ph)) + pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) else: - pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long()) + print(base_pitch, pitch, pitch_retake) base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake + + if pitch_expr is None: + pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long()) + else: + retake_true_embed = self.pitch_retake_embed( + torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device) + ) # [B=1, T=1] => [B=1, T=1, H] + retake_false_embed = self.pitch_retake_embed( + torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device) + ) # [B=1, T=1] => [B=1, T=1, H] + pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1] + pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed + pitch_cond = condition + pitch_retake_embed pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if infer: diff --git a/scripts/export.py b/scripts/export.py index 29b4f3f5..f7e6de5a 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -140,6 +140,8 @@ def acoustic( @click.option('--exp', type=str, required=True, metavar='', help='Choose an experiment to export.') @click.option('--ckpt', type=int, required=False, metavar='', help='Checkpoint training steps.') @click.option('--out', type=str, required=False, metavar='', help='Output directory for the artifacts.') +@click.option('--expose_expr', is_flag=True, show_default=True, + help='Expose pitch expressiveness control functionality.') @click.option('--export_spk', type=str, required=False, multiple=True, metavar='', help='(for multi-speaker models) Export one or more speaker or speaker mix keys.') @click.option('--freeze_spk', type=str, required=False, metavar='', @@ -148,6 +150,7 @@ def variance( exp: str, ckpt: int = None, out: str = None, + expose_expr: bool = False, export_spk: List[str] = None, freeze_spk: str = None ): @@ -177,6 +180,7 @@ def variance( device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), cache_dir=root_dir / 'deployment' / 'cache', ckpt_steps=ckpt, + expose_expr=expose_expr, export_spk=export_spk_mix, freeze_spk=freeze_spk_mix ) diff --git a/scripts/infer.py b/scripts/infer.py index 00389c22..d83bb931 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -142,6 +142,7 @@ def acoustic( @click.option('--title', type=str, required=False, help='Title of output file') @click.option('--num', type=int, required=False, default=1, help='Number of runs') @click.option('--key', type=int, required=False, default=0, help='Key transition of pitch') +@click.option('--expr', type=float, required=False, help='Static expressiveness control') @click.option('--seed', type=int, required=False, default=-1, help='Random seed of the inference') @click.option('--speedup', type=int, required=False, default=0, help='Diffusion acceleration ratio') def variance( @@ -154,6 +155,7 @@ def variance( title: str, num: int, key: int, + expr: float, seed: int, speedup: int ): @@ -167,6 +169,9 @@ def variance( if (not out or out.resolve() == proj.parent.resolve()) and not title: name += '_variance' + if expr is not None: + assert 0 <= expr <= 1, 'Expressiveness must be in [0, 1].' + with open(proj, 'r', encoding='utf-8') as f: params = json.load(f) @@ -202,6 +207,9 @@ def variance( spk_mix = parse_commandline_spk_mix(spk) if hparams['use_spk_id'] and spk is not None else None for param in params: + if expr is not None: + param['expr'] = expr + if spk_mix is not None: param['ph_spk_mix_backup'] = param.get('ph_spk_mix') param['spk_mix_backup'] = param.get('spk_mix')