File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -228,6 +228,12 @@ def get_model_tokenizer_from_local(model_dir: str,
228228 if model is None :
229229 if model_info .task_type == 'seq_cls' and not model_meta .is_reward :
230230 context = partial (patch_automodel_for_sequence_classification , model_meta = model_meta )
231+ elif model_info .task_type == 'seq_cls' and model_meta .is_reward and model_config .num_labels > 1 :
232+ logger .warning (
233+ 'You are using a seq_cls reward model and num_labels > 1, ignore_mismatched_sizes will be set to True'
234+ )
235+ model_kwargs ['ignore_mismatched_sizes' ] = True
236+ context = partial (patch_automodel_for_sequence_classification , model_meta = model_meta )
231237 else :
232238 context = partial (patch_automodel , automodel_class = automodel_class , model_info = model_info )
233239 with context ():
You can’t perform that action at this time.
0 commit comments