Skip to content

Commit

Permalink
Implement pitch expressiveness controlling mechanism (#97)
Browse files Browse the repository at this point in the history
* Add expressiveness in model `forward()`

* Support inference with static or dynamic expressiveness

* Fix assignment of `retake_`

* Format code

* Add `expressiveness` in ONNX model

* Swap input order

* Fix typo

* Adapt latest updates from main branch
  • Loading branch information
yqzhishen authored Aug 11, 2023
1 parent 566ad4a commit 38bc407
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 8 deletions.
11 changes: 10 additions & 1 deletion deployment/exporters/variance_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
device: Union[str, torch.device] = 'cpu',
cache_dir: Path = None,
ckpt_steps: int = None,
expose_expr: bool = False,
export_spk: List[Tuple[str, Dict[str, float]]] = None,
freeze_spk: Tuple[str, Dict[str, float]] = None
):
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
if self.model.predict_variances else None

# Attributes for exporting
self.expose_expr = expose_expr
self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \
if hparams['use_spk_id'] else None
self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \
Expand Down Expand Up @@ -264,6 +266,10 @@ def _torch_export_model(self):
note_midi,
note_dur,
pitch,
*([
torch.ones_like(pitch)
if self.expose_expr else []
]),
retake,
*([torch.rand(
1, 15, hparams['hidden_size'],
Expand All @@ -274,7 +280,9 @@ def _torch_export_model(self):
input_names=[
'encoder_out', 'ph_dur',
'note_midi', 'note_dur',
'pitch', 'retake',
'pitch',
*(['expr'] if self.expose_expr else []),
'retake',
*(['spk_embed'] if input_spk_embed else [])
],
output_names=[
Expand All @@ -293,6 +301,7 @@ def _torch_export_model(self):
'note_dur': {
1: 'n_notes'
},
**({'expr': {1: 'n_frames'}} if self.expose_expr else {}),
'pitch': {
1: 'n_frames'
},
Expand Down
17 changes: 14 additions & 3 deletions deployment/modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,25 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):

def forward_pitch_preprocess(
self, encoder_out, ph_dur, note_midi, note_dur,
pitch=None, retake=None, spk_embed=None
pitch=None, expr=None, retake=None, spk_embed=None
):
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
condition += self.pitch_retake_embed(retake.long())
if expr is None:
retake_embed = self.pitch_retake_embed(retake.long())
else:
retake_true_embed = self.pitch_retake_embed(
torch.ones(1, 1, dtype=torch.long, device=encoder_out.device)
) # [B=1, T=1] => [B=1, T=1, H]
retake_false_embed = self.pitch_retake_embed(
torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device)
) # [B=1, T=1] => [B=1, T=1, H]
expr = (expr * retake)[:, :, None] # [B, T, 1]
retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed
pitch_cond = condition + retake_embed
frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None)
base_pitch = self.smooth(frame_midi_pitch)
base_pitch = base_pitch * retake + pitch * ~retake
pitch_cond = condition + self.base_pitch_embed(base_pitch[:, :, None])
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if hparams['use_spk_id'] and spk_embed is not None:
pitch_cond += spk_embed
return pitch_cond, base_pitch
Expand Down
19 changes: 18 additions & 1 deletion inference/ds_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,22 @@ def preprocess_input(
summary['pitch'] = 'manual'
elif self.auto_completion_mode or self.global_predict_pitch:
summary['pitch'] = 'auto'

# Load expressiveness
expr = param.get('expr', 1.)
if isinstance(expr, (int, float, bool)):
summary['expr'] = f'static({expr:.3f})'
batch['expr'] = torch.FloatTensor([expr]).to(self.device)[:, None] # [B=1, T=1]
else:
summary['expr'] = 'dynamic'
expr = resample_align_curve(
np.array(expr.split(), np.float32),
original_timestep=float(param['expr_timestep']),
target_timestep=self.timestep,
align_length=T_s
)
batch['expr'] = torch.from_numpy(expr.astype(np.float32)).to(self.device)[None]

else:
summary['pitch'] = 'ignored'

Expand All @@ -235,6 +251,7 @@ def forward_model(self, sample):
ph_dur = sample['ph_dur']
mel2ph = sample['mel2ph']
base_pitch = sample['base_pitch']
expr = sample.get('expr')
pitch = sample.get('pitch')

if hparams['use_spk_id']:
Expand All @@ -255,7 +272,7 @@ def forward_model(self, sample):

dur_pred, pitch_pred, variance_pred = self.model(
txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur,
mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch,
mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr,
ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed,
infer=True
)
Expand Down
20 changes: 17 additions & 3 deletions modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def __init__(self, vocab_size):

def forward(
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None,
base_pitch=None, pitch=None, pitch_retake=None, variance_retake: Dict[str, Tensor] = None,
base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None,
variance_retake: Dict[str, Tensor] = None,
spk_id=None, infer=True, **kwargs
):
if self.use_spk_id:
Expand Down Expand Up @@ -151,10 +152,23 @@ def forward(

if self.predict_pitch:
if pitch_retake is None:
pitch_retake_embed = self.pitch_retake_embed(torch.ones_like(mel2ph))
pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool)
else:
pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long())
print(base_pitch, pitch, pitch_retake)
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake

if pitch_expr is None:
pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long())
else:
retake_true_embed = self.pitch_retake_embed(
torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device)
) # [B=1, T=1] => [B=1, T=1, H]
retake_false_embed = self.pitch_retake_embed(
torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device)
) # [B=1, T=1] => [B=1, T=1, H]
pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1]
pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed

