Skip to content

Commit

Permalink
Enable passing in external position ids (#1493)
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay authored Aug 28, 2024
1 parent 0db4425 commit bf6cfdf
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
20 changes: 14 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> BaseModelOutputWithPast:
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
Expand Down Expand Up @@ -856,12 +857,16 @@ def forward(
)

if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
pos = torch.arange(
past_position,
S + past_position,
dtype=torch.long,
device=input_device,
).unsqueeze(0)
if position_ids is None:
pos = torch.arange(
past_position,
S + past_position,
dtype=torch.long,
device=input_device,
).unsqueeze(0)
else:
pos = position_ids

if attention_mask is not None:
# adjust the position indices to account for padding tokens
pos = torch.clamp(
Expand Down Expand Up @@ -1128,6 +1133,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> CausalLMOutputWithPast:
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
Expand All @@ -1146,6 +1152,7 @@ def forward(
output_hidden_states=output_hidden_states,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
)

if self.lm_head is not None:
Expand Down Expand Up @@ -1443,6 +1450,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
attention_mask=batch.get('attention_mask', None),
sequence_id=batch.get('sequence_id', None),
inputs_embeds=batch.get('inputs_embeds', None),
position_ids=batch.get('position_ids', None),
)

def loss(self, outputs: CausalLMOutputWithPast,
Expand Down
34 changes: 34 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2941,3 +2941,37 @@ def test_hf_rotary_child_class_builds():

assert torch.all(cos == cos_mp)
assert torch.all(sin == sin_mp)


@pytest.mark.parametrize(
'conf_path',
[
'scripts/train/yamls/pretrain/testing.yaml',
],
)
def test_position_ids_fwd_pass(
request: pytest.FixtureRequest,
conf_path: str,
batch_size: int = 2,
):
test_cfg, model, _ = _get_objs(request=request, conf_path=conf_path)
model.eval()

# run a forward where we do not pass the position_ids
batch = gen_random_batch(batch_size, test_cfg)
outputs = model(batch)
loss_no_ids = model.loss(outputs, batch)
assert isinstance(loss_no_ids, torch.Tensor)

# run a forward where we explicitly pass the position_ids
input_ids = batch['input_ids']
_, S = input_ids.size()
pos = torch.arange(0, S, dtype=torch.long,
device=input_ids.device).unsqueeze(0)
batch['position_ids'] = pos

outputs = model(batch)
loss_ids = model.loss(outputs, batch)
assert isinstance(loss_ids, torch.Tensor)

assert torch.eq(loss_no_ids, loss_ids)
2 changes: 2 additions & 0 deletions tests/models/test_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
result = super().forward(
input_ids,
Expand All @@ -49,6 +50,7 @@ def forward(
output_hidden_states,
use_cache,
inputs_embeds,
position_ids,
)
# Modify the logits to select the next token.
if dist.get_global_rank() == 0:
Expand Down

0 comments on commit bf6cfdf

Please sign in to comment.