From 7ddfac24d716b152cd025ffcf96efced7575ac50 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 7 Aug 2023 16:23:13 +0800 Subject: [PATCH 1/3] auto detect device in pytorch model --- xinference/model/llm/core.py | 8 +++ xinference/model/llm/llm_family.py | 8 +-- xinference/model/llm/pytorch/core.py | 83 +++++++++++++++------------- 3 files changed, 55 insertions(+), 44 deletions(-) diff --git a/xinference/model/llm/core.py b/xinference/model/llm/core.py index 5e46843295..49c91bcaf4 100644 --- a/xinference/model/llm/core.py +++ b/xinference/model/llm/core.py @@ -53,6 +53,14 @@ def _is_darwin_and_apple_silicon(): def _is_linux(): return platform.system() == "Linux" + @staticmethod + def _is_darwin(): + return platform.system() == "Darwin" + + @staticmethod + def _is_arm(): + return platform.processor() == "arm" + @abstractmethod def load(self): raise NotImplementedError diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 1c385416f0..083bf6ff78 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -160,13 +160,9 @@ def _is_linux(): def _has_cuda_device(): - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") - if cuda_visible_devices: - return True - else: - from xorbits._mars.resource import cuda_count + from xorbits._mars.resource import cuda_count - return cuda_count() > 0 + return cuda_count() > 0 def get_user_defined_llm_families(): diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index af6fa907f9..e424f9c245 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -26,7 +26,7 @@ EmbeddingUsage, ) from ..core import LLM -from ..llm_family import LLMFamilyV1, LLMSpecV1 +from ..llm_family import LLMFamilyV1, LLMSpecV1, _has_cuda_device from ..utils import ChatModelMixin logger = logging.getLogger(__name__) @@ -73,6 +73,7 @@ def __init__( self._pytorch_model_config: PytorchModelConfig = self._sanitize_model_config( pytorch_model_config ) + self._device = self._select_device() def _sanitize_model_config( self, pytorch_model_config: Optional[PytorchModelConfig] @@ -80,16 +81,13 @@ def _sanitize_model_config( if pytorch_model_config is None: pytorch_model_config = PytorchModelConfig() pytorch_model_config.setdefault("revision", "main") + pytorch_model_config.setdefault("device", "auto") pytorch_model_config.setdefault("gpus", None) pytorch_model_config.setdefault("num_gpus", 1) pytorch_model_config.setdefault("gptq_ckpt", None) pytorch_model_config.setdefault("gptq_wbits", 16) pytorch_model_config.setdefault("gptq_groupsize", -1) pytorch_model_config.setdefault("gptq_act_order", False) - if self._is_darwin_and_apple_silicon(): - pytorch_model_config.setdefault("device", "mps") - else: - pytorch_model_config.setdefault("device", "cuda") return pytorch_model_config def _sanitize_generate_config( @@ -105,6 +103,28 @@ def _sanitize_generate_config( pytorch_generate_config["model"] = self.model_uid return pytorch_generate_config + def _select_device(self): + device = self._pytorch_model_config.get("device", "auto") + if device == "auto": + if self._is_darwin(): + return "mps" if self._is_arm() else "cpu" + return "cuda" if _has_cuda_device() else "cpu" + elif device == "cuda": + if not _has_cuda_device(): + raise ValueError( + "No cuda device is detected in your environment, please set device to cpu" + ) + elif device == "mps": + if not self._is_darwin_and_apple_silicon(): + raise ValueError( + "mps is only used on Mac M1/M2 machines, please set device to cpu" + ) + elif device == "cpu": + pass + else: + raise ValueError(f"Device {device} is not supported in temporary") + return device + def _load_model(self, kwargs: dict): try: from transformers import AutoModelForCausalLM, AutoTokenizer @@ -142,23 +162,19 @@ def load(self): quantization = self.quantization num_gpus = self._pytorch_model_config.get("num_gpus", 1) - if self._is_darwin_and_apple_silicon(): - device = self._pytorch_model_config.get("device", "mps") - else: - device = self._pytorch_model_config.get("device", "cuda") - if device == "cpu": + if self._device == "cpu": kwargs = {"torch_dtype": torch.float32} - elif device == "cuda": + elif self._device == "cuda": kwargs = {"torch_dtype": torch.float16} - elif device == "mps": + elif self._device == "mps": kwargs = {"torch_dtype": torch.float16} else: - raise ValueError(f"Device {device} is not supported in temporary") + raise ValueError(f"Device {self._device} is not supported in temporary") kwargs["revision"] = self._pytorch_model_config.get("revision", "main") if quantization != "none": - if device == "cuda" and self._is_linux(): + if self._device == "cuda" and self._is_linux(): kwargs["device_map"] = "auto" if quantization == "4-bit": kwargs["load_in_4bit"] = True @@ -178,7 +194,7 @@ def load(self): else: self._model, self._tokenizer = load_compress_model( model_path=self.model_path, - device=device, + device=self._device, torch_dtype=kwargs["torch_dtype"], use_fast=self._use_fast_tokenizer, revision=kwargs["revision"], @@ -189,9 +205,9 @@ def load(self): self._model, self._tokenizer = self._load_model(kwargs) if ( - device == "cuda" and num_gpus == 1 and quantization == "none" - ) or device == "mps": - self._model.to(device) + self._device == "cuda" and num_gpus == 1 and quantization == "none" + ) or self._device == "mps": + self._model.to(self._device) logger.debug(f"Model Memory: {self._model.get_memory_footprint()}") @classmethod @@ -222,21 +238,21 @@ def generate( ) def generator_wrapper( - prompt: str, device: str, generate_config: PytorchGenerateConfig + prompt: str, generate_config: PytorchGenerateConfig ) -> Iterator[CompletionChunk]: if "falcon" in self.model_family.model_name: for completion_chunk, _ in generate_stream_falcon( - self._model, self._tokenizer, prompt, device, generate_config + self._model, self._tokenizer, prompt, self._device, generate_config ): yield completion_chunk elif "chatglm" in self.model_family.model_name: for completion_chunk, _ in generate_stream_chatglm( - self._model, self._tokenizer, prompt, device, generate_config + self._model, self._tokenizer, prompt, self._device, generate_config ): yield completion_chunk else: for completion_chunk, _ in generate_stream( - self._model, self._tokenizer, prompt, device, generate_config + self._model, self._tokenizer, prompt, self._device, generate_config ): yield completion_chunk @@ -250,24 +266,20 @@ def generator_wrapper( assert self._tokenizer is not None stream = generate_config.get("stream", False) - if self._is_darwin_and_apple_silicon(): - device = self._pytorch_model_config.get("device", "mps") - else: - device = self._pytorch_model_config.get("device", "cuda") if not stream: if "falcon" in self.model_family.model_name: for completion_chunk, completion_usage in generate_stream_falcon( - self._model, self._tokenizer, prompt, device, generate_config + self._model, self._tokenizer, prompt, self._device, generate_config ): pass elif "chatglm" in self.model_family.model_name: for completion_chunk, completion_usage in generate_stream_chatglm( - self._model, self._tokenizer, prompt, device, generate_config + self._model, self._tokenizer, prompt, self._device, generate_config ): pass else: for completion_chunk, completion_usage in generate_stream( - self._model, self._tokenizer, prompt, device, generate_config + self._model, self._tokenizer, prompt, self._device, generate_config ): pass completion = Completion( @@ -280,7 +292,7 @@ def generator_wrapper( ) return completion else: - return generator_wrapper(prompt, device, generate_config) + return generator_wrapper(prompt, generate_config) def create_embedding(self, input: Union[str, List[str]]) -> Embedding: try: @@ -291,11 +303,6 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding: "Could not import torch. Please install it with `pip install torch`." ) from e - if self._is_darwin_and_apple_silicon(): - device = self._pytorch_model_config.get("device", "mps") - else: - device = self._pytorch_model_config.get("device", "cuda") - if isinstance(input, str): inputs = [input] else: @@ -308,8 +315,8 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding: encoding = tokenizer.batch_encode_plus( inputs, padding=True, return_tensors="pt" ) - input_ids = encoding["input_ids"].to(device) - attention_mask = encoding["attention_mask"].to(device) + input_ids = encoding["input_ids"].to(self._device) + attention_mask = encoding["attention_mask"].to(self._device) model_output = self._model( input_ids, attention_mask, output_hidden_states=True ) @@ -342,7 +349,7 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding: embedding = [] token_num = 0 for index, text in enumerate(inputs): - input_ids = tokenizer.encode(text, return_tensors="pt").to(device) + input_ids = tokenizer.encode(text, return_tensors="pt").to(self._device) model_output = self._model(input_ids, output_hidden_states=True) if is_chatglm: data = (model_output.hidden_states[-1].transpose(0, 1))[0] From 7050b451302ce134fb05377a88bcf2fd81123b9a Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 8 Aug 2023 12:20:26 +0800 Subject: [PATCH 2/3] fix comment --- xinference/model/llm/llm_family.py | 2 ++ xinference/model/llm/pytorch/core.py | 10 ++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 083bf6ff78..4a5cfa1013 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -160,6 +160,8 @@ def _is_linux(): def _has_cuda_device(): + # `cuda_count` method already contains the logic for the + # number of GPUs specified by `CUDA_VISIBLE_DEVICES`. from xorbits._mars.resource import cuda_count return cuda_count() > 0 diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index c3d634cae2..7e5a4a425b 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -73,7 +73,7 @@ def __init__( self._pytorch_model_config: PytorchModelConfig = self._sanitize_model_config( pytorch_model_config ) - self._device = self._select_device() + self._device = self._pytorch_model_config["device"] def _sanitize_model_config( self, pytorch_model_config: Optional[PytorchModelConfig] @@ -81,13 +81,13 @@ def _sanitize_model_config( if pytorch_model_config is None: pytorch_model_config = PytorchModelConfig() pytorch_model_config.setdefault("revision", "main") - pytorch_model_config.setdefault("device", "auto") pytorch_model_config.setdefault("gpus", None) pytorch_model_config.setdefault("num_gpus", 1) pytorch_model_config.setdefault("gptq_ckpt", None) pytorch_model_config.setdefault("gptq_wbits", 16) pytorch_model_config.setdefault("gptq_groupsize", -1) pytorch_model_config.setdefault("gptq_act_order", False) + pytorch_model_config["device"] = self._select_device() return pytorch_model_config def _sanitize_generate_config( @@ -111,13 +111,11 @@ def _select_device(self): return "cuda" if _has_cuda_device() else "cpu" elif device == "cuda": if not _has_cuda_device(): - raise ValueError( - "No cuda device is detected in your environment, please set device to cpu" - ) + raise ValueError("No cuda device is detected in your environment") elif device == "mps": if not self._is_darwin_and_apple_silicon(): raise ValueError( - "mps is only used on Mac M1/M2 machines, please set device to cpu" + "mps is only available for Mac computers with Apple silicon" ) elif device == "cpu": pass From ee81d317f2be63e4ec1bc3b1e002b5b89fd4b8b8 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 8 Aug 2023 12:59:38 +0800 Subject: [PATCH 3/3] fix --- xinference/model/llm/pytorch/core.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 7e5a4a425b..0645220b91 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -87,7 +87,10 @@ def _sanitize_model_config( pytorch_model_config.setdefault("gptq_wbits", 16) pytorch_model_config.setdefault("gptq_groupsize", -1) pytorch_model_config.setdefault("gptq_act_order", False) - pytorch_model_config["device"] = self._select_device() + pytorch_model_config.setdefault("device", "auto") + pytorch_model_config["device"] = self._select_device( + pytorch_model_config["device"] + ) return pytorch_model_config def _sanitize_generate_config( @@ -103,8 +106,7 @@ def _sanitize_generate_config( pytorch_generate_config["model"] = self.model_uid return pytorch_generate_config - def _select_device(self): - device = self._pytorch_model_config.get("device", "auto") + def _select_device(self, device): if device == "auto": if self._is_darwin(): return "mps" if self._is_arm() else "cpu"