Skip to content

Commit

Permalink
support reduce_dialogue_repetition
Browse files Browse the repository at this point in the history
  • Loading branch information
ming1753 committed Jan 6, 2025
1 parent 608d4be commit 577b7a7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion llm/server/server/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _load_tokenizer(self):
"""
if self.config.use_hf_tokenizer:
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
else:
from paddlenlp.transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir)
Expand Down
12 changes: 12 additions & 0 deletions llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(self, args):
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
self.args.hidden_size = self.model_cfg["hidden_size"]

self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))

self.nranks = dist.get_world_size()
self.init_dist_env()
self.rank = fleet.worker_index()
Expand Down Expand Up @@ -246,6 +248,12 @@ def init_inputs(self):
self.share_inputs['free_list_len'] = paddle.full(
shape=[1], fill_value=self.free_list_len, dtype="int32")

if self.reduce_dialogue_repetition:
self.share_inputs["first_token_ids"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")

def dy_input_preprocess(self, tasks):
"""
dynamic insertion
Expand Down Expand Up @@ -279,6 +287,10 @@ def dy_input_preprocess(self, tasks):
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
self.share_inputs['stop_flags'][idx:idx + 1] = False

if self.reduce_dialogue_repetition:
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length

if "infer_seed" in task:
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']

Expand Down

0 comments on commit 577b7a7

Please sign in to comment.