1- from typing import Union , Optional
1+ from typing import Optional
22
33import torch
44import torch .nn as nn
55from transformers import AutoConfig
6+ from transformers import PretrainedConfig
67
78from cacheflow .models .memory_analyzer import CacheFlowMemoryAnalyzer
89from cacheflow .models .memory_analyzer import GPT2MemoryAnalyzer
2223 'opt' : OPTForCausalLM ,
2324 'stablelm' : GPTNeoXForCausalLM ,
2425 'pythia' : GPTNeoXForCausalLM ,
26+ 'dolly-v2' : GPTNeoXForCausalLM ,
2527}
2628
2729_MEMORY_ANALYZERS = {
3032 'opt' : OPTMemoryAnalyzer ,
3133 'stablelm' : GPTNeoXMemoryAnalyzer ,
3234 'pythia' : GPTNeoXMemoryAnalyzer ,
35+ 'dolly-v2' : GPTNeoXMemoryAnalyzer ,
3336}
3437
3538
39+ def _get_dtype (config : PretrainedConfig , dtype : str ) -> torch .dtype :
40+ config_dtype : torch .dtype = getattr (config , 'torch_dtype' , torch .float32 )
41+ if dtype == 'default' :
42+ if config_dtype == torch .float32 :
43+ # Following the common practice, we use float16 for float32 models.
44+ torch_dtype = torch .float16
45+ else :
46+ torch_dtype = config_dtype
47+ else :
48+ torch_dtype = get_torch_dtype (dtype )
49+ if torch_dtype != config_dtype and config_dtype != torch .float32 :
50+ # TODO(woosuk): Allow using float16 for bfloat16 models and
51+ # vice versa. Print a warning message and continue.
52+ raise ValueError (
53+ f'Cannot use { torch_dtype } for { config_dtype } model.' )
54+ return torch_dtype
55+
56+
3657def get_model (
3758 model_name : str ,
38- dtype : Union [ torch . dtype , str ] ,
59+ dtype : str ,
3960 cache_dir : Optional [str ],
4061 use_dummy_weights : bool ,
4162 use_np_cache : bool ,
4263) -> nn .Module :
43- torch_dtype = get_torch_dtype (dtype )
44- torch .set_default_dtype (torch_dtype )
4564 config = AutoConfig .from_pretrained (model_name )
65+ torch_dtype = _get_dtype (config , dtype )
66+ torch .set_default_dtype (torch_dtype )
4667 for model_class_name , model_class in _MODELS .items ():
4768 if model_class_name in model_name :
4869 if use_dummy_weights :
@@ -66,12 +87,13 @@ def get_model(
6687def get_memory_analyzer (
6788 model_name : str ,
6889 block_size : int ,
69- dtype : Union [ torch . dtype , str ] ,
90+ dtype : str ,
7091 gpu_memory : int ,
7192 cpu_memory : int ,
7293 tensor_parallel_size : int = 1 ,
7394) -> CacheFlowMemoryAnalyzer :
74- torch_dtype = get_torch_dtype (dtype )
95+ config = AutoConfig .from_pretrained (model_name )
96+ torch_dtype = _get_dtype (config , dtype )
7597 for model_class , memory_analyzer in _MEMORY_ANALYZERS .items ():
7698 if model_class in model_name :
7799 return memory_analyzer (
0 commit comments