Skip to content

Commit f75aa72

Browse files
liangfugnovackAoyuQC
authored
[Neuron] Add custom_ops for neuron backend (#13246)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: George Novack <gnovack@amazon.com> Co-authored-by: Aoyu Zhang <aoyuzhan@amazon.com>
1 parent 340e39e commit f75aa72

File tree

9 files changed

+346
-3
lines changed

9 files changed

+346
-3
lines changed

tests/neuron/test_activation.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
7+
from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
8+
from vllm.platforms import current_platform
9+
10+
11+
@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
12+
@pytest.mark.parametrize("num_tokens,d,dtype", [
13+
(7, 512, torch.half),
14+
(7, 512, torch.float),
15+
(83, 512, torch.half),
16+
])
17+
@torch.inference_mode()
18+
def test_act_and_mul(
19+
activation: str,
20+
num_tokens: int,
21+
d: int,
22+
dtype: torch.dtype,
23+
) -> None:
24+
import torch_xla.core.xla_model as xm
25+
26+
device = xm.xla_device()
27+
current_platform.seed_everything(0)
28+
torch.set_default_device("cpu")
29+
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
30+
if activation == "silu_and_mul":
31+
layer = SiluAndMul()
32+
fn = layer.forward_native
33+
elif activation == "gelu_fast":
34+
layer = FastGELU()
35+
fn = F.gelu
36+
else:
37+
raise NotImplementedError(
38+
f"activation {activation} is not implemented.")
39+
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
40+
out = layer.to(device=device).forward_neuron(x)
41+
ref_out = fn(x.cpu())
42+
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)

tests/neuron/test_layernorm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.layernorm import RMSNorm
7+
from vllm.platforms import current_platform
8+
9+
10+
@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
11+
(7, 8, False, torch.half),
12+
(83, 768, False, torch.half),
13+
(83, 768, True, torch.half),
14+
(83, 768, True, torch.bfloat16),
15+
(83, 768, True, torch.float32),
16+
])
17+
@torch.inference_mode()
18+
def test_rms_norm(
19+
num_tokens: int,
20+
hidden_size: int,
21+
add_residual: bool,
22+
dtype: torch.dtype,
23+
) -> None:
24+
import torch_xla.core.xla_model as xm
25+
26+
device = xm.xla_device()
27+
current_platform.seed_everything(0)
28+
torch.set_default_device("cpu")
29+
layer = RMSNorm(hidden_size).to(dtype=dtype)
30+
layer.weight.data.normal_(mean=1.0, std=0.1)
31+
scale = 1 / (2 * hidden_size)
32+
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
33+
x *= scale
34+
residual = torch.randn_like(x) * scale if add_residual else None
35+
36+
residual_cpu = residual.cpu() if add_residual else None
37+
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
38+
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
39+
out = layer.to(device=device)(x, residual)
40+
41+
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
42+
# numerical errors than other operators because they involve reductions.
43+
# Therefore, we use a larger tolerance.
44+
if add_residual:
45+
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
46+
torch.testing.assert_close(out[0].cpu(),
47+
ref_out[0],
48+
atol=1e-2,
49+
rtol=1e-2)
50+
torch.testing.assert_close(out[1].cpu(),
51+
ref_out[1],
52+
atol=1e-2,
53+
rtol=1e-2)
54+
else:
55+
assert out.is_xla, "output tensor is expected to be XLA tensor"
56+
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import random
4+
from typing import Tuple
5+
from unittest.mock import patch
6+
7+
import pytest
8+
import torch
9+
10+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11+
from vllm.model_executor.sampling_metadata import SamplingMetadata
12+
from vllm.model_executor.utils import set_random_seed
13+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
14+
from vllm.utils import is_pin_memory_available
15+
16+
17+
class MockLogitsProcessor(LogitsProcessor):
18+
19+
def __init__(self, vocab_size: int, scale: float,
20+
fake_logits: torch.Tensor):
21+
super().__init__(vocab_size=vocab_size, scale=scale)
22+
self.fake_logits = fake_logits.clone()
23+
24+
def forward(self, *args, **kwargs):
25+
with patch(
26+
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
27+
lambda x, y: x
28+
), patch(
29+
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
30+
lambda *args, **kwargs: self.fake_logits):
31+
return super().forward(*args, **kwargs)
32+
33+
34+
def _prepare_test(
35+
batch_size: int
36+
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
37+
vocab_size = 32000
38+
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
39+
fake_logits = torch.full((batch_size, vocab_size),
40+
1e-2,
41+
dtype=input_tensor.dtype)
42+
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
43+
return input_tensor, fake_logits, logits_processor
44+
45+
46+
RANDOM_SEEDS = list(range(8))
47+
48+
49+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
50+
def test_logits_processors(seed: int):
51+
import torch_xla.core.xla_model as xm
52+
53+
device = xm.xla_device()
54+
set_random_seed(seed)
55+
torch.set_default_device("cpu")
56+
batch_size = random.randint(1, 256)
57+
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
58+
59+
# This sample logits processor gives infinite score to the i-th token,
60+
# where i is the length of the input sequence.
61+
# We therefore expect the output token sequence to be [0, 1, 2, ...]
62+
def pick_ith(token_ids, logits):
63+
logits[len(token_ids)] = float("inf")
64+
return logits
65+
66+
seq_group_metadata_list = []
67+
seq_lens = []
68+
for i in range(batch_size):
69+
seq_group_metadata_list.append(
70+
SequenceGroupMetadata(
71+
request_id=f"test_{i}",
72+
is_prompt=True,
73+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
74+
sampling_params=SamplingParams(temperature=0,
75+
logits_processors=[pick_ith]),
76+
block_tables={0: [1]},
77+
))
78+
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
79+
80+
sampling_metadata = SamplingMetadata.prepare(
81+
seq_group_metadata_list,
82+
seq_lens,
83+
query_lens=seq_lens,
84+
device=device,
85+
pin_memory=is_pin_memory_available())
86+
logits_processor_output = logits_processor(
87+
lm_head=None,
88+
hidden_states=input_tensor,
89+
sampling_metadata=sampling_metadata)
90+
91+
fake_logits *= logits_processor.scale
92+
torch.testing.assert_close(logits_processor_output[:, 1],
93+
fake_logits[:, 1],
94+
rtol=1e-4,
95+
atol=0.0)

tests/neuron/test_prefix_prefill.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def test_contexted_kv_attention(
345345

346346
torch.manual_seed(0)
347347
torch.set_printoptions(sci_mode=False)
348+
torch.set_default_device("cpu")
348349
dtype = torch.float32
349350

350351
min_ctx_len = 32
@@ -438,9 +439,9 @@ def pad_to_next_power_of_2(a):
438439

439440
# transform block table
440441
active_block_table = get_active_block_tables(
441-
block_table,
442-
torch.tensor(query_lens),
443-
torch.tensor(seq_lens),
442+
block_table.cpu(),
443+
torch.tensor(query_lens).cpu(),
444+
torch.tensor(seq_lens).cpu(),
444445
block_size,
445446
num_active_blocks,
446447
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Tests for miscellaneous utilities
4+
"""
5+
6+
import pytest
7+
import torch
8+
9+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
10+
from vllm.platforms import current_platform
11+
12+
13+
@pytest.mark.parametrize(
14+
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
15+
(16, False, 32, 32, 1024),
16+
(16, False, 32, 128, 1024),
17+
(16, True, 32, 32, 1024),
18+
(16, True, 32, 128, 1024),
19+
])
20+
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
21+
head_size, seq_len):
22+
import torch_xla.core.xla_model as xm
23+
24+
device = xm.xla_device()
25+
current_platform.seed_everything(0)
26+
torch.set_default_device("cpu")
27+
28+
batch_size = 1
29+
base = 10000
30+
num_heads = 8
31+
32+
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
33+
is_neox_style, torch.float32)
34+
35+
positions = torch.randint(0,
36+
max_position, (batch_size, seq_len),
37+
device="cpu")
38+
query = torch.randn(batch_size,
39+
seq_len,
40+
num_heads * head_size,
41+
dtype=torch.float32,
42+
device="cpu")
43+
key = torch.randn_like(query)
44+
45+
assert positions.is_cpu, \
46+
"reference input tensor is expected to be CPU tensor."
47+
ref_query, ref_key = rot.to(device="cpu").forward_native(
48+
positions, query, key)
49+
out_query, out_key = rot.to(device=device).forward_neuron(
50+
positions.to(device=device), query.to(device=device),
51+
key.to(device=device))
52+
assert out_query.is_xla and out_key.is_xla, \
53+
"output tensor is expected to be XLA tensor"
54+
torch.testing.assert_close(out_query.cpu(),
55+
ref_query,
56+
atol=1e-2,
57+
rtol=1e-2)
58+
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)

vllm/model_executor/custom_op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def forward_hpu(self, *args, **kwargs):
5959
# PyTorch-native implementation.
6060
return self.forward_native(*args, **kwargs)
6161

62+
def forward_neuron(self, *args, **kwargs):
63+
# By default, we assume that Neuron ops are compatible with the
64+
# PyTorch-native implementation.
65+
return self.forward_native(*args, **kwargs)
66+
6267
def forward_oot(self, *args, **kwargs):
6368
# By default, we assume that OOT ops are compatible with the
6469
# PyTorch-native implementation.
@@ -88,6 +93,8 @@ def dispatch_forward(self):
8893
return self.forward_tpu
8994
elif current_platform.is_xpu():
9095
return self.forward_xpu
96+
elif current_platform.is_neuron():
97+
return self.forward_neuron
9198
elif current_platform.is_out_of_tree():
9299
return self.forward_oot
93100
else:

vllm/model_executor/layers/activation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
8989
self.op(out, x)
9090
return out
9191

92+
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
93+
d = x.shape[-1] // 2
94+
x_reshaped = x.view(-1, x.shape[-1])
95+
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
96+
result = s * x_reshaped[:, d:]
97+
return result.view(*x.shape[:-1], d)
98+
9299

93100
@CustomOp.register("mul_and_silu")
94101
class MulAndSilu(CustomOp):

vllm/model_executor/layers/logits_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
# Whether to use gather or all-gather to gather the logits.
5454
parallel_config = get_current_vllm_config().parallel_config
5555
self.use_all_gather = current_platform.is_tpu() \
56+
or current_platform.is_neuron() \
5657
or envs.VLLM_USE_V1 \
5758
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
5859

0 commit comments

Comments
 (0)