Skip to content

Commit

Permalink
s2a: fixed DDP training without causal_encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jpc committed Feb 28, 2024
1 parent d23339b commit bbf86c5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@
"\n",
" random :bool = False\n",
" random_finetune :bool = False\n",
" \n",
" # backwards compat\n",
" force_hidden_to_emb: bool = False\n",
"\n",
" def __post_init__(self):\n",
" # randomize the hyperparams if requested\n",
Expand Down Expand Up @@ -431,6 +434,7 @@
" old_default('rope', False)\n",
" old_default('linear_heads', True)\n",
" old_default('causal_encoder', False)\n",
" old_default('force_hidden_to_emb', True)\n",
" return args\n",
" \n",
"class SADelARTransformer(nn.Module):\n",
Expand Down Expand Up @@ -462,7 +466,8 @@
" self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)\n",
" if self.emb_factor:\n",
" self.emb_to_hidden = nn.Linear(stoks_width, width)\n",
" self.hidden_to_emb = nn.Linear(width, stoks_width)\n",
" if self.tunables.causal_encoder or self.tunables.force_hidden_to_emb:\n",
" self.hidden_to_emb = nn.Linear(width, stoks_width)\n",
" \n",
" if self.spk_factor:\n",
" self.spk_to_hidden = nn.Linear(spk_width, width)\n",
Expand Down
7 changes: 6 additions & 1 deletion whisperspeech/s2a_delar_mup_wds_mlang.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ class Tunables:

random :bool = False
random_finetune :bool = False

# backwards compat
force_hidden_to_emb: bool = False

def __post_init__(self):
# randomize the hyperparams if requested
Expand Down Expand Up @@ -218,6 +221,7 @@ def old_default(name, value):
old_default('rope', False)
old_default('linear_heads', True)
old_default('causal_encoder', False)
old_default('force_hidden_to_emb', True)
return args

class SADelARTransformer(nn.Module):
Expand Down Expand Up @@ -249,7 +253,8 @@ def __init__(self, depth=3, ctx_n=2250,
self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
if self.emb_factor:
self.emb_to_hidden = nn.Linear(stoks_width, width)
self.hidden_to_emb = nn.Linear(width, stoks_width)
if self.tunables.causal_encoder or self.tunables.force_hidden_to_emb:
self.hidden_to_emb = nn.Linear(width, stoks_width)

if self.spk_factor:
self.spk_to_hidden = nn.Linear(spk_width, width)
Expand Down

0 comments on commit bbf86c5

Please sign in to comment.