Skip to content

refactor inference #1245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def prepare_model_template(args: InferArguments,

def read_media_file(infer_kwargs: Dict[str, Any], infer_media_type: Literal['none', 'round', 'dialogue']) -> None:
text = 'Input a media path or URL <<< '
images = infer_kwargs.get('images', [])
images = infer_kwargs.get('images') or []
if infer_media_type == 'none':
return
if infer_media_type == 'round' or len(images) == 0:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ class DeployArguments(InferArguments):
def __post_init__(self):
super().__post_init__()
model_info = MODEL_MAPPING[self.model_type]
tags = model_info.get('tags', [])
tags = model_info.get('tags') or []
self.is_multimodal = 'multi-modal' in tags


Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
response = conversations[-1][self.value_key]
system = sys
history = h
tools = d.get('tools', [])
tools = d.get('tools') or []
row = {'system': system, 'history': history, 'history_roles': hr}
row.update({
'query': query,
Expand Down
205 changes: 94 additions & 111 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,38 +542,22 @@ def __next__(self) -> List[int]:
return value


@torch.inference_mode()
def inference_stream(model: PreTrainedModel,
template: Template,
query: str,
history: Optional[History] = None,
system: Optional[str] = None,
images: Optional[List[str]] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[StopWords] = None,
generation_info: Optional[Dict[str, int]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs) -> Iterator[Tuple[str, History]]:
"""
generation_config: Priority: generation_config > model.generation_config.
"""
def _prepare_inputs(model: PreTrainedModel,
template: Template,
query: str,
history: History,
system: Optional[str] = None,
images: Optional[List[str]] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[StopWords] = None,
adapter_names: Optional[List[str]] = None,
**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], int]:
if stop_words is None:
stop_words = []
if history is None:
history = []
else:
history = deepcopy(history)
if images is None:
images = []

# agent support
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
act_length = len(history[-1][-1])
query = None

example = {
'query': query,
'history': history,
Expand All @@ -587,7 +571,7 @@ def inference_stream(model: PreTrainedModel,
truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
if len(inputs) == 0 and truncation_strategy == 'delete':
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
return '', history
return {}, tokenizer_kwargs, 0

inputs.pop('labels', None)
tokenizer = template.tokenizer
Expand All @@ -606,11 +590,8 @@ def inference_stream(model: PreTrainedModel,
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
model.eval()
if generation_config is None:
generation_config = getattr(model, 'generation_config', None)
generation_config = getattr(model, 'generation_config')
generation_config = deepcopy(generation_config)
if generation_config.num_beams != 1:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)

if tokenizer.eos_token_id is not None:
generation_config.eos_token_id = tokenizer.eos_token_id
Expand All @@ -627,21 +608,69 @@ def inference_stream(model: PreTrainedModel,
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
if template.suffix[-1] not in stop_words:
stop_words.append(template.suffix[-1])
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
inputs = to_device(inputs, device)
if generation_info is not None:
generation_info['num_prompt_tokens'] = token_len
if 'inputs_embeds' in inputs:
inputs.pop('input_ids', None)
streamer = TokenListIteratorStreamer()
if adapter_names is not None:
inputs['adapter_names'] = adapter_names
generation_kwargs = {
'streamer': streamer,
'generation_config': generation_config,
'stopping_criteria': stopping_criteria,
**inputs
}

stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
inputs['stopping_criteria'] = stopping_criteria
inputs['generation_config'] = generation_config
return inputs, tokenizer_kwargs, token_len


@torch.inference_mode()
def inference_stream(model: PreTrainedModel,
template: Template,
query: str,
history: Optional[History] = None,
system: Optional[str] = None,
images: Optional[List[str]] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[StopWords] = None,
generation_info: Optional[Dict[str, int]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs) -> Iterator[Tuple[str, History]]:
"""
generation_config: Priority: generation_config > model.generation_config.
"""
if history is None:
history = []
else:
history = deepcopy(history)
inputs, tokenizer_kwargs, token_len = _prepare_inputs(
model,
template,
query,
history,
system,
images,
generation_config=generation_config,
stop_words=stop_words,
adapter_names=adapter_names,
**kwargs)
if len(inputs) == 0:
return '', history
if generation_info is None:
generation_info = {}
generation_info['num_prompt_tokens'] = token_len

# agent support
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
act_length = len(history[-1][-1])
query = None

generation_config = inputs['generation_config']
if generation_config.num_beams != 1:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)

streamer = TokenListIteratorStreamer()
generation_kwargs = {'streamer': streamer, **inputs}
_model_generate = model.generate
if is_torch_npu_available():

Expand All @@ -667,8 +696,7 @@ def _model_generate(*args, **kwargs):
except StopIteration:
is_finished = True
generate_ids = template.get_generate_ids(torch.tensor(raw_generate_ids)[None], token_len)
if generation_info is not None:
generation_info['num_generated_tokens'] = len(generate_ids)
generation_info['num_generated_tokens'] = len(generate_ids)
response = template.generate_ids_to_response(
generate_ids,
is_finished,
Expand Down Expand Up @@ -702,58 +730,38 @@ def inference(model: PreTrainedModel,
"""
generation_config: Priority: generation_config > model.generation_config.
"""
if stop_words is None:
stop_words = []
if history is None:
history = []
else:
history = deepcopy(history)
if images is None:
images = []
inputs, tokenizer_kwargs, token_len = _prepare_inputs(
model,
template,
query,
history,
system,
images,
generation_config=generation_config,
stop_words=stop_words,
adapter_names=adapter_names,
**kwargs)
if len(inputs) == 0:
return '', history
if generation_info is None:
generation_info = {}
generation_info['num_prompt_tokens'] = token_len

# agent support
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
query = None

example = {
'query': query,
'history': history,
'system': system,
'images': images, # for vl. str.
'tools': kwargs.pop('tools', None)
}
template.model = model
inputs, tokenizer_kwargs = template.encode(example)

truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
if len(inputs) == 0 and truncation_strategy == 'delete':
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
return '', history

inputs.pop('labels', None)
tokenizer = template.tokenizer
device = next(model.parameters()).device
if 'input_ids' in inputs:
input_ids = torch.tensor(inputs['input_ids'])[None]
inputs['input_ids'] = input_ids
token_len = input_ids.shape[1]
if 'inputs_embeds' in inputs:
inputs_embeds = inputs['inputs_embeds'][None]
inputs['inputs_embeds'] = inputs_embeds
token_len = inputs_embeds.shape[1]

inputs['attention_mask'] = torch.ones(token_len)[None]
if 'token_type_ids' in inputs:
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
model.eval()
if generation_config is None:
generation_config = getattr(model, 'generation_config', None)
generation_config = deepcopy(generation_config)
if stream and not verbose:
logger.warning('Please set verbose to True to support TextStreamer, or use `inference_stream.`')
stream = False
streamer = None
tokenizer = template.tokenizer
if stream:
streamer = TextStreamer(tokenizer, skip_prompt=True)
if verbose:
Expand All @@ -762,37 +770,12 @@ def inference(model: PreTrainedModel,
print(
f'{prompt_prefix}{safe_tokenizer_decode(tokenizer, input_ids[0], **tokenizer_kwargs)}{output_prefix}',
end='')
elif 'query' in example:
query = example['query']
else:
print(f'[QUERY]{query}\n{output_prefix}', end='')
if tokenizer.eos_token_id is not None:
generation_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
generation_config.pad_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is not None:
generation_config.bos_token_id = tokenizer.bos_token_id
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
max_length = get_max_model_len(model.config)
if max_length and token_len + generation_config.max_new_tokens > max_length:
generation_config.max_new_tokens = max_length - token_len
if generation_config.max_new_tokens <= 0:
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
if template.suffix[-1] not in stop_words:
stop_words.append(template.suffix[-1])
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
inputs = to_device(inputs, device)
if generation_info is not None:
generation_info['num_prompt_tokens'] = token_len
if 'inputs_embeds' in inputs:
inputs.pop('input_ids', None)
if adapter_names is not None:
inputs['adapter_names'] = adapter_names
generate_ids = model.generate(
streamer=streamer, generation_config=generation_config, stopping_criteria=stopping_criteria, **inputs)

generate_ids = model.generate(streamer=streamer, **inputs)
generate_ids = template.get_generate_ids(generate_ids, token_len)
if generation_info is not None:
generation_info['num_generated_tokens'] = len(generate_ids)
generation_info['num_generated_tokens'] = len(generate_ids)
if verbose and stream is False:
response = tokenizer.decode(generate_ids, **tokenizer_kwargs)
print(response)
Expand Down
Loading
Loading