Skip to content

Commit

Permalink
auto select best device (#822)
Browse files Browse the repository at this point in the history
* auto select best device

* ruff

* fix wrong check

* update readme

* fix circular import
  • Loading branch information
CSY-ModelCloud authored Dec 12, 2024
1 parent 5452e0b commit dffa089
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Below is a basic sample using `GPTQModel` to quantize a llm model and perform po
```py
from datasets import load_dataset
from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig, get_best_device
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"
Expand All @@ -154,7 +154,7 @@ model.quantize(calibration_dataset)

model.save(quant_path)

model = GPTQModel.load(quant_path, device=get_best_device())
model = GPTQModel.load(quant_path)

result = model.generate(
**tokenizer(
Expand Down
4 changes: 4 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from huggingface_hub import list_repo_files
from transformers import AutoConfig

from ._const import get_best_device
from ..utils import BACKEND, EVAL
from ..utils.logger import setup_logger
from ..utils.model import check_and_get_model_type
Expand Down Expand Up @@ -147,6 +148,9 @@ def load(
is_quantized = True
break

if not device and not device_map:
device = get_best_device()

if is_quantized:
return cls.from_quantized(
model_id_or_path=model_id_or_path,
Expand Down

0 comments on commit dffa089

Please sign in to comment.