Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][TPU] Support MoE with Pallas GMM kernel #6457

Merged
merged 26 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
ARG NIGHTLY_DATE="20240601"
ARG NIGHTLY_DATE="20240713"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE
WORKDIR /workspace

# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install NumPy 1 instead of NumPy 2.
RUN pip install "numpy<2"
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y

$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240601"
$ export DATE="+20240713"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl

Expand Down Expand Up @@ -85,7 +85,7 @@ Next, build vLLM from source. This will only take a few seconds:
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory


You can install OpenBLAS with the following command:
Please install OpenBLAS with the following command:

.. code-block:: console

Expand Down
18 changes: 18 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")

def forward_tpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)


class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
Expand Down
62 changes: 62 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_pallas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
from torch_xla.experimental.custom_kernel import _histogram


def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens * topk}")

hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)

topk_indices = topk_indices.flatten()
topk_argsort_indices = topk_indices.argsort()
topk_argsort_revert_indices = topk_argsort_indices.argsort()
token_indices = torch.arange(num_tokens,
device=device).repeat_interleave(topk)
token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)

# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)

x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)

x = x * topk_weights.unsqueeze_(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x
9 changes: 4 additions & 5 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,10 @@ def _get_padded_prefill_len(x: int) -> int:


def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
elif batch_size <= 8:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
Expand Down
Loading