pitch_cond = condition + pitch_retake_embed
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if infer:
Expand Down
4 changes: 4 additions & 0 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def acoustic(
@click.option('--exp', type=str, required=True, metavar='<exp>', help='Choose an experiment to export.')
@click.option('--ckpt', type=int, required=False, metavar='<steps>', help='Checkpoint training steps.')
@click.option('--out', type=str, required=False, metavar='<dir>', help='Output directory for the artifacts.')
@click.option('--expose_expr', is_flag=True, show_default=True,
help='Expose pitch expressiveness control functionality.')
@click.option('--export_spk', type=str, required=False, multiple=True, metavar='<mix>',
help='(for multi-speaker models) Export one or more speaker or speaker mix keys.')
@click.option('--freeze_spk', type=str, required=False, metavar='<mix>',
Expand All @@ -148,6 +150,7 @@ def variance(
exp: str,
ckpt: int = None,
out: str = None,
expose_expr: bool = False,
export_spk: List[str] = None,
freeze_spk: str = None
):
Expand Down Expand Up @@ -177,6 +180,7 @@ def variance(
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
cache_dir=root_dir / 'deployment' / 'cache',
ckpt_steps=ckpt,
expose_expr=expose_expr,
export_spk=export_spk_mix,
freeze_spk=freeze_spk_mix
)
Expand Down
8 changes: 8 additions & 0 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def acoustic(
@click.option('--title', type=str, required=False, help='Title of output file')
@click.option('--num', type=int, required=False, default=1, help='Number of runs')
@click.option('--key', type=int, required=False, default=0, help='Key transition of pitch')
@click.option('--expr', type=float, required=False, help='Static expressiveness control')
@click.option('--seed', type=int, required=False, default=-1, help='Random seed of the inference')
@click.option('--speedup', type=int, required=False, default=0, help='Diffusion acceleration ratio')
def variance(
Expand All @@ -154,6 +155,7 @@ def variance(
title: str,
num: int,
key: int,
expr: float,
seed: int,
speedup: int
):
Expand All @@ -167,6 +169,9 @@ def variance(
if (not out or out.resolve() == proj.parent.resolve()) and not title:
name += '_variance'

if expr is not None:
assert 0 <= expr <= 1, 'Expressiveness must be in [0, 1].'

with open(proj, 'r', encoding='utf-8') as f:
params = json.load(f)

Expand Down Expand Up @@ -202,6 +207,9 @@ def variance(

spk_mix = parse_commandline_spk_mix(spk) if hparams['use_spk_id'] and spk is not None else None
for param in params:
if expr is not None:
param['expr'] = expr

if spk_mix is not None:
param['ph_spk_mix_backup'] = param.get('ph_spk_mix')
param['spk_mix_backup'] = param.get('spk_mix')
Expand Down

0 comments on commit 38bc407

Please sign in to comment.