Skip to content

Commit

Permalink
ENH: auto detect device in pytorch model (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki authored Aug 8, 2023
1 parent 797ad49 commit 22146b0
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 44 deletions.
8 changes: 8 additions & 0 deletions xinference/model/llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def _is_darwin_and_apple_silicon():
def _is_linux():
return platform.system() == "Linux"

@staticmethod
def _is_darwin():
return platform.system() == "Darwin"

@staticmethod
def _is_arm():
return platform.processor() == "arm"

@abstractmethod
def load(self):
raise NotImplementedError
Expand Down
10 changes: 4 additions & 6 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,11 @@ def _is_linux():


def _has_cuda_device():
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices:
return True
else:
from xorbits._mars.resource import cuda_count
# `cuda_count` method already contains the logic for the
# number of GPUs specified by `CUDA_VISIBLE_DEVICES`.
from xorbits._mars.resource import cuda_count

return cuda_count() > 0
return cuda_count() > 0


def get_user_defined_llm_families():
Expand Down
83 changes: 45 additions & 38 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
EmbeddingUsage,
)
from ..core import LLM
from ..llm_family import LLMFamilyV1, LLMSpecV1
from ..llm_family import LLMFamilyV1, LLMSpecV1, _has_cuda_device
from ..utils import ChatModelMixin

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
self._pytorch_model_config: PytorchModelConfig = self._sanitize_model_config(
pytorch_model_config
)
self._device = self._pytorch_model_config["device"]

def _sanitize_model_config(
self, pytorch_model_config: Optional[PytorchModelConfig]
Expand All @@ -86,10 +87,10 @@ def _sanitize_model_config(
pytorch_model_config.setdefault("gptq_wbits", 16)
pytorch_model_config.setdefault("gptq_groupsize", -1)
pytorch_model_config.setdefault("gptq_act_order", False)
if self._is_darwin_and_apple_silicon():
pytorch_model_config.setdefault("device", "mps")
else:
pytorch_model_config.setdefault("device", "cuda")
pytorch_model_config.setdefault("device", "auto")
pytorch_model_config["device"] = self._select_device(
pytorch_model_config["device"]
)
return pytorch_model_config

def _sanitize_generate_config(
Expand All @@ -105,6 +106,25 @@ def _sanitize_generate_config(
pytorch_generate_config["model"] = self.model_uid
return pytorch_generate_config

def _select_device(self, device):
if device == "auto":
if self._is_darwin():
return "mps" if self._is_arm() else "cpu"
return "cuda" if _has_cuda_device() else "cpu"
elif device == "cuda":
if not _has_cuda_device():
raise ValueError("No cuda device is detected in your environment")
elif device == "mps":
if not self._is_darwin_and_apple_silicon():
raise ValueError(
"mps is only available for Mac computers with Apple silicon"
)
elif device == "cpu":
pass
else:
raise ValueError(f"Device {device} is not supported in temporary")
return device

def _load_model(self, kwargs: dict):
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -142,23 +162,19 @@ def load(self):

quantization = self.quantization
num_gpus = self._pytorch_model_config.get("num_gpus", 1)
if self._is_darwin_and_apple_silicon():
device = self._pytorch_model_config.get("device", "mps")
else:
device = self._pytorch_model_config.get("device", "cuda")

if device == "cpu":
if self._device == "cpu":
kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":
elif self._device == "cuda":
kwargs = {"torch_dtype": torch.float16}
elif device == "mps":
elif self._device == "mps":
kwargs = {"torch_dtype": torch.float16}
else:
raise ValueError(f"Device {device} is not supported in temporary")
raise ValueError(f"Device {self._device} is not supported in temporary")
kwargs["revision"] = self._pytorch_model_config.get("revision", "main")

if quantization != "none":
if device == "cuda" and self._is_linux():
if self._device == "cuda" and self._is_linux():
kwargs["device_map"] = "auto"
if quantization == "4-bit":
kwargs["load_in_4bit"] = True
Expand All @@ -185,7 +201,7 @@ def load(self):
else:
self._model, self._tokenizer = load_compress_model(
model_path=self.model_path,
device=device,
device=self._device,
torch_dtype=kwargs["torch_dtype"],
use_fast=self._use_fast_tokenizer,
revision=kwargs["revision"],
Expand All @@ -196,9 +212,9 @@ def load(self):
self._model, self._tokenizer = self._load_model(kwargs)

if (
device == "cuda" and num_gpus == 1 and quantization == "none"
) or device == "mps":
self._model.to(device)
self._device == "cuda" and num_gpus == 1 and quantization == "none"
) or self._device == "mps":
self._model.to(self._device)
logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")

@classmethod
Expand Down Expand Up @@ -229,21 +245,21 @@ def generate(
)

def generator_wrapper(
prompt: str, device: str, generate_config: PytorchGenerateConfig
prompt: str, generate_config: PytorchGenerateConfig
) -> Iterator[CompletionChunk]:
if "falcon" in self.model_family.model_name:
for completion_chunk, _ in generate_stream_falcon(
self._model, self._tokenizer, prompt, device, generate_config
self._model, self._tokenizer, prompt, self._device, generate_config
):
yield completion_chunk
elif "chatglm" in self.model_family.model_name:
for completion_chunk, _ in generate_stream_chatglm(
self._model, self._tokenizer, prompt, device, generate_config
self._model, self._tokenizer, prompt, self._device, generate_config
):
yield completion_chunk
else:
for completion_chunk, _ in generate_stream(
self._model, self._tokenizer, prompt, device, generate_config
self._model, self._tokenizer, prompt, self._device, generate_config
):
yield completion_chunk

Expand All @@ -257,24 +273,20 @@ def generator_wrapper(
assert self._tokenizer is not None

stream = generate_config.get("stream", False)
if self._is_darwin_and_apple_silicon():
device = self._pytorch_model_config.get("device", "mps")
else:
device = self._pytorch_model_config.get("device", "cuda")
if not stream:
if "falcon" in self.model_family.model_name:
for completion_chunk, completion_usage in generate_stream_falcon(
self._model, self._tokenizer, prompt, device, generate_config
self._model, self._tokenizer, prompt, self._device, generate_config
):
pass
elif "chatglm" in self.model_family.model_name:
for completion_chunk, completion_usage in generate_stream_chatglm(
self._model, self._tokenizer, prompt, device, generate_config
self._model, self._tokenizer, prompt, self._device, generate_config
):
pass
else:
for completion_chunk, completion_usage in generate_stream(
self._model, self._tokenizer, prompt, device, generate_config
self._model, self._tokenizer, prompt, self._device, generate_config
):
pass
completion = Completion(
Expand All @@ -287,7 +299,7 @@ def generator_wrapper(
)
return completion
else:
return generator_wrapper(prompt, device, generate_config)
return generator_wrapper(prompt, generate_config)

def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
try:
Expand All @@ -298,11 +310,6 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
"Could not import torch. Please install it with `pip install torch`."
) from e

if self._is_darwin_and_apple_silicon():
device = self._pytorch_model_config.get("device", "mps")
else:
device = self._pytorch_model_config.get("device", "cuda")

if isinstance(input, str):
inputs = [input]
else:
Expand All @@ -315,8 +322,8 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
encoding = tokenizer.batch_encode_plus(
inputs, padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
input_ids = encoding["input_ids"].to(self._device)
attention_mask = encoding["attention_mask"].to(self._device)
model_output = self._model(
input_ids, attention_mask, output_hidden_states=True
)
Expand Down Expand Up @@ -349,7 +356,7 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
embedding = []
token_num = 0
for index, text in enumerate(inputs):
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
input_ids = tokenizer.encode(text, return_tensors="pt").to(self._device)
model_output = self._model(input_ids, output_hidden_states=True)
if is_chatglm:
data = (model_output.hidden_states[-1].transpose(0, 1))[0]
Expand Down

0 comments on commit 22146b0

Please sign in to comment.