Skip to content

Commit

Permalink
GH-2632: Fixing paramselection code to work with changes in Flair v0.10
Browse files Browse the repository at this point in the history
  • Loading branch information
tadejmagajna committed Feb 14, 2022
1 parent f0c9f42 commit 2995cc4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions flair/hyperparameter/param_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
SEQUENCE_TAGGER_PARAMETERS,
TRAINING_PARAMETERS,
DOCUMENT_EMBEDDING_PARAMETERS,
MODEL_TRAINER_PARAMETERS,
)
from flair.models import SequenceTagger, TextClassifier
from flair.trainers import ModelTrainer
Expand Down Expand Up @@ -97,13 +96,8 @@ def _objective(self, params: dict):
training_params = {
key: params[key] for key in params if key in TRAINING_PARAMETERS
}
model_trainer_parameters = {
key: params[key] for key in params if key in MODEL_TRAINER_PARAMETERS
}

trainer: ModelTrainer = ModelTrainer(
model, self.corpus, **model_trainer_parameters
)
trainer: ModelTrainer = ModelTrainer(model, self.corpus)

result = trainer.train(
self.base_path,
Expand Down Expand Up @@ -226,6 +220,7 @@ class TextClassifierParamSelector(ParamSelector):
def __init__(
self,
corpus: Corpus,
label_type: str,
multi_label: bool,
base_path: Union[str, Path],
document_embedding_type: str,
Expand All @@ -236,6 +231,7 @@ def __init__(
):
"""
:param corpus: the corpus
:param label_type: string to identify the label type ('question_class', 'sentiment', etc.)
:param multi_label: true, if the dataset is multi label, false otherwise
:param base_path: the path to the result folder (results will be written to that folder)
:param document_embedding_type: either 'lstm', 'mean', 'min', or 'max'
Expand All @@ -254,23 +250,25 @@ def __init__(
)

self.multi_label = multi_label
self.label_type = label_type
self.document_embedding_type = document_embedding_type

self.label_dictionary = self.corpus.make_label_dictionary()
self.label_dictionary = self.corpus.make_label_dictionary(self.label_type)

def _set_up_model(self, params: dict):
embdding_params = {
key: params[key] for key in params if key in DOCUMENT_EMBEDDING_PARAMETERS
}

if self.document_embedding_type == "lstm":
document_embedding = DocumentRNNEmbeddings(**embdding_params)
document_embedding = DocumentRNNEmbeddings(rnn_type="LSTM", **embdding_params)
else:
document_embedding = DocumentPoolEmbeddings(**embdding_params)

text_classifier: TextClassifier = TextClassifier(
label_dictionary=self.label_dictionary,
multi_label=self.multi_label,
label_type=self.label_type,
document_embeddings=document_embedding,
)

Expand Down
2 changes: 1 addition & 1 deletion flair/hyperparameter/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Parameter(Enum):
TRAINING_PARAMETERS = [
Parameter.LEARNING_RATE.value,
Parameter.MINI_BATCH_SIZE.value,
Parameter.OPTIMIZER.value,
Parameter.ANNEAL_FACTOR.value,
Parameter.PATIENCE.value,
Parameter.ANNEAL_WITH_RESTARTS.value,
Expand All @@ -52,7 +53,6 @@ class Parameter(Enum):
Parameter.LOCKED_DROPOUT.value,
Parameter.WORD_DROPOUT.value,
]
MODEL_TRAINER_PARAMETERS = [Parameter.OPTIMIZER.value]
DOCUMENT_EMBEDDING_PARAMETERS = [
Parameter.EMBEDDINGS.value,
Parameter.HIDDEN_SIZE.value,
Expand Down

0 comments on commit 2995cc4

Please sign in to comment.