Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reintroduce hyperparameter tuning #2633

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions flair/hyperparameter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .param_selection import (
SearchSpace,
SequenceTaggerParamSelector,
TextClassifierParamSelector,
)
from .parameter import (
DOCUMENT_EMBEDDING_PARAMETERS,
SEQUENCE_TAGGER_PARAMETERS,
TRAINING_PARAMETERS,
Parameter,
)

__all__ = [
"Parameter",
"SEQUENCE_TAGGER_PARAMETERS",
"TRAINING_PARAMETERS",
"DOCUMENT_EMBEDDING_PARAMETERS",
"SequenceTaggerParamSelector",
"TextClassifierParamSelector",
"SearchSpace",
]
260 changes: 260 additions & 0 deletions flair/hyperparameter/param_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import logging
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Union

import numpy as np
from hyperopt import fmin, hp, tpe

import flair.nn
from flair.data import Corpus
from flair.embeddings import DocumentPoolEmbeddings, DocumentRNNEmbeddings
from flair.hyperparameter.parameter import (
DOCUMENT_EMBEDDING_PARAMETERS,
SEQUENCE_TAGGER_PARAMETERS,
TRAINING_PARAMETERS,
Parameter,
)
from flair.models import SequenceTagger, TextClassifier
from flair.trainers import ModelTrainer
from flair.training_utils import EvaluationMetric, init_output_file, log_line

log = logging.getLogger("flair")


class OptimizationValue(Enum):
DEV_LOSS = "loss"
DEV_SCORE = "score"


class SearchSpace(object):
def __init__(self):
self.search_space = {}

def add(self, parameter: Parameter, func, **kwargs):
self.search_space[parameter.value] = func(parameter.value, **kwargs)

def get_search_space(self):
return hp.choice("parameters", [self.search_space])


class ParamSelector(object):
def __init__(
self,
corpus: Corpus,
base_path: Union[str, Path],
max_epochs: int,
evaluation_metric: EvaluationMetric,
training_runs: int,
optimization_value: OptimizationValue,
):
if type(base_path) is str:
base_path = Path(base_path)

self.corpus = corpus
self.max_epochs = max_epochs
self.base_path = base_path
self.evaluation_metric = evaluation_metric
self.run = 1
self.training_runs = training_runs
self.optimization_value = optimization_value

self.param_selection_file = init_output_file(base_path, "param_selection.txt")

@abstractmethod
def _set_up_model(self, params: dict) -> flair.nn.Model:
pass

def _objective(self, params: dict):
log_line(log)
log.info(f"Evaluation run: {self.run}")
log.info("Evaluating parameter combination:")
for k, v in params.items():
if isinstance(v, tuple):
v = ",".join([str(x) for x in v])
log.info(f"\t{k}: {str(v)}")
log_line(log)

scores = []
vars = []

for i in range(0, self.training_runs):
log_line(log)
log.info(f"Training run: {i + 1}")

for sent in self.corpus.get_all_sentences(): # type: ignore
sent.clear_embeddings()

model = self._set_up_model(params)

training_params = {key: params[key] for key in params if key in TRAINING_PARAMETERS}

trainer: ModelTrainer = ModelTrainer(model, self.corpus)

result = trainer.train(
self.base_path,
max_epochs=self.max_epochs,
param_selection_mode=True,
**training_params,
)

# take the average over the last three scores of training
if self.optimization_value == OptimizationValue.DEV_LOSS:
curr_scores = result["dev_loss_history"][-3:]
else:
curr_scores = list(map(lambda s: 1 - s, result["dev_score_history"][-3:]))

score = sum(curr_scores) / float(len(curr_scores))
var = np.var(curr_scores)
scores.append(score)
vars.append(var)

# take average over the scores from the different training runs
final_score = sum(scores) / float(len(scores))
final_var = sum(vars) / float(len(vars))

test_score = result["test_score"]
log_line(log)
log.info("Done evaluating parameter combination:")
for k, v in params.items():
if isinstance(v, tuple):
v = ",".join([str(x) for x in v])
log.info(f"\t{k}: {v}")
log.info(f"{self.optimization_value.value}: {final_score}")
log.info(f"variance: {final_var}")
log.info(f"test_score: {test_score}\n")
log_line(log)

with open(self.param_selection_file, "a") as f:
f.write(f"evaluation run {self.run}\n")
for k, v in params.items():
if isinstance(v, tuple):
v = ",".join([str(x) for x in v])
f.write(f"\t{k}: {str(v)}\n")
f.write(f"{self.optimization_value.value}: {final_score}\n")
f.write(f"variance: {final_var}\n")
f.write(f"test_score: {test_score}\n")
f.write("-" * 100 + "\n")

self.run += 1

return {"status": "ok", "loss": final_score, "loss_variance": final_var}

def optimize(self, space: SearchSpace, max_evals=100):
search_space = space.search_space
best = fmin(self._objective, search_space, algo=tpe.suggest, max_evals=max_evals)

log_line(log)
log.info("Optimizing parameter configuration done.")
log.info("Best parameter configuration found:")
for k, v in best.items():
log.info(f"\t{k}: {v}")
log_line(log)

with open(self.param_selection_file, "a") as f:
f.write("best parameter combination\n")
for k, v in best.items():
if isinstance(v, tuple):
v = ",".join([str(x) for x in v])
f.write(f"\t{k}: {str(v)}\n")


class SequenceTaggerParamSelector(ParamSelector):
def __init__(
self,
corpus: Corpus,
tag_type: str,
base_path: Union[str, Path],
max_epochs: int = 50,
evaluation_metric: EvaluationMetric = EvaluationMetric.MICRO_F1_SCORE,
training_runs: int = 1,
optimization_value: OptimizationValue = OptimizationValue.DEV_LOSS,
):
"""
:param corpus: the corpus
:param tag_type: tag type to use
:param base_path: the path to the result folder (results will be written to that folder)
:param max_epochs: number of epochs to perform on every evaluation run
:param evaluation_metric: evaluation metric used during training
:param training_runs: number of training runs per evaluation run
:param optimization_value: value to optimize
"""
super().__init__(
corpus,
base_path,
max_epochs,
evaluation_metric,
training_runs,
optimization_value,
)

self.tag_type = tag_type
self.tag_dictionary = self.corpus.make_label_dictionary(self.tag_type)

def _set_up_model(self, params: dict):
sequence_tagger_params = {key: params[key] for key in params if key in SEQUENCE_TAGGER_PARAMETERS}

tagger: SequenceTagger = SequenceTagger(
tag_dictionary=self.tag_dictionary,
tag_type=self.tag_type,
**sequence_tagger_params,
)
return tagger


class TextClassifierParamSelector(ParamSelector):
def __init__(
self,
corpus: Corpus,
label_type: str,
multi_label: bool,
base_path: Union[str, Path],
document_embedding_type: str,
max_epochs: int = 50,
evaluation_metric: EvaluationMetric = EvaluationMetric.MICRO_F1_SCORE,
training_runs: int = 1,
optimization_value: OptimizationValue = OptimizationValue.DEV_LOSS,
):
"""
:param corpus: the corpus
:param label_type: string to identify the label type ('question_class', 'sentiment', etc.)
:param multi_label: true, if the dataset is multi label, false otherwise
:param base_path: the path to the result folder (results will be written to that folder)
:param document_embedding_type: either 'lstm', 'mean', 'min', or 'max'
:param max_epochs: number of epochs to perform on every evaluation run
:param evaluation_metric: evaluation metric used during training
:param training_runs: number of training runs per evaluation run
:param optimization_value: value to optimize
"""
super().__init__(
corpus,
base_path,
max_epochs,
evaluation_metric,
training_runs,
optimization_value,
)

self.multi_label = multi_label
self.label_type = label_type
self.document_embedding_type = document_embedding_type

self.label_dictionary = self.corpus.make_label_dictionary(self.label_type)

def _set_up_model(self, params: dict):
embdding_params = {key: params[key] for key in params if key in DOCUMENT_EMBEDDING_PARAMETERS}

if self.document_embedding_type == "lstm":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems the only options are RNN or pooled word embeddings, but the standard approach today would be to simply use a transformer only which far outperforms RNNs trained over word embeddings. Not necessarily something that needs to be fixed in this PR but something that limits the usefulness of our parameter selection for text classifiers at the moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per your suggestion (which makes perfects sense - it felt really weird readding this outdated RNN code) I refactored TextClassifierParamSelector so that it uses TransformerDocumentEmbeddings now. I documented it to a reasonable extent in the tutorial.

document_embedding = DocumentRNNEmbeddings(rnn_type="LSTM", **embdding_params)
else:
document_embedding = DocumentPoolEmbeddings(**embdding_params) # type: ignore

text_classifier: TextClassifier = TextClassifier(
label_dictionary=self.label_dictionary,
multi_label=self.multi_label,
label_type=self.label_type,
document_embeddings=document_embedding,
)

return text_classifier
66 changes: 66 additions & 0 deletions flair/hyperparameter/parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from enum import Enum


class Parameter(Enum):
EMBEDDINGS = "embeddings"
HIDDEN_SIZE = "hidden_size"
USE_CRF = "use_crf"
USE_RNN = "use_rnn"
RNN_LAYERS = "rnn_layers"
DROPOUT = "dropout"
WORD_DROPOUT = "word_dropout"
LOCKED_DROPOUT = "locked_dropout"
LEARNING_RATE = "learning_rate"
MINI_BATCH_SIZE = "mini_batch_size"
ANNEAL_FACTOR = "anneal_factor"
ANNEAL_WITH_RESTARTS = "anneal_with_restarts"
PATIENCE = "patience"
REPROJECT_WORDS = "reproject_words"
REPROJECT_WORD_DIMENSION = "reproject_words_dimension"
BIDIRECTIONAL = "bidirectional"
OPTIMIZER = "optimizer"
MOMENTUM = "momentum"
DAMPENING = "dampening"
WEIGHT_DECAY = "weight_decay"
NESTEROV = "nesterov"
AMSGRAD = "amsgrad"
BETAS = "betas"
EPS = "eps"


TRAINING_PARAMETERS = [
Parameter.LEARNING_RATE.value,
Parameter.MINI_BATCH_SIZE.value,
Parameter.OPTIMIZER.value,
Parameter.ANNEAL_FACTOR.value,
Parameter.PATIENCE.value,
Parameter.ANNEAL_WITH_RESTARTS.value,
Parameter.MOMENTUM.value,
Parameter.DAMPENING.value,
Parameter.WEIGHT_DECAY.value,
Parameter.NESTEROV.value,
Parameter.AMSGRAD.value,
Parameter.BETAS.value,
Parameter.EPS.value,
]
SEQUENCE_TAGGER_PARAMETERS = [
Parameter.EMBEDDINGS.value,
Parameter.HIDDEN_SIZE.value,
Parameter.RNN_LAYERS.value,
Parameter.USE_CRF.value,
Parameter.USE_RNN.value,
Parameter.DROPOUT.value,
Parameter.LOCKED_DROPOUT.value,
Parameter.WORD_DROPOUT.value,
]
DOCUMENT_EMBEDDING_PARAMETERS = [
Parameter.EMBEDDINGS.value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one idea for the problem above would be to support param selection only for TransformerDocumentEmbeddings and define some parameters that one would like to search over (which transformer model, which layers, fine-tune yes/no) and remove the RNN specific ones.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did just that and refactored TextClassifierParamSelector so that it uses TransformerDocumentEmbeddings. Also renamed DOCUMENT_EMBEDDING_PARAMETERS to TEXT_CLASSIFICATION_PARAMETERS

Parameter.HIDDEN_SIZE.value,
Parameter.RNN_LAYERS.value,
Parameter.REPROJECT_WORDS.value,
Parameter.REPROJECT_WORD_DIMENSION.value,
Parameter.BIDIRECTIONAL.value,
Parameter.DROPOUT.value,
Parameter.LOCKED_DROPOUT.value,
Parameter.WORD_DROPOUT.value,
]
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ exclude = '''
[tool.pytest.ini_options]
flake8-max-line-length = 210
flake8-ignore = ["E203", "W503"] # See https://github.com/PyCQA/pycodestyle/issues/373
addopts = "-W error --flake8 --mypy --ignore flair/data_fetcher.py --ignore flair/embeddings/legacy.py --isort"
addopts = "--flake8 --mypy --ignore flair/data_fetcher.py --ignore flair/embeddings/legacy.py --isort"
filterwarnings = [
"error", # Convert all warnings to errors
"ignore:the imp module is deprecated:DeprecationWarning:past" # ignore DeprecationWarning from hyperopt dependency
]
markers = [
"integration",
]
Expand All @@ -24,4 +28,4 @@ exclude = "flair/data_fetcher.py|flair/embeddings/legacy.py"
ignore_missing_imports = true

[tool.isort]
profile = "black"
profile = "black"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mpld3==0.3
scikit-learn>=0.21.3
sqlitedict>=1.6.0
deprecated>=1.2.4
hyperopt>=0.2.7
transformers>=4.0.0
bpemb>=0.3.2
regex
Expand Down
Loading