-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reintroduce hyperparameter tuning #2633
Reintroduce hyperparameter tuning #2633
Conversation
This reverts commit 9aff426.
…ple training runs
3198ed5
to
88f44b2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tadejmagajna Thanks for bringing this back! But some issues regarding the examples, see comments.
def _set_up_model(self, params: dict): | ||
embdding_params = {key: params[key] for key in params if key in DOCUMENT_EMBEDDING_PARAMETERS} | ||
|
||
if self.document_embedding_type == "lstm": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems the only options are RNN or pooled word embeddings, but the standard approach today would be to simply use a transformer only which far outperforms RNNs trained over word embeddings. Not necessarily something that needs to be fixed in this PR but something that limits the usefulness of our parameter selection for text classifiers at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per your suggestion (which makes perfects sense - it felt really weird readding this outdated RNN code) I refactored TextClassifierParamSelector
so that it uses TransformerDocumentEmbeddings
now. I documented it to a reasonable extent in the tutorial.
flair/hyperparameter/parameter.py
Outdated
Parameter.WORD_DROPOUT.value, | ||
] | ||
DOCUMENT_EMBEDDING_PARAMETERS = [ | ||
Parameter.EMBEDDINGS.value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one idea for the problem above would be to support param selection only for TransformerDocumentEmbeddings
and define some parameters that one would like to search over (which transformer model, which layers, fine-tune yes/no) and remove the RNN specific ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did just that and refactored TextClassifierParamSelector
so that it uses TransformerDocumentEmbeddings
. Also renamed DOCUMENT_EMBEDDING_PARAMETERS
to TEXT_CLASSIFICATION_PARAMETERS
from flair.datasets import TREC_6 | ||
|
||
# load your corpus | ||
corpus = TREC_6() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tutorial is a bit counterintuitive since it defines a text classification corpus and then word embeddings. The fact that there's an LSTM used is buried in the model initialization and not explained.
Perhaps switch to a sequence labeling example? WNUT_17
for instance could be used as example corpus, and for sequence labeling the current implementation allows some interesting parameter searches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree - makes sense to give precedence to sequence labelling in Flair.
As per your suggestion I added an example to the top that uses WNUT_17
and documents the use of SequenceTaggerParamSelector
However, I also kept the example that shows how to use TextClassifierParamSelector
. I feel like that if a feature isn't documented, people are unlikely to dig through the source and the value of the feature will be diminished.
So what I did was I created two sections (Selecting hyperparameters for sequence labelling
and Selecting hyperparameters for text classification
) and documented both use-cases accordingly.
tag_type = 'ner' | ||
|
||
# 3. make the tag dictionary from the corpus | ||
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be tag_dictionary = corpus.make_label_dictionary(label_type=tag_type)
, otherwise the example won't run
trainer: ModelTrainer = ModelTrainer(tagger, corpus) | ||
|
||
# 7. find learning rate | ||
learning_rate_tsv = trainer.find_learning_rate('resources/taggers/example-ner', Adam) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function is broken, it should be changed to the code below, ideally as part of this PR.
def find_learning_rate(
self,
base_path: Union[Path, str],
optimizer,
file_name: str = "learning_rate.tsv",
start_learning_rate: float = 1e-7,
end_learning_rate: float = 10,
iterations: int = 100,
mini_batch_size: int = 32,
stop_early: bool = True,
smoothing_factor: float = 0.98,
**kwargs,
) -> Path:
best_loss = None
moving_avg_loss = 0
# cast string to Path
if type(base_path) is str:
base_path = Path(base_path)
learning_rate_tsv = init_output_file(base_path, file_name)
with open(learning_rate_tsv, "a") as f:
f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n")
optimizer = optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs)
train_data = self.corpus.train
scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)
model_state = self.model.state_dict()
self.model.train()
step = 0
while step < iterations:
batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
for batch in batch_loader:
step += 1
# forward pass
loss = self.model.forward_loss(batch)
if isinstance(loss, tuple):
loss = loss[0]
# update optimizer and scheduler
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()
scheduler.step(step)
learning_rate = scheduler.get_lr()[0]
loss_item = loss.item()
if step == 1:
best_loss = loss_item
else:
if smoothing_factor > 0:
moving_avg_loss = smoothing_factor * moving_avg_loss + (1 - smoothing_factor) * loss_item
loss_item = moving_avg_loss / (1 - smoothing_factor ** (step + 1))
if loss_item < best_loss:
best_loss = loss
if step > iterations:
break
if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)):
log_line(log)
log.info("loss diverged - stopping early!")
step = iterations
break
with open(str(learning_rate_tsv), "a") as f:
f.write(f"{step}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n")
self.model.load_state_dict(model_state)
self.model.to(flair.device)
log_line(log)
log.info(f"learning rate finder finished - plot {learning_rate_tsv}")
log_line(log)
return Path(learning_rate_tsv)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed this as part of this PR, but note that I made tiny adjustment to the source code you pasted above. With the above code, we were getting the following warining from Torch:
UserWarning: The epoch parameter in
scheduler.step()
was not necessary and is being deprecated where possible. Please usescheduler.step()
to step the scheduler.
this warning broke the CI build.
So I replaced the line scheduler.step(step)
with scheduler.step()
. If my understanding is correct, this shouldn't badly affect the execution of the code. But if it does, please do let me know.
…lector and applying PR suggestions
@alanakbik Thank you for the feedback! I addressed your PR comments and I incorporated all of the suggestions including the refactoring of |
@alanakbik I made |
@tadejmagajna thanks a lot - looks great! |
This PR reintroduces hyperparameter tuning in Flair
closes #2632
A number of small changes was required so I split this PR into appropriately named commits where each commit is an isolated change to the code base. For example, there is a separate commit that reverts the deletion or hyperparameter tuning in Flair - this helps us quickly see which code changes are new and which simply reintroduce the old code.
The changes and their corresponding commits include:
make_label_dictionary(label_type)
requiring thelabel_type
argumenttraining_runs>1
pyproject.toml
where I swapped the-W
parameter for the more-readablefilterwarnings
parameter and made it ignore that specificDeprecationWarning
.Please do me know if there are any other parts of the hyperparameter tuning code I forgot to make compliant with the current Flair syntax