Skip to content

Commit bc884ae

Browse files
committed
granite speech
1 parent f2b360e commit bc884ae

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/transformers/models/granite_speech/modeling_granite_speech.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,15 @@ def _init_weights(self, module: nn.Module):
330330
module.weight.data.normal_(mean=0.0, std=std)
331331
if module.bias is not None:
332332
module.bias.data.zero_()
333-
334333
elif isinstance(module, nn.Embedding):
335334
module.weight.data.normal_(mean=0.0, std=std)
336335
if module.padding_idx is not None:
337336
module.weight.data[module.padding_idx].zero_()
337+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
338+
module.weight.data.fill_(1.0)
339+
module.bias.data.zero_()
340+
elif isinstance(module, GraniteSpeechEncoderProjector):
341+
module.query.data.normal_()
338342

339343

340344
GRANITE_SPEECH_INPUTS_DOCSTRING = r"""

0 commit comments

Comments
 (0)