Skip to content

Commit b2474b3

Browse files
author
gaohongkui
committed
🐛 fix: fix reward model train seq_cls
1 parent a52cd65 commit b2474b3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

swift/llm/model/register.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def get_model_tokenizer_from_local(model_dir: str,
228228
if model is None:
229229
if model_info.task_type == 'seq_cls' and not model_meta.is_reward:
230230
context = partial(patch_automodel_for_sequence_classification, model_meta=model_meta)
231+
elif model_info.task_type == 'seq_cls' and model_meta.is_reward and model_config.num_labels > 1:
232+
logger.warning(
233+
'You are using a seq_cls reward model and num_labels > 1, ignore_mismatched_sizes will be set to True'
234+
)
235+
model_kwargs['ignore_mismatched_sizes'] = True
236+
context = partial(patch_automodel_for_sequence_classification, model_meta=model_meta)
231237
else:
232238
context = partial(patch_automodel, automodel_class=automodel_class, model_info=model_info)
233239
with context():

0 commit comments

Comments
 (0)