Skip to content

Commit

Permalink
fix: no need to dtype A in jamba (#32924)
Browse files Browse the repository at this point in the history
Co-authored-by: Gal Cohen <galc@ai21.com>
  • Loading branch information
2 people authored and ArthurZucker committed Aug 22, 2024
1 parent c1df7f8 commit 3d8cba8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def __init__(self, config: JambaConfig, layer_idx):

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = torch.arange(1, self.ssm_state_size + 1)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()

self.A_log = nn.Parameter(torch.log(A))
Expand Down

0 comments on commit 3d8cba8

Please sign in to comment.