File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -230,6 +230,11 @@ def get_model_tokenizer_from_local(model_dir: str,
230230 if model is None :
231231 if model_info .task_type == 'seq_cls' and not model_meta .is_reward :
232232 context = partial (patch_automodel_for_sequence_classification , model_meta = model_meta )
233+ elif model_info .task_type == 'seq_cls' and model_meta .is_reward and model_config .num_labels > 1 :
234+ logger .warning ('You are using a reward model for seq_cls task and num_labels > 1, '
235+ 'ignore_mismatched_sizes will be set to True' )
236+ model_kwargs ['ignore_mismatched_sizes' ] = True
237+ context = partial (patch_automodel_for_sequence_classification , model_meta = model_meta )
233238 else :
234239 context = partial (patch_automodel , automodel_class = automodel_class , model_info = model_info )
235240 with context ():
You can’t perform that action at this time.
0 commit comments