Skip to content

Official implementation of Auxiliary Learning by Implicit Differentiation [ICLR 2021]

License

Notifications You must be signed in to change notification settings

AvivNavon/AuxiLearn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AuxiLearn - Auxiliary Learning by Implicit Differentiation

This repository contains the source code to support the paper Auxiliary Learning by Implicit Differentiation, by Aviv Navon*, Idan Achituve*, Haggai Maron, Gal Chechik and Ethan Fetaya, ICLR 2021.


Links

  1. Paper
  2. Project page

Installation

Please note: We encountered some issues and drops in performance while working with different PyTorch versions. Please install AuxiLearn on a clean virtual environment!

python3 -m venv <venv>
source <venv>/bin/activate

On a clean virtual environment clone the repo and install:

git clone https://github.com/AvivNavon/AuxiLearn.git
cd AuxiLearn
pip install .

Usage

Given a bi-level optimization problem in which the upper-level parameters (i.e., auxiliary parameters) are only implicitly affecting the upper-level objective, you can use auxilearn to compute the upper-level gradients through implicit differentiation.

The main code component you will need to use is auxilearn.optim.MetaOptimizer. It is a wrapper over PyTorch optimizers that updates its parameters through implicit differentiation.

Code example

We assume two models, primary_model and auxiliary_model, and two dataloaders. The primary_model is optimized using the train data in the train_loader, and the auxiliary_model is optimized using the auxiliary set in the aux_loader. We assume a loss_fuction that return the train loss if train=True, or auxiliary set loss if train=False. Also, we assume the training loss is a function of both the primary parameters and the auxiliary parameters, and that the loss on the auxiliary set (or validation set) is a function of the primary parameters only. In Auxiliary Learning, the auxiliary set loss is the loss on the main task (see paper for more details).

from auxilearn.optim import MetaOptimizer

primary_model = MyModel()
auxiliary_model = MyAuxiliaryModel()
# optimizers
primary_optimizer = torch.optim.Adam(primary_model.parameters())

aux_lr = 1e-4
aux_base_optimizer = torch.optim.Adam(auxiliary_model.parameters(), lr=aux_lr)
aux_optimizer = MetaOptimizer(aux_base_optimizer, hpo_lr=aux_lr)

# training loop
step = 0
for epoch in range(epochs):
    for batch in train_loder:
        step += 1
        # calculate batch loss using 'primary_model' and 'auxiliary_model'
        primary_optimizer.zero_grad()
        loss = loss_func(train=True)
        # update primary parameters
        loss.backward()
        primary_optimizer.step()
        
        # condition for updating auxiliary parameters
        if step % aux_params_update_every == 0:
            # calc current train loss
            train_set_loss = loss_func(train=True)
            # calc current auxiliary set loss - this is the loss over the main task
            auxiliary_set_loss = loss_func(train=False) 
            
            # update auxiliary parameters - no need to call loss.backwards() or aux_optimizer.zero_grad()
            aux_optimizer.step(
                val_loss=auxiliary_set_loss,
                train_loss=train_set_loss,
                aux_params=auxiliary_model.parameters(),
                parameters=primary_model.parameters(),
            )

Citation

If you find auxilearn to be useful in your own research, please consider citing the following paper:

@inproceedings{
navon2021auxiliary,
title={Auxiliary Learning by Implicit Differentiation},
author={Aviv Navon and Idan Achituve and Haggai Maron and Gal Chechik and Ethan Fetaya},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=n7wIfYPdVet}
}

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages