Skip to content

Commit

Permalink
solve pre-commit issue
Browse files Browse the repository at this point in the history
  • Loading branch information
AoyuQC committed Jan 28, 2025
1 parent 900efe4 commit 14367f0
Show file tree
Hide file tree
Showing 11 changed files with 423 additions and 530 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
.ipynb

# IPython
profile_default/
Expand Down
445 changes: 389 additions & 56 deletions examples/neuron_v1.py

Large diffs are not rendered by default.

38 changes: 13 additions & 25 deletions examples/offline_model_neuron.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import os
import tempfile

from vllm import LLM, SamplingParams
from vllm import SamplingParams
from vllm.attention.backends.neuron_attn import NeuronAttentionBackend
from vllm.config import VllmConfig
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import ensure_model_parallel_initialized, init_distributed_environment
# from vllm.config import VllmConfig
# from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment
)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.logits_processor import _prune_hidden_states
# from vllm.model_executor.layers.logits_processor import _prune_hidden_states
from vllm.model_executor.model_loader import get_model

import torch
import torch_neuronx
import torch.nn as nn
# import torch_neuronx
# import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.neuron.compiler import neuron_argmax
# from vllm.neuron.compiler import neuron_argmax

# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
Expand Down Expand Up @@ -68,7 +71,7 @@
)

attn_backend = NeuronAttentionBackend
vllm_config=config.create_engine_config()
vllm_config = config.create_engine_config()
device = xm.xla_device()
model = get_model(vllm_config=vllm_config)
model = model.eval().to(device)
Expand All @@ -86,7 +89,6 @@ def forward(
inputs_embeds,
sampling_metadata
):
# hidden_states, (attn_input, q, k, v, attn_out, mlp_output, mlp_input) = model(
hidden_states = model(
input_ids,
positions,
Expand All @@ -97,13 +99,6 @@ def forward(
)

return hidden_states
# hidden_states = hidden_states.flatten(0, 1)
# logits = model.compute_logits(hidden_states, sampling_metadata)[-1, :100]
# argmax_token_ids = neuron_argmax(logits, dim=-1, keepdim=True)
# argmax_token_ids = argmax_token_ids.repeat(1, 1)
# return argmax_token_i
return logits


compiled_model = torch.compile(forward,
backend="openxla",
Expand Down Expand Up @@ -161,11 +156,4 @@ def forward(
inputs_embeds=None,
sampling_metadata=sampling_metadata
)
print(output)
# print("Q:", q, q.shape)
# # print("W_Q:", w_q, w_q.shape)
# print("Attn input:", attn_input, attn_input.shape)
# print("K:", k, k.shape)
# print("attn_out:", attn_out, attn_out.shape)
# print("mlp_input:", mlp_input, mlp_input.shape)
# print("mlp_output:", mlp_output, mlp_output.shape)
print(output)
Loading

0 comments on commit 14367f0

Please sign in to comment.