diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 23c37b7fc9..849b16763e 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1156,6 +1156,10 @@ def __setstate__(self, state): config = None if config_state_dict: + # some models like the tars model somehow lost this information. + if config_state_dict.get("_name_or_path") == "None": + config_state_dict["_name_or_path"] = state.get("model", "None") + model_type = config_state_dict.get("model_type", "bert") config_class = CONFIG_MAPPING[model_type] config = config_class.from_dict(config_state_dict) diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index 3dfc871e2e..9003bd92a8 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -417,7 +417,7 @@ def _get_state_dict(self): "current_task": self._current_task, "tag_type": self.get_current_label_type(), "tag_dictionary": self.get_current_label_dictionary(), - "tars_model": self.tars_model, + "tars_embeddings": self.tars_model.embeddings.save_embeddings(use_state_dict=False), "num_negative_labels_to_sample": self.num_negative_labels_to_sample, "prefix": self.prefix, "task_specific_attributes": self._task_specific_attributes, @@ -437,13 +437,18 @@ def _fetch_model(model_name) -> str: @classmethod def _init_model_with_state_dict(cls, state, **kwargs): + tars_embeddings = state.get("tars_embeddings") + + if tars_embeddings is None: + tars_model = state["tars_model"] + tars_embeddings = tars_model.embeddings # init new TARS classifier model = super()._init_model_with_state_dict( state, task_name=state.get("current_task"), label_dictionary=state.get("tag_dictionary"), label_type=state.get("tag_type"), - embeddings=state.get("tars_model").embeddings, + embeddings=tars_embeddings, num_negative_labels_to_sample=state.get("num_negative_labels_to_sample"), prefix=state.get("prefix"), **kwargs, @@ -730,22 +735,30 @@ def _get_state_dict(self): model_state = { **super()._get_state_dict(), "current_task": self._current_task, - "label_type": self.get_current_label_type(), - "label_dictionary": self.get_current_label_dictionary(), - "tars_model": self.tars_model, + "tars_embeddings": self.tars_model.embeddings.save_embeddings(use_state_dict=False), "num_negative_labels_to_sample": self.num_negative_labels_to_sample, "task_specific_attributes": self._task_specific_attributes, } + if self._current_task is not None: + model_state.update( + { + "label_type": self.get_current_label_type(), + "label_dictionary": self.get_current_label_dictionary(), + } + ) return model_state @classmethod def _init_model_with_state_dict(cls, state, **kwargs): # get the serialized embeddings - tars_model = state.get("tars_model") - if hasattr(tars_model, "embeddings"): - embeddings = tars_model.embeddings - else: - embeddings = tars_model.document_embeddings + tars_embeddings = state.get("tars_embeddings") + + if tars_embeddings is None: + tars_model = state["tars_model"] + if hasattr(tars_model, "embeddings"): + tars_embeddings = tars_model.embeddings + else: + tars_embeddings = tars_model.document_embeddings # remap state dict for models serialized with Flair <= 0.11.3 import re @@ -762,7 +775,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): task_name=state["current_task"], label_dictionary=state.get("label_dictionary"), label_type=state.get("label_type", "default_label"), - embeddings=embeddings, + embeddings=tars_embeddings, num_negative_labels_to_sample=state.get("num_negative_labels_to_sample"), **kwargs, )