Skip to content

Commit a2a858d

Browse files
authored
🐛 fix: fix reward model train seq_cls (#3921)
1 parent 634c15d commit a2a858d

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

swift/llm/model/register.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ def get_model_tokenizer_from_local(model_dir: str,
230230
if model is None:
231231
if model_info.task_type == 'seq_cls' and not model_meta.is_reward:
232232
context = partial(patch_automodel_for_sequence_classification, model_meta=model_meta)
233+
elif model_info.task_type == 'seq_cls' and model_meta.is_reward and model_config.num_labels > 1:
234+
logger.warning('You are using a reward model for seq_cls task and num_labels > 1, '
235+
'ignore_mismatched_sizes will be set to True')
236+
model_kwargs['ignore_mismatched_sizes'] = True
237+
context = partial(patch_automodel_for_sequence_classification, model_meta=model_meta)
233238
else:
234239
context = partial(patch_automodel, automodel_class=automodel_class, model_info=model_info)
235240
with context():

0 commit comments

Comments
 (0)