Skip to content

Commit

Permalink
Merge pull request #3212 from flairNLP/fix_tars_loading
Browse files Browse the repository at this point in the history
Fix tars loading
  • Loading branch information
alanakbik authored Apr 21, 2023
2 parents e944346 + 4f1d4cd commit 9037a72
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
4 changes: 4 additions & 0 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,10 @@ def __setstate__(self, state):
config = None

if config_state_dict:
# some models like the tars model somehow lost this information.
if config_state_dict.get("_name_or_path") == "None":
config_state_dict["_name_or_path"] = state.get("model", "None")

model_type = config_state_dict.get("model_type", "bert")
config_class = CONFIG_MAPPING[model_type]
config = config_class.from_dict(config_state_dict)
Expand Down
35 changes: 24 additions & 11 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _get_state_dict(self):
"current_task": self._current_task,
"tag_type": self.get_current_label_type(),
"tag_dictionary": self.get_current_label_dictionary(),
"tars_model": self.tars_model,
"tars_embeddings": self.tars_model.embeddings.save_embeddings(use_state_dict=False),
"num_negative_labels_to_sample": self.num_negative_labels_to_sample,
"prefix": self.prefix,
"task_specific_attributes": self._task_specific_attributes,
Expand All @@ -437,13 +437,18 @@ def _fetch_model(model_name) -> str:

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
tars_embeddings = state.get("tars_embeddings")

if tars_embeddings is None:
tars_model = state["tars_model"]
tars_embeddings = tars_model.embeddings
# init new TARS classifier
model = super()._init_model_with_state_dict(
state,
task_name=state.get("current_task"),
label_dictionary=state.get("tag_dictionary"),
label_type=state.get("tag_type"),
embeddings=state.get("tars_model").embeddings,
embeddings=tars_embeddings,
num_negative_labels_to_sample=state.get("num_negative_labels_to_sample"),
prefix=state.get("prefix"),
**kwargs,
Expand Down Expand Up @@ -730,22 +735,30 @@ def _get_state_dict(self):
model_state = {
**super()._get_state_dict(),
"current_task": self._current_task,
"label_type": self.get_current_label_type(),
"label_dictionary": self.get_current_label_dictionary(),
"tars_model": self.tars_model,
"tars_embeddings": self.tars_model.embeddings.save_embeddings(use_state_dict=False),
"num_negative_labels_to_sample": self.num_negative_labels_to_sample,
"task_specific_attributes": self._task_specific_attributes,
}
if self._current_task is not None:
model_state.update(
{
"label_type": self.get_current_label_type(),
"label_dictionary": self.get_current_label_dictionary(),
}
)
return model_state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
# get the serialized embeddings
tars_model = state.get("tars_model")
if hasattr(tars_model, "embeddings"):
embeddings = tars_model.embeddings
else:
embeddings = tars_model.document_embeddings
tars_embeddings = state.get("tars_embeddings")

if tars_embeddings is None:
tars_model = state["tars_model"]
if hasattr(tars_model, "embeddings"):
tars_embeddings = tars_model.embeddings
else:
tars_embeddings = tars_model.document_embeddings

# remap state dict for models serialized with Flair <= 0.11.3
import re
Expand All @@ -762,7 +775,7 @@ def _init_model_with_state_dict(cls, state, **kwargs):
task_name=state["current_task"],
label_dictionary=state.get("label_dictionary"),
label_type=state.get("label_type", "default_label"),
embeddings=embeddings,
embeddings=tars_embeddings,
num_negative_labels_to_sample=state.get("num_negative_labels_to_sample"),
**kwargs,
)
Expand Down

0 comments on commit 9037a72

Please sign in to comment.