Skip to content

Commit

Permalink
Merge pull request ggerganov#23 from awgu/pt2
Browse files Browse the repository at this point in the history
Register `freqs_cis` as non-persistent buffer
  • Loading branch information
karpathy authored Jul 24, 2023
2 parents 3bfa566 + af3b5c0 commit bd9e837
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def __init__(self, params: ModelArgs):
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying

# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)

# init all weights
self.apply(self._init_weights)
Expand All @@ -215,7 +216,6 @@ def _init_weights(self, module):
def forward(self, tokens, targets=None):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[:seqlen]

for layer in self.layers:
Expand Down

0 comments on commit bd9e837

Please sign in to comment.