Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensor parallel #2

Merged
merged 21 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ pip install -e .
## Run

```bash
python server.py
ray start --head
python server.py [--tensor-parallel-size <N>]
```
2 changes: 1 addition & 1 deletion cacheflow/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.prompt_lens,
)

Expand Down
6 changes: 5 additions & 1 deletion cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ def __repr__(self) -> str:
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len})')
f'max_context_len={self.max_context_len}), '
f'prompt_lens={self.prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
37 changes: 19 additions & 18 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def __init__(
model_name: str,
block_size: int,
dtype: torch.dtype,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size

# TODO(woosuk): Support tensor parallelism.
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
Expand All @@ -47,27 +48,26 @@ def __init__(
self.vocab_size = config.vocab_size
self.max_position = config.max_position_embeddings

assert self.embedding_size == self.hidden_size
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved

def _get_param_size(self) -> int:
# TODO(woosuk): Support tensor parallelism.
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
word_embedding = self.vocab_size * self.embedding_size
if self.embedding_size != self.vocab_size:
# Project in/out.
word_embedding += 2 * self.embedding_size * self.vocab_size
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
position_embedding = self.max_position * self.hidden_size

ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size + self.hidden_size
k = self.hidden_size * self.hidden_size + self.hidden_size
v = self.hidden_size * self.hidden_size + self.hidden_size
out = self.hidden_size * self.hidden_size + self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
mha = ln1 + q + k + v + out

ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size + self.hidden_size
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2

total = (word_embedding + position_embedding +
total = (word_embedding + position_embedding +
self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
Expand All @@ -76,15 +76,16 @@ def _get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# TODO(woosuk): Support tensor parallelism.
# NOTE: We approxmiately calculate the maximum activation size by
# 1) estimating the maximum activation tensor size during inference, and
# 2) multiplying it by 4.
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
qkv = 3 * (max_num_batched_tokens * self.hidden_size)
ffn = max_num_batched_tokens * self.ffn_size
max_act = 4 * max(qkv, ffn)
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
max_act = 2 * max(qkv, ffn) + 2 * residual
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act

Expand Down
19 changes: 13 additions & 6 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Union

import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig

from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Expand All @@ -21,24 +23,29 @@
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
path: str = '/tmp/transformers',
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
for model_class, hf_model in _MODELS.items():
if model_class in model_name:
model = hf_model.from_pretrained(
model_name, torch_dtype=torch_dtype)
return model.eval()
torch.set_default_dtype(torch_dtype)
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
config = AutoConfig.from_pretrained(model_name)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
model = model_class(config)
weights_dir = model_class.download_weights(model_name, path=path)
model.load_weights(weights_dir)
return model.eval(), torch_dtype
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Unsupported model name: {model_name}')


def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype)
model_name, block_size, torch_dtype, tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')
Loading