Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Jan 15, 2025
1 parent cbcabe0 commit 60f5177
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
5 changes: 3 additions & 2 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def _infer_forward(self,
if template.mode == 'seq_cls':
preds, logprobs = template.decode_seq_cls(logits)
elif template.mode == 'prm':
preds, logprobs = template.decode_prm(logits)
preds = template.decode_prm(inputs['input_ids'], logits)
logprobs = [None]*len(preds)
else:
raise ValueError(f'Unsupported mode: {template.mode}')

Expand Down Expand Up @@ -441,7 +442,7 @@ def _gen_wrapper():

return _gen_wrapper()
else:
infer_func = self._infer_prm if template.mode == 'seq_cls' else self._infer_full
infer_func = self._infer_forward if template.mode in ('seq_cls', 'prm') else self._infer_full
return self._update_metrics(infer_func(**kwargs), metrics)

def infer(
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/model/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,9 @@ def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx
Model('Qwen/Qwen2.5-Math-PRM-72B', 'Qwen/Qwen2.5-Math-PRM-72B'),
]),
],
TemplateType.qwen,
TemplateType.qwen2_5_math_prm,
get_model_tokenizer_reward_model,
task_type='prm',
architectures=['Qwen2ForProcessRewardModel'],
requires=['transformers>=4.37'],
))
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ModelMeta:

is_multimodal: bool = False
is_reward: bool = False
task_type: Optional[str] = None

# File patterns to ignore when downloading the model.
ignore_patterns: List[str] = field(default_factory=list)
Expand Down Expand Up @@ -391,6 +392,8 @@ def get_model_info_meta(
task_type = 'seq_cls'
if task_type == 'seq_cls':
assert num_labels is not None, 'Please pass the parameter `num_labels`.'
if model_meta.task_type is not None:
task_type = model_meta.task_type
model_info.task_type = task_type
model_info.num_labels = num_labels

Expand Down
11 changes: 5 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MaxLengthError(ValueError):


class Template(ProcessorMixin):
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>']
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>']
special_keys = ['images', 'videos', 'audios', 'objects']
grounding_type = 'norm_1000'

Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(

self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
'train', 'rlhf', 'kto', # train
'seq_cls'] = 'pt'
'seq_cls', 'prm'] = 'pt'
if self.model_info.task_type != 'causal_lm':
self.mode = self.model_info.task_type
self._handles = []
Expand Down Expand Up @@ -237,7 +237,7 @@ def encode(self,
encoded = Template._encode(self, inputs)
for key in ['images', 'audios', 'videos']:
encoded[key] = getattr(inputs, key)
elif self.mode in {'pt', 'train'}:
elif self.mode in {'pt', 'train', 'prm'}:
encoded = self._encode(inputs)
elif self.mode == 'seq_cls':
encoded = self._seq_cls_encode(inputs)
Expand Down Expand Up @@ -365,8 +365,7 @@ def _concat_context_list(
def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
"""Merge anything in the context to simplify the inputs"""
if inputs.is_multimodal:
context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list)
context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list)
context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs)

res: List[Context] = [] # result of context_list
Expand Down Expand Up @@ -833,7 +832,7 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
return self._rlhf_data_collator(batch, padding_to=padding_to)
elif self.mode == 'kto':
return self._kto_data_collator(batch, padding_to=padding_to)
elif self.mode in {'pt', 'train'}:
elif self.mode in {'pt', 'train', 'prm'}:
return self._data_collator(batch, padding_to=padding_to)
elif self.mode == 'seq_cls':
return self._seq_cls_data_collator(batch, padding_to=padding_to)
Expand Down

0 comments on commit 60f5177

Please sign in to comment.