Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reintroduce hyperparameter tuning #2633

Merged

Conversation

tadejmagajna
Copy link
Collaborator

@tadejmagajna tadejmagajna commented Feb 14, 2022

This PR reintroduces hyperparameter tuning in Flair
closes #2632

A number of small changes was required so I split this PR into appropriately named commits where each commit is an isolated change to the code base. For example, there is a separate commit that reverts the deletion or hyperparameter tuning in Flair - this helps us quickly see which code changes are new and which simply reintroduce the old code.

The changes and their corresponding commits include:

  1. Revert "Removes hyperparameter features" Reverts the removal of all hyperparameter runing code
  2. Updating the param selection docs for the v0.10 syntax. This updates the README tutorial for English and Korean (I did my best with Korean here) to fit the new v0.10 compliant syntax.
  3. Adding hyperopt back to requirements.txt adds the most recent version of hyperopt back to requirements
  4. Fixing paramselection code to work with changes in Flair v0.10 This updates the original hyperparameter syntax to work with changes applied in v0.10 such as, for example, the make_label_dictionary(label_type) requiring the label_type argument
  5. Fixing bug where embeddings got added twice on multiple training runs. This provides a temporary solution for StackedEmbeddings modifies original embeddings' name which causes unwanted behaviour and potential crashes #2600 where embeddings were added to tokens multiple times and this caused crashes (can be reproduced by running the TUTORIAL_8 code). Note that the fix probably does probably slow the execution down a bit, but it only affects runs where training_runs>1
  6. Enabling and fixing tests for param selection. Here I reintroduce hyperparameter tests, apply some fixes to them and enable them again because they were previously disabled for some reason. There was also a deprecation warning related to some dependency of Hyperopt. To fix it I applied a change to to pyproject.toml where I swapped the -W parameter for the more-readable filterwarnings parameter and made it ignore that specific DeprecationWarning.
  7. Automatic code formatting of param selection tests auto code formatting

Please do me know if there are any other parts of the hyperparameter tuning code I forgot to make compliant with the current Flair syntax

@tadejmagajna tadejmagajna force-pushed the GH-2632-reintroduce-hyperparameter-tuning branch from 3198ed5 to 88f44b2 Compare February 14, 2022 21:02
Copy link
Collaborator

@alanakbik alanakbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tadejmagajna Thanks for bringing this back! But some issues regarding the examples, see comments.

def _set_up_model(self, params: dict):
embdding_params = {key: params[key] for key in params if key in DOCUMENT_EMBEDDING_PARAMETERS}

if self.document_embedding_type == "lstm":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems the only options are RNN or pooled word embeddings, but the standard approach today would be to simply use a transformer only which far outperforms RNNs trained over word embeddings. Not necessarily something that needs to be fixed in this PR but something that limits the usefulness of our parameter selection for text classifiers at the moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per your suggestion (which makes perfects sense - it felt really weird readding this outdated RNN code) I refactored TextClassifierParamSelector so that it uses TransformerDocumentEmbeddings now. I documented it to a reasonable extent in the tutorial.

Parameter.WORD_DROPOUT.value,
]
DOCUMENT_EMBEDDING_PARAMETERS = [
Parameter.EMBEDDINGS.value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one idea for the problem above would be to support param selection only for TransformerDocumentEmbeddings and define some parameters that one would like to search over (which transformer model, which layers, fine-tune yes/no) and remove the RNN specific ones.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did just that and refactored TextClassifierParamSelector so that it uses TransformerDocumentEmbeddings. Also renamed DOCUMENT_EMBEDDING_PARAMETERS to TEXT_CLASSIFICATION_PARAMETERS

from flair.datasets import TREC_6

# load your corpus
corpus = TREC_6()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tutorial is a bit counterintuitive since it defines a text classification corpus and then word embeddings. The fact that there's an LSTM used is buried in the model initialization and not explained.

Perhaps switch to a sequence labeling example? WNUT_17 for instance could be used as example corpus, and for sequence labeling the current implementation allows some interesting parameter searches.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - makes sense to give precedence to sequence labelling in Flair.

As per your suggestion I added an example to the top that uses WNUT_17 and documents the use of SequenceTaggerParamSelector

However, I also kept the example that shows how to use TextClassifierParamSelector. I feel like that if a feature isn't documented, people are unlikely to dig through the source and the value of the feature will be diminished.

So what I did was I created two sections (Selecting hyperparameters for sequence labelling and Selecting hyperparameters for text classification) and documented both use-cases accordingly.

tag_type = 'ner'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be tag_dictionary = corpus.make_label_dictionary(label_type=tag_type), otherwise the example won't run

trainer: ModelTrainer = ModelTrainer(tagger, corpus)

# 7. find learning rate
learning_rate_tsv = trainer.find_learning_rate('resources/taggers/example-ner', Adam)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is broken, it should be changed to the code below, ideally as part of this PR.

    def find_learning_rate(
        self,
        base_path: Union[Path, str],
        optimizer,
        file_name: str = "learning_rate.tsv",
        start_learning_rate: float = 1e-7,
        end_learning_rate: float = 10,
        iterations: int = 100,
        mini_batch_size: int = 32,
        stop_early: bool = True,
        smoothing_factor: float = 0.98,
        **kwargs,
    ) -> Path:
        best_loss = None
        moving_avg_loss = 0

        # cast string to Path
        if type(base_path) is str:
            base_path = Path(base_path)
        learning_rate_tsv = init_output_file(base_path, file_name)

        with open(learning_rate_tsv, "a") as f:
            f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n")

        optimizer = optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs)

        train_data = self.corpus.train

        scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)

        model_state = self.model.state_dict()
        self.model.train()

        step = 0
        while step < iterations:
            batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
            for batch in batch_loader:
                step += 1

                # forward pass
                loss = self.model.forward_loss(batch)
                if isinstance(loss, tuple):
                    loss = loss[0]

                # update optimizer and scheduler
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                optimizer.step()
                scheduler.step(step)

                learning_rate = scheduler.get_lr()[0]

                loss_item = loss.item()
                if step == 1:
                    best_loss = loss_item
                else:
                    if smoothing_factor > 0:
                        moving_avg_loss = smoothing_factor * moving_avg_loss + (1 - smoothing_factor) * loss_item
                        loss_item = moving_avg_loss / (1 - smoothing_factor ** (step + 1))
                    if loss_item < best_loss:
                        best_loss = loss

                if step > iterations:
                    break

                if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)):
                    log_line(log)
                    log.info("loss diverged - stopping early!")
                    step = iterations
                    break

                with open(str(learning_rate_tsv), "a") as f:
                    f.write(f"{step}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n")

            self.model.load_state_dict(model_state)
            self.model.to(flair.device)

        log_line(log)
        log.info(f"learning rate finder finished - plot {learning_rate_tsv}")
        log_line(log)

        return Path(learning_rate_tsv)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this as part of this PR, but note that I made tiny adjustment to the source code you pasted above. With the above code, we were getting the following warining from Torch:

UserWarning: The epoch parameter in scheduler.step() was not necessary and is being deprecated where possible. Please use scheduler.step() to step the scheduler.

this warning broke the CI build.

So I replaced the line scheduler.step(step) with scheduler.step(). If my understanding is correct, this shouldn't badly affect the execution of the code. But if it does, please do let me know.

@tadejmagajna
Copy link
Collaborator Author

@alanakbik Thank you for the feedback!

I addressed your PR comments and I incorporated all of the suggestions including the refactoring of TextClassifierParamSelector so that it uses TransformerDocumentEmbeddings (instead of RNN) as well as fixing of the find_learning_rate(). I left more detailed comments in each individual PR feedback thread.

@tadejmagajna tadejmagajna requested a review from alanakbik March 7, 2022 01:36
@tadejmagajna
Copy link
Collaborator Author

@alanakbik I made fune_tune a normal (non-tunable) param now and set it to default to True. Updated all the docs as well

@alanakbik
Copy link
Collaborator

@tadejmagajna thanks a lot - looks great!

@alanakbik alanakbik merged commit 872e66f into flairNLP:master Mar 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reintroduce hyperparameter tuning in Flair
2 participants