Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mixtral-related issues #570

Merged
merged 18 commits into from
Apr 10, 2024
2 changes: 2 additions & 0 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ jobs:
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 20
Expand Down
2 changes: 1 addition & 1 deletion src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen

def reorder_cache(self, beam_idx):
pass
raise NotImplementedError("Beam search reordering is not implemented yet")


_skipped_tokens = ContextVar("skipped_tokens", default=0)
Expand Down
6 changes: 6 additions & 0 deletions src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor

from petals.utils.misc import is_dummy


class WrappedBloomBlock(BloomBlock):
def forward(
Expand All @@ -22,6 +24,10 @@ def forward(
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
Expand Down
12 changes: 6 additions & 6 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -33,16 +34,15 @@ def forward(
past_key_values_length = 0

past_key_value = layer_past

if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
past_key_value.update(
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
past_key_value._seen_tokens = past_key_values_length

if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
Expand Down Expand Up @@ -83,7 +83,7 @@ def forward(

if use_cache:
present_key_value = outputs[-1]
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
present_key_value = present_key_value[self.layer_idx]
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
outputs = outputs[:-1] + (present_key_value,)

Expand Down
21 changes: 15 additions & 6 deletions src/petals/models/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,20 @@ def forward(
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens

@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return nn.Identity()

@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers

@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return self.norm

class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):

class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected

Expand All @@ -151,9 +157,12 @@ def transformer(self) -> DistributedMixtralModel: # For compatibility with Remo
return self.model


class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected

config_class = DistributedMixtralConfig

def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
Expand Down
17 changes: 15 additions & 2 deletions src/petals/server/block_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel

from petals.models.mixtral.block import WrappedMixtralBlock
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes

Expand Down Expand Up @@ -32,7 +33,7 @@ def get_block_size(
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'

with init_empty_weights(include_buffers=True):
block = config.block_class(config)
block = get_model_block(config)
n_params = sum(param.numel() for param in block.parameters())

if location == "memory":
Expand All @@ -50,3 +51,15 @@ def get_block_size(
bytes_per_value = get_size_in_bytes(dtype)

return round(n_params * bytes_per_value * (1 + eps))


def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, layer_idx)
return config.block_class(config)
8 changes: 2 additions & 6 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from petals.constants import DTYPE_MAP
from petals.models.mixtral import WrappedMixtralBlock
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
Expand Down Expand Up @@ -52,11 +52,7 @@ def load_pretrained_block(
torch_dtype = resolve_block_dtype(config, torch_dtype)

with init_empty_weights():
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
block = config.block_class(config, block_index)
else:
block = config.block_class(config)
block = get_model_block(config, layer_idx=block_index)

block_prefix = f"{config.block_prefix}.{block_index}."
state_dict = _load_state_dict_from_repo(
Expand Down
18 changes: 13 additions & 5 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig

from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.misc import DUMMY_KEY_PAST

logger = get_logger(__name__)

Expand Down Expand Up @@ -201,18 +202,25 @@ def measure_compute_rps(
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = config.block_class(config).to(dtype)
block = get_model_block(config)
block = block.to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)

cache = None
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time

# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None

cache = step(cache)
synchronize(device)

start_time = time.perf_counter()
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
cache = step(cache)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
Expand Down
2 changes: 2 additions & 0 deletions src/petals/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

DUMMY_INT64 = torch.empty(0, dtype=torch.int64)

DUMMY_KEY_PAST = torch.empty((0, 0, 0))


def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
Expand Down
4 changes: 2 additions & 2 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo

from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
Expand Down Expand Up @@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_chained_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.misc import DUMMY_KEY_PAST
from test_utils import *


Expand Down Expand Up @@ -54,12 +55,14 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

dtype = torch.float32
ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
]
outputs_ref = []
caches = [None, None]
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"


@pytest.mark.skipif(
"bloom" not in MODEL_NAME.lower(),
reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_optimized_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from petals.server.block_utils import get_model_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
Expand Down Expand Up @@ -195,8 +196,9 @@ def test_optimized_block(device):
dtype = torch.bfloat16
quant_type = QuantType.NONE

block = config.block_class(config).to(dtype)
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block_idx = 1
block = get_model_block(config, layer_idx=block_idx).to(dtype)
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)

if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
Expand All @@ -206,7 +208,7 @@ def test_optimized_block(device):
pytest.skip(f"This test is not applicable to {config.model_type} models")

unopt_block = convert_block(
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)

unopt_block.load_state_dict(block.state_dict())
Expand Down
Loading