Skip to content

Commit

Permalink
Add IPEX Allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
bigPYJ1151 committed Jul 12, 2024
1 parent 3d67c66 commit b45597d
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 27 deletions.
5 changes: 2 additions & 3 deletions Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ RUN pip install intel-openmp

ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"


RUN echo 'ulimit -c 0' >> ~/.bashrc

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install --proxy http://child-prc.intel.com:913 http://mlpc.intel.com/downloads/cpu/ipex-2.4/rc0/intel_extension_for_pytorch-2.4.0-cp310-cp310-manylinux2014_x86_64.whl

RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
Expand All @@ -27,7 +26,7 @@ COPY ./ /workspace/vllm

WORKDIR /workspace/vllm

RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu

# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
ARG VLLM_CPU_DISABLE_AVX512
Expand Down
26 changes: 13 additions & 13 deletions csrc/cpu/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ void init_cpu_threads_env(const std::string& cpu_ids) {
}

// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
CPU_ZERO_S(size, mask);
CPU_SET_S(omp_cpu_ids[i], size, mask);
sched_setaffinity(0, sizeof(cpu_set_t), mask);
CPU_FREE(mask);
}
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
CPU_ZERO_S(size, mask);
CPU_SET_S(omp_cpu_ids[i], size, mask);
sched_setaffinity(0, sizeof(cpu_set_t), mask);
CPU_FREE(mask);
}

numa_free_nodemask(omp_cpu_mask);
}
2 changes: 0 additions & 2 deletions docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ Intel Extension for PyTorch

- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.

- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.

.. _cpu_backend_performance_tips:

Performance tips
Expand Down
4 changes: 2 additions & 2 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
torch == 2.4.0; platform_machine != "ppc64le"
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
3 changes: 3 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
Expand Down
26 changes: 21 additions & 5 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _init_executor(self) -> None:

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

# Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str:
Expand All @@ -49,6 +49,10 @@ def _init_executor(self) -> None:
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"

# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
self.parallel_config.tensor_parallel_size)

self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
Expand Down Expand Up @@ -250,16 +254,28 @@ def list_loras(self) -> Set[int]:

def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
return all(
self._run_workers(
"add_prompt_adapter",
prompt_adapter_request,
))

def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
return all(
self._run_workers(
"remove_prompt_adapter",
prompt_adapter_id,
))

def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
return self.driver_method_invoker(self.driver_worker,
"list_prompt_adapters")

def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
return all(self._run_workers(
"pin_prompt_adapter",
prompt_adapter_id,
))

def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
Expand Down
2 changes: 0 additions & 2 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""A CPU worker class."""
from typing import Dict, List, Optional, Tuple
import os

import torch
import torch.distributed
Expand Down Expand Up @@ -183,7 +182,6 @@ def __init__(
self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[List[torch.Tensor]]


def init_device(self) -> None:
if self.local_omp_cpuid != "all":
torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
Expand Down

0 comments on commit b45597d

Please sign in to comment.