Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing in external position ids #1493

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -856,12 +857,16 @@ def forward(
)

if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
pos = torch.arange(
past_position,
S + past_position,
dtype=torch.long,
device=input_device,
).unsqueeze(0)
if position_ids is None:
pos = torch.arange(
past_position,
S + past_position,
dtype=torch.long,
device=input_device,
).unsqueeze(0)
else:
pos = position_ids

if attention_mask is not None:
# adjust the position indices to account for padding tokens
pos = torch.clamp(
Expand Down Expand Up @@ -1121,6 +1126,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
gupta-abhay marked this conversation as resolved.
Show resolved Hide resolved
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
Expand All @@ -1140,6 +1146,7 @@ def forward(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1441,6 +1448,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
return self.model(
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
position_ids=batch.get('position_ids', None),
sequence_id=batch.get('sequence_id', None),
inputs_embeds=batch.get('inputs_embeds', None),
)
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(
input_ids: torch.LongTensor,
past_key_values: Optional[list[tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
Expand All @@ -42,6 +43,7 @@ def forward(
input_ids,
past_key_values,
attention_mask,
position_ids,
sequence_id,
labels,
return_dict,
Expand Down
Loading