Skip to content

Commit c56d317

Browse files
committed
style
1 parent e5d5ecc commit c56d317

File tree

6 files changed

+327
-363
lines changed

6 files changed

+327
-363
lines changed

src/transformers/models/glm4/modeling_glm4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ def _init_weights(self, module):
407407
module.weight.data.normal_(mean=0.0, std=std)
408408
if module.padding_idx is not None:
409409
module.weight.data[module.padding_idx].zero_()
410+
elif isinstance(module, Glm4RMSNorm):
411+
module.weight.data.fill_(1.0)
410412

411413

412414
GLM4_INPUTS_DOCSTRING = r"""
@@ -678,7 +680,7 @@ def _update_causal_mask(
678680
if (
679681
self.config._attn_implementation == "sdpa"
680682
and attention_mask is not None
681-
and attention_mask.device.type in ["cuda", "xpu"]
683+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
682684
and not output_attentions
683685
):
684686
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when

0 commit comments

Comments
 (0)