diff --git a/eole/modules/embeddings.py b/eole/modules/embeddings.py index dff44369..b828ea28 100644 --- a/eole/modules/embeddings.py +++ b/eole/modules/embeddings.py @@ -168,10 +168,9 @@ def forward(self, source, step=None): if step == 0 or step is None: # reset self.past_length = 0 - past_length = self.past_length # TODO position_ids = torch.arange( past_length, - source.size(-1) + past_length, + source.size(-1) + self.past_length, dtype=torch.long, device=source.device, )