Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mod vllm generate #833

Merged
merged 12 commits into from
Dec 13, 2024
4 changes: 2 additions & 2 deletions gptqmodel/integration/src/vllm/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
def update_config(self, prefix: str):
bits = self.weight_bits
# check for variable/dynamic config
if len(self.dynamic) > 0 and prefix:
if self.dynamic and len(self.dynamic) > 0 and prefix:
bits = self.dynamic_get(prefix, "bits", bits)
self.group_size = self.dynamic_get(prefix, "group_size", self.group_size)
self.desc_act = self.dynamic_get(prefix, "actorder", self.desc_act)
Expand Down Expand Up @@ -156,7 +156,7 @@ def dynamic_get(self, layer_name: str, key: str = None, default_value: Union[int
def get_quant_method(self, layer: torch.nn.Module,
prefix: str
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod", UnquantizedLinearMethod]]:
if self.dynamic_get(layer_name=prefix) == False: # noqa: E712
if self.dynamic and self.dynamic_get(layer_name=prefix) == False: # noqa: E712
return UnquantizedLinearMethod()

if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
Expand Down
12 changes: 10 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,13 +703,21 @@ def tmp(_, inp, out):
@property
def device(self):
if not self.hf_device_map:
return self.model.device
if hasattr(self.model, "device"):
return self.model.device
elif hasattr(self.model, "llm_engine"):
return self.model.llm_engine.device_config.device_type
else:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = [d for d in self.hf_device_map.values() if d not in {"disk"}][0]
return torch.device(device)

def to(self, device: Union[str, torch.device]):
self.model.to(device)
if hasattr(self.model, "to"):
self.model.to(device)
else:
logger.warning(f"{self.model.__class__.__name__} does not support the to() method")
return self

def forward(self, *args, **kwargs):
Expand Down
100 changes: 71 additions & 29 deletions gptqmodel/utils/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,49 @@

try:
from vllm import LLM, SamplingParams

VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False

VLLM_INSTALL_HINT = "vLLM not installed. Please install via `pip install -U vllm`."


# returns SamplingParams but we can't use this typehint since vLLM is optional depend
def convert_hf_params_to_vllm(hf_params: Dict[str, Any]):
if not VLLM_AVAILABLE:
raise ValueError(VLLM_INSTALL_HINT)
sampling_params = SamplingParams()

if hf_params.get('num_return_sequences', None):
sampling_params.n = hf_params.get('num_return_sequences')

if hf_params.get('repetition_penalty', None):
sampling_params.repetition_penalty = hf_params.get('repetition_penalty')

if hf_params.get('temperature', None):
sampling_params.temperature = hf_params.get('temperature')

if hf_params.get('top_k', None):
sampling_params.top_k = hf_params.get('top_k')

if hf_params.get('top_p', None):
sampling_params.top_p = hf_params.get('top_p')

if hf_params.get('max_length', None):
sampling_params.max_tokens = hf_params.get('max_length')

if hf_params.get('min_length', None):
sampling_params.min_tokens = hf_params.get('min_length')

if hf_params.get('eos_token_id', None):
sampling_params.stop_token_ids = [hf_params.get('eos_token_id'), None]

return sampling_params

params = {
'n': hf_params.get('num_return_sequences', 1),
'repetition_penalty': hf_params.get('repetition_penalty', 1.0),
'temperature': hf_params.get('temperature', 1.0),
'top_k': hf_params.get('top_k', -1),
'top_p': hf_params.get('top_p', 1.0),
'max_tokens': hf_params.get('max_length', 16),
'min_tokens': hf_params.get('min_length', 0),
'stop_token_ids': [hf_params.get('eos_token_id'), None],
}
return SamplingParams(**params)

def load_model_by_vllm(
model,
**kwargs,
model,
**kwargs,
):
if not VLLM_AVAILABLE:
raise ValueError(VLLM_INSTALL_HINT)
Expand All @@ -42,23 +58,49 @@ def load_model_by_vllm(

return model

@torch.inference_mode
def vllm_generate(
model,
**kwargs,
):

@torch.inference_mode()
def vllm_generate(model, **kwargs):
if not VLLM_AVAILABLE:
raise ValueError(VLLM_INSTALL_HINT)

prompts = kwargs.pop("prompts", None)
sampling_params = kwargs.pop("sampling_params", None)
# Extract and validate prompts
prompts = kwargs.pop("prompts", None) or kwargs.pop("input_ids", None)
if prompts is None:
raise ValueError("Either prompts or input_ids must be provided")

sampling_params = kwargs.get("sampling_params")
if not isinstance(sampling_params, SamplingParams):
hf_params = {key: kwargs[key] for key in [
'num_return_sequences', 'repetition_penalty', 'temperature',
'top_k', 'top_p', 'max_length', 'min_length', 'eos_token_id'
] if key in kwargs}
sampling_params = convert_hf_params_to_vllm(hf_params)

outputs = model.generate(prompts, sampling_params)
return outputs
hf_params = {
key: kwargs.get(key) for key in [
'num_return_sequences', 'repetition_penalty', 'temperature',
'top_k', 'top_p', 'max_length', 'min_length', 'eos_token_id'
]
}
sampling_params = convert_hf_params_to_vllm({k: v for k, v in hf_params.items() if v is not None})

# Convert prompts to vLLM format
if isinstance(prompts, torch.Tensor):
req_results = model.generate(prompt_token_ids=prompts.tolist(), sampling_params=sampling_params)
elif isinstance(prompts, list):
if isinstance(prompts[0], list) or isinstance(prompts[0], int):
req_results = model.generate(prompt_token_ids=prompts, sampling_params=sampling_params)
else:
req_results = model.generate(prompts=prompts, sampling_params=sampling_params)
elif isinstance(prompts, str):
req_results = model.generate(prompts=prompts, sampling_params=sampling_params)
else:
raise ValueError(f"Invalid input type for vllm_generate, type is {type(prompts)}")

outputs = []
for result in req_results:
combined_token_ids = result.prompt_token_ids + list(result.outputs[0].token_ids)
outputs.append(combined_token_ids)

pad_token_id = model.get_tokenizer().pad_token_id
if pad_token_id is None:
pad_token_id = model.get_tokenizer().eos_token_id
max_length = max(len(sublist) for sublist in outputs)
padded_list = [sublist + [pad_token_id] * (max_length - len(sublist)) for sublist in outputs]

return torch.Tensor(padded_list).to(torch.uint32)
55 changes: 30 additions & 25 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class TestLoadVLLM(unittest.TestCase):
@classmethod
def setUpClass(self):
if importlib.util.find_spec("flashinfer") is None:
subprocess.check_call([sys.executable, "-m", "pip", "install", "flashinfer", "-i", f"https://flashinfer.ai/whl/cu{torch.version.cuda.replace('.', '')}/torch{'.'.join(torch.__version__.split('.')[:2])}"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "flashinfer", "-i",
f"https://flashinfer.ai/whl/cu{torch.version.cuda.replace('.', '')}/torch{'.'.join(torch.__version__.split('.')[:2])}"])

if importlib.util.find_spec("vllm") is None:
subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm>=0.6.2"])
Expand All @@ -36,7 +37,7 @@ def setUpClass(self):
self.prompts = [
"The capital of France is",
]
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16, top_k=1)

def release_vllm_model(self):
from vllm.distributed.parallel_state import destroy_model_parallel # noqa: E402
Expand All @@ -52,25 +53,29 @@ def test_load_vllm(self):
backend=BACKEND.VLLM,
gpu_memory_utilization=0.2,
)

tokenizer = model.get_tokenizer()

outputs = model.generate(
prompts=self.prompts,
sampling_params=self.sampling_params,
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris, which is also the capital of France.")
outputs_param = model.generate(

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(self.prompts[0]):]
print(f"Prompt: {self.prompts!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris.\n\n2. The capital of the United States is Washington, D")

outputs = model.generate(
prompts=self.prompts,
temperature=0.8,
top_p=0.95,
max_length=16,
top_k=1,
)
for output in outputs_param:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " ___________.\n6. City Name: Paris, France\n7. C")

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(self.prompts[0]):]
print(f"Prompt: {self.prompts!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris.\n\n2. The capital of the United States is Washington, D")

del model
self.release_vllm_model()
Expand All @@ -82,17 +87,16 @@ def test_load_shared_vllm(self):
backend=BACKEND.VLLM,
gpu_memory_utilization=0.2,
)
tokenizer = model.get_tokenizer()
outputs = model.generate(
prompts=self.prompts,
temperature=0.8,
top_p=0.95,
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris, which is also known as the city of love.")

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(self.prompts[0]):]
print(f"Prompt: {self.prompts!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris, which is also known as the city of love.")

del model
self.release_vllm_model()
Expand Down Expand Up @@ -140,6 +144,8 @@ def test_dynamic(self):
gpu_memory_utilization=0.2,
)

tokenizer = model.get_tokenizer()

for name, submodule in model.named_modules():
if name == 'model.model.layers.0.self_attn.q_proj' and isinstance(submodule,
BaseQuantLinear): # module 0 was skipped
Expand All @@ -150,12 +156,11 @@ def test_dynamic(self):
temperature=0.8,
top_p=0.95,
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris, which is also the country's largest city.")

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(self.prompts[0]):]
print(f"Prompt: {self.prompts!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris, which is also the country's largest city.")

del model
self.release_vllm_model()
Loading