Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support do_sample parameter #2375

Merged
merged 8 commits into from
Sep 2, 2024
Merged

support do_sample parameter #2375

merged 8 commits into from
Sep 2, 2024

Conversation

irexyc
Copy link
Collaborator

@irexyc irexyc commented Aug 26, 2024

Motivation

  • merge EngineGenerationConfig / GenerationConfig
  • align gen_config logic with transformers
  • prevent modification of gen_config.random_seed
# transformers batch generate
from transformers import LlamaTokenizer, AutoModelForCausalLM
import torch

tokenizer = LlamaTokenizer.from_pretrained("/mnt/140/llama2/huggingface/llama-2-7b-chat/")
model = AutoModelForCausalLM.from_pretrained("/mnt/140/llama2/huggingface/llama-2-7b-chat/", torch_dtype=torch.float16, device_map="auto")

tokenizer.padding_side = "left"

# Define PAD Token = EOS Token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# use different length sentences to test batching
sentences = [
         "Hello, my dog is a little",
         "Hello, my dog is a little",
 ]

inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)

output_sequences = model.generate(**inputs, do_sample=False, max_new_tokens=20)
out = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
print(f'{out[0]}\n{out[1]}\n')

output_sequences = model.generate(**inputs, do_sample=True, max_new_tokens=20)
out = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
print(f'{out[0]}\n{out[1]}\n')

@lvhan028
Copy link
Collaborator

May conclude the demo results

@lvhan028
Copy link
Collaborator

api_server.py should also be updated.
The following is how api_server instantiates GenerationConfig. But now do_sample is added and its default value is False.
So api_server will always request greedy search.

    gen_config = GenerationConfig(
        max_new_tokens=request.max_tokens if request.max_tokens else 512,
        logprobs=request.logprobs,
        top_k=request.top_k,
        top_p=request.top_p,
        temperature=request.temperature,
        repetition_penalty=request.repetition_penalty,
        ignore_eos=request.ignore_eos,
        stop_words=request.stop,
        skip_special_tokens=request.skip_special_tokens,
        random_seed=random_seed)

@@ -360,6 +360,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,

gen_config = GenerationConfig(
max_new_tokens=request.max_tokens,
do_sample=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make it a request option?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's unnecessary since users can still use top_k=1 or temperature=0. to invoke greedy search

bad_words = special_word_token_ids(self.bad_words) or []
stop_words.extend(self.stop_words_ids or [])
bad_words.extend(self.bad_words_ids or [])
self.stop_words_ids = list(set(stop_words)) or None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop_words could be a list of list, which can be use set

a = [[123], [456]]
b = set(a) # TypeError: unhashable type: 'list'

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why stop_words could be a list of list?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the following case:

gen_config = GenerationConfig(
    stop_words=chat_template.stop_words,
    stop_token_ids=[[92542], [92540]]  # list of list
)

stop_words=special_word_token_ids(gen_config.stop_words),
bad_words=special_word_token_ids(gen_config.bad_words),
logits_processors=gen_config.logits_processors)
stop_words = special_word_token_ids(self.stop_words) or []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May assert self.stop_words_ids and self.bad_words_ids None for now.
Otherwise, it probably brings in side effects if users input some unexpected token_id.

@lvhan028 lvhan028 merged commit 7519a35 into InternLM:main Sep 2, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants