diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cda9169b4c..06b64101c3 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -760,6 +760,7 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, ) -> BaseModelOutputWithPast: return_dict = ( return_dict if return_dict is not None else self.config.return_dict @@ -856,12 +857,16 @@ def forward( ) if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'): - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_device, - ).unsqueeze(0) + if position_ids is None: + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_device, + ).unsqueeze(0) + else: + pos = position_ids + if attention_mask is not None: # adjust the position indices to account for padding tokens pos = torch.clamp( @@ -1128,6 +1133,7 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, ) -> CausalLMOutputWithPast: return_dict = ( return_dict if return_dict is not None else self.config.return_dict @@ -1146,6 +1152,7 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds, + position_ids=position_ids, ) if self.lm_head is not None: @@ -1443,6 +1450,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: attention_mask=batch.get('attention_mask', None), sequence_id=batch.get('sequence_id', None), inputs_embeds=batch.get('inputs_embeds', None), + position_ids=batch.get('position_ids', None), ) def loss(self, outputs: CausalLMOutputWithPast, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index ac1bdacf4e..eeb6bf0d90 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2941,3 +2941,37 @@ def test_hf_rotary_child_class_builds(): assert torch.all(cos == cos_mp) assert torch.all(sin == sin_mp) + + +@pytest.mark.parametrize( + 'conf_path', + [ + 'scripts/train/yamls/pretrain/testing.yaml', + ], +) +def test_position_ids_fwd_pass( + request: pytest.FixtureRequest, + conf_path: str, + batch_size: int = 2, +): + test_cfg, model, _ = _get_objs(request=request, conf_path=conf_path) + model.eval() + + # run a forward where we do not pass the position_ids + batch = gen_random_batch(batch_size, test_cfg) + outputs = model(batch) + loss_no_ids = model.loss(outputs, batch) + assert isinstance(loss_no_ids, torch.Tensor) + + # run a forward where we explicitly pass the position_ids + input_ids = batch['input_ids'] + _, S = input_ids.size() + pos = torch.arange(0, S, dtype=torch.long, + device=input_ids.device).unsqueeze(0) + batch['position_ids'] = pos + + outputs = model(batch) + loss_ids = model.loss(outputs, batch) + assert isinstance(loss_ids, torch.Tensor) + + assert torch.eq(loss_no_ids, loss_ids) diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py index 379f4b34bd..134ca35ec0 100644 --- a/tests/models/test_mpt_gen.py +++ b/tests/models/test_mpt_gen.py @@ -37,6 +37,7 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, ): result = super().forward( input_ids, @@ -49,6 +50,7 @@ def forward( output_hidden_states, use_cache, inputs_embeds, + position_ids, ) # Modify the logits to select the next token. if dist.get_global_rank() == 0: