Skip to content

Commit

Permalink
Merge pull request vllm-project#2 from slyalin/optimum_models_after_r…
Browse files Browse the repository at this point in the history
…eorg

Re-enable optimum-intel path
  • Loading branch information
ilya-lavrenov authored Mar 18, 2024
2 parents a920809 + b98f5ba commit e913d6b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
15 changes: 15 additions & 0 deletions vllm/executor/openvino_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ def _init_distributed_environment(self) -> None:
ensure_model_parallel_initialized(self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)

def __del__(self):
# TODO: Better to put this code in a wrapper around optimum-based model inside OpenVINO model loader
# but it requires more coding because it should be a full-functional substitution of torch.nn.Module.
# The current solution to put the code here is not robust enough: self.model_runner is not our class instance
# and it can be modified in a way that model is no longer kept as self.model_runner.model attribute.
if not (hasattr(self.model_runner, 'model') and hasattr(self.model_runner.model, 'model')):
return
pt_model = self.model_runner.model
if hasattr(pt_model, 'ov_node_factory'):
del pt_model._ov_request
del pt_model.model
if gc: # when app is being destroyed the module may not be available
gc.collect()
del pt_model.ov_node_factory


class OpenVINOExecutor(ExecutorBase):

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def forward(
if self.logits_as_hidden_states:
logits = hidden_states
if is_openvino_optimum_intel():
# TODO: Fuse this step to the model inference
logits = _prune_hidden_states(logits, sampling_metadata)
else:
hidden_states = _prune_hidden_states(hidden_states,
Expand Down
21 changes: 14 additions & 7 deletions vllm/model_executor/openvino_model_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for selecting and loading models."""
from functools import partial
from typing import Optional

import math
import torch
import numpy as np
Expand All @@ -11,6 +11,8 @@
from vllm.sequence import SamplerOutput
from vllm.utils import is_openvino_optimum_intel

import openvino as ov


def _flattenize_inputs(inputs):
"""
Expand Down Expand Up @@ -53,7 +55,7 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor:


def patch_stateful_model(
model: torch.nn.Module,
model: ov.Model,
factory):
print('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM')
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher, AnyInput, Or
Expand Down Expand Up @@ -194,7 +196,14 @@ def __init__(self):
seq = WrapType("opset13.Gather", [kv_shape, AnyInput(), AnyInput()])

def callback(m: Matcher) -> bool:
replace_node(m.get_match_root(), max_context_len)
gather = m.get_match_root()
target_type = gather.get_output_element_type(0)
if max_context_len.get_output_element_type(0) != target_type:
print(f'Converting {max_context_len.get_output_element_type(0)} of max_context_len to {target_type}')
replacement = opset13.convert(max_context_len, target_type)
else:
replacement = max_context_len
replace_node(gather, replacement)
print("DETECTED PATTERN FOR max_sequence_length, CONNECTED TO A DEDICATED PARAMETER")
return True

Expand Down Expand Up @@ -270,7 +279,6 @@ def _patch_model_with_openvino(
from vllm.model_executor.layers.attention.attention import Attention
from openvino.frontend.pytorch import ModuleExtension
from openvino import Core, convert_model, Type, PartialShape
from functools import partial

# Avoid usage of vllm._C.ops

Expand Down Expand Up @@ -426,7 +434,7 @@ def get_model(model_config: ModelConfig,

pt_model = None

if is_openvino_optimum_intel() and False:
if is_openvino_optimum_intel():
import openvino as ov
from optimum.intel import OVModelForCausalLM
pt_model = OVModelForCausalLM.from_pretrained(model_config.model, export=True, compile=False, load_in_8bit=False, trust_remote_code=True) # need stateful because it also enables SDPA
Expand All @@ -438,9 +446,8 @@ def get_model(model_config: ModelConfig,
patch_stateful_model(pt_model.model, pt_model.ov_node_factory)
core = ov.Core()
ov_compiled = core.compile_model(pt_model.model, "CPU")
pt_model.ov_request = ov_compiled.create_infer_request()
pt_model._ov_request = ov_compiled.create_infer_request()

from functools import partial
pt_model._openvino_patch_orig_forward = pt_model.forward
pt_model.forward = partial(ov_wrapper, pt_model)

Expand Down

0 comments on commit e913d6b

Please sign in to comment.