Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _update_model_kwargs_for_generation #29560

Closed
wants to merge 1 commit into from

Conversation

Jintao-Huang
Copy link
Contributor

What does this PR do?

The default value of model_inputs is None, and there is a line of code later on model_kwargs["cache_position"] = model_inputs.get("cache_position", None), which is unsafe. So, if the value of model_inputs is None, it should be set to {}.

When running the following code with the main branch code, an error occurs. However, it runs successfully using transformers==4.38.*.

# pip install ms-swift -U
import os

from swift.llm import (
    get_model_tokenizer, get_template, inference, ModelType,
    get_default_template_type, inference_stream
)
from swift.utils import seed_everything
import torch

model_type = ModelType.qwen1half_0_5b_chat
template_type = get_default_template_type(model_type)
print(f'template_type: {template_type}')  # template_type: qwen

kwargs = {}
# kwargs['use_flash_attn'] = True  # use flash_attn

model, tokenizer = get_model_tokenizer(model_type, torch.float16,
                                       model_kwargs={'device_map': 'auto'}, **kwargs)
template = get_template(template_type, tokenizer)
seed_everything(42)

model.generation_config.max_new_tokens = 128

query = 'hello'
gen = inference_stream(model, template, query)
print_idx = 0
print(f'query: {query}\nresponse: ', end='')
for response, history in gen:
    delta = response[print_idx:]
    print(delta, end='', flush=True)
    print_idx = len(response)
print()
print(f'history: {history}')

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@gante
Copy link
Member

gante commented Mar 13, 2024

Hi @Jintao-Huang 👋

This need is not needed after #29467 gets merged in :) As such, I'm not approving/merging this PR so as to not generate conflicts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants