-
Notifications
You must be signed in to change notification settings - Fork 48
/
config.py
84 lines (64 loc) · 3.27 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from ...import_utils import torch_version
from ...system_utils import is_rocm_system
from ..config import BackendConfig
DEVICE_MAPS = ["auto", "sequential"]
AMP_DTYPES = ["bfloat16", "float16"]
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]
QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}}
@dataclass
class PyTorchConfig(BackendConfig):
name: str = "pytorch"
version: Optional[str] = torch_version()
_target_: str = "optimum_benchmark.backends.pytorch.backend.PyTorchBackend"
# load options
no_weights: bool = False
device_map: Optional[str] = None
torch_dtype: Optional[str] = None
# optimization options
eval_mode: bool = True
to_bettertransformer: bool = False
low_cpu_mem_usage: Optional[bool] = None
attn_implementation: Optional[str] = None
cache_implementation: Optional[str] = None
# automatic mixed precision options
autocast_enabled: bool = False
autocast_dtype: Optional[str] = None
# torch compile options
torch_compile: bool = False
torch_compile_target: str = "forward"
torch_compile_config: Dict[str, Any] = field(default_factory=dict)
# quantization options
quantization_scheme: Optional[str] = None
quantization_config: Dict[str, Any] = field(default_factory=dict)
# distributed inference options
deepspeed_inference: bool = False
deepspeed_inference_config: Dict[str, Any] = field(default_factory=dict)
# peft options
peft_type: Optional[str] = None
peft_config: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
if self.model_kwargs.get("torch_dtype", None) is not None:
raise ValueError(
"`torch_dtype` is an explicit argument in the PyTorch backend config. "
"Please remove it from the `model_kwargs` and set it in the backend config directly."
)
if self.device_map is not None and self.device_map not in DEVICE_MAPS:
raise ValueError(f"`device_map` must be one of {DEVICE_MAPS}. Got {self.device_map} instead.")
if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")
if self.autocast_dtype is not None and self.autocast_dtype not in AMP_DTYPES:
raise ValueError(f"`autocast_dtype` must be one of {AMP_DTYPES}. Got {self.autocast_dtype} instead.")
if self.quantization_scheme is not None:
if self.quantization_scheme not in QUANTIZATION_CONFIGS:
raise ValueError(
f"`quantization_scheme` must be one of {list(QUANTIZATION_CONFIGS.keys())}. "
f"Got {self.quantization_scheme} instead."
)
if self.quantization_scheme == "bnb" and is_rocm_system():
raise ValueError("BitsAndBytes is not supported on ROCm GPUs. Please disable it.")
if self.quantization_config:
QUANTIZATION_CONFIG = QUANTIZATION_CONFIGS[self.quantization_scheme]
self.quantization_config = {**QUANTIZATION_CONFIG, **self.quantization_config}