Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not update past_key_values in place #652

Merged
merged 18 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def forward(
)

# initialize the past key values cache if it should be used
presents = () if use_cache else None
if use_cache and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)
irenedea marked this conversation as resolved.
Show resolved Hide resolved
] # type: ignore
Expand All @@ -434,16 +435,16 @@ def forward(
all_hidden_states = all_hidden_states + (x,)
past_key_value = (past_key_values[b_idx]
if past_key_values is not None else None)
x, attn_weights, past_key_value = block(
x, attn_weights, present = block(
x,
past_key_value=past_key_value,
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value
if use_cache:
presents += (present,)

if output_attentions:
assert all_self_attns is not None # pyright
Expand All @@ -458,7 +459,7 @@ def forward(

return BaseModelOutputWithPast(
last_hidden_state=x,
past_key_values=past_key_values,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
Expand Down
91 changes: 90 additions & 1 deletion tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict
from typing import Any, Dict, List, Optional, Tuple

import pytest
from composer.core.precision import get_precision_context
from composer.utils import get_device, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from composer.utils import dist

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils import build_tokenizer
import torch

from transformers import AutoModelForCausalLM
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from unittest.mock import patch
from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM, MPTConfig

@pytest.mark.gpu
@pytest.mark.parametrize('device', ['cpu', 'gpu'])
Expand Down Expand Up @@ -72,3 +78,86 @@ def test_init_hfhub_mpt(device: str, attn_impl: str):

def test_init_hfhub_mpt_cpu():
test_init_hfhub_mpt(device='cpu', attn_impl='torch')

EOS_TOKEN_ID = 0

class MockMPTForCausalLM(MPTForCausalLM):
"""Class that overrides the forward of MPTForCausalLM.
"""
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
result = super().forward(input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache, inputs_embeds)
# Modify the logits to select the next token.
if dist.get_global_rank() == 0:
# Rank 0 hits EOS immediately.
result.logits[:, :, EOS_TOKEN_ID] = torch.inf
else:
# Other ranks do not hit EOS.
result.logits[:, :, EOS_TOKEN_ID] = -torch.inf
return result

def mock_from_config(config: MPTConfig, **_):
config_dict = config.to_dict()
config = MPTConfig.from_dict(config_dict)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
return MockMPTForCausalLM._from_config(config)

@pytest.mark.world_size(2)
@pytest.mark.gpu
@patch.object(AutoModelForCausalLM, 'from_config', new=mock_from_config)
def test_mpt_generate_multi_gpu():
"""Tests mpt generation with mutiple gpus and
generations of different lengths.
"""
composer_device = get_device('gpu')
dist.initialize_dist(composer_device)
with open('scripts/train/yamls/pretrain/testing.yaml') as f:
test_cfg = om.load(f)

assert isinstance(test_cfg, DictConfig)
reproducibility.seed_all(test_cfg.get('seed', 42))

test_cfg.model = DictConfig({
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'mosaicml/mpt-7b',
'pretrained': False,
'config_overrides': {
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'expansion_ratio': 2,
'no_bias': False,
'use_cache': False
}
})

# build tokenizer
tokenizer_name = test_cfg.tokenizer.name
tokenizer = build_tokenizer(tokenizer_name, {'max_seq_len': 15})

# build model
model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)
model = composer_device.module_to_device(model)

model.model = FSDP(model.model)

_ = model.generate(
composer_device.tensor_to_device(
tokenizer('hello', return_tensors='pt')['input_ids']),
max_new_tokens=10,
eos_token_id=EOS_TOKEN_ID,
use_cache=True,
synced_gpus=True
)
Loading