Skip to content

Commit 189ae23

Browse files
authored
Use dtype from model config & Add Dolly V2 (#63)
1 parent e548c14 commit 189ae23

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

cacheflow/master/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
214214
help='save a numpy copy of model weights for faster loading')
215215
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
216216
# NOTE(woosuk): FlashAttention does not support float32.
217-
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
217+
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
218+
help=('data type for model weights and activations. '
219+
'The "default" option will use FP16 precision '
220+
'for FP32 and FP16 models, and BF16 precision '
221+
'for BF16 models.'))
218222
# Parallel arguments
219223
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
220224
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')

cacheflow/models/model_utils.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Union, Optional
1+
from typing import Optional
22

33
import torch
44
import torch.nn as nn
55
from transformers import AutoConfig
6+
from transformers import PretrainedConfig
67

78
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
89
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
@@ -22,6 +23,7 @@
2223
'opt': OPTForCausalLM,
2324
'stablelm': GPTNeoXForCausalLM,
2425
'pythia': GPTNeoXForCausalLM,
26+
'dolly-v2': GPTNeoXForCausalLM,
2527
}
2628

2729
_MEMORY_ANALYZERS = {
@@ -30,19 +32,38 @@
3032
'opt': OPTMemoryAnalyzer,
3133
'stablelm': GPTNeoXMemoryAnalyzer,
3234
'pythia': GPTNeoXMemoryAnalyzer,
35+
'dolly-v2': GPTNeoXMemoryAnalyzer,
3336
}
3437

3538

39+
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
40+
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
41+
if dtype == 'default':
42+
if config_dtype == torch.float32:
43+
# Following the common practice, we use float16 for float32 models.
44+
torch_dtype = torch.float16
45+
else:
46+
torch_dtype = config_dtype
47+
else:
48+
torch_dtype = get_torch_dtype(dtype)
49+
if torch_dtype != config_dtype and config_dtype != torch.float32:
50+
# TODO(woosuk): Allow using float16 for bfloat16 models and
51+
# vice versa. Print a warning message and continue.
52+
raise ValueError(
53+
f'Cannot use {torch_dtype} for {config_dtype} model.')
54+
return torch_dtype
55+
56+
3657
def get_model(
3758
model_name: str,
38-
dtype: Union[torch.dtype, str],
59+
dtype: str,
3960
cache_dir: Optional[str],
4061
use_dummy_weights: bool,
4162
use_np_cache: bool,
4263
) -> nn.Module:
43-
torch_dtype = get_torch_dtype(dtype)
44-
torch.set_default_dtype(torch_dtype)
4564
config = AutoConfig.from_pretrained(model_name)
65+
torch_dtype = _get_dtype(config, dtype)
66+
torch.set_default_dtype(torch_dtype)
4667
for model_class_name, model_class in _MODELS.items():
4768
if model_class_name in model_name:
4869
if use_dummy_weights:
@@ -66,12 +87,13 @@ def get_model(
6687
def get_memory_analyzer(
6788
model_name: str,
6889
block_size: int,
69-
dtype: Union[torch.dtype, str],
90+
dtype: str,
7091
gpu_memory: int,
7192
cpu_memory: int,
7293
tensor_parallel_size: int = 1,
7394
) -> CacheFlowMemoryAnalyzer:
74-
torch_dtype = get_torch_dtype(dtype)
95+
config = AutoConfig.from_pretrained(model_name)
96+
torch_dtype = _get_dtype(config, dtype)
7597
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
7698
if model_class in model_name:
7799
return memory_analyzer(

0 commit comments

Comments
 (0)