Skip to content
forked from shengliu66/ELR

Official Implementation of Early-Learning Regularization Prevents Memorization of Noisy Labels

Notifications You must be signed in to change notification settings

RShevcheCape/ELR

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

90 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Early-Learning Regularization Prevents Memorization of Noisy Labels

This repository is the official implementation of Early-Learning Regularization Prevents Memorization of Noisy Labels (NeurIPS 2020).

We propose a novel framework to perform classification via deep learning in the presence of noisy annotations. When trained on noisy labels, deep neural networks have been observed to first fit the training data with clean labels during an early learning phase, before eventually memorizing the examples with false labels. Our technique exploits the progress of the early learning phase via regularization to perform classification from noisy labels. There are two key elements to our approach. First, we leverage semi-supervised learning techniques to produce target probabilities based on the model outputs. Second, we design a regularization term that steers the model towards these targets, implicitly preventing memorization of the false labels. The resulting framework is shown to provide robustness to noisy annotations on several standard benchmarks and real-world datasets, where it achieves results comparable to the state of the art.

These graphs show the results of training a ResNet-34 with a traditional cross entropy loss (top row) and our proposed method (bottom row) to perform classification on the CIFAR-10 dataset where 40% of the labels are flipped at random. The left column shows the fraction of examples with clean labels that are predicted correctly (green) and incorrectly (blue). The right column shows the fraction of examples with wrong labels that are predicted correctly (green), memorized (the prediction equals the wrong label, shown in red), and incorrectly predicted as neither the true nor the labeled class (blue). The model trained with cross entropy begins by learning to predict the true labels, even for many of the examples with wrong labels, but eventually memorizes the wrong labels. Our proposed method based on early-learning regularization prevents memorization, allowing the model to continue learning on the examples with clean labels to attain high accuracy on examples with both clean and wrong labels.

Requirements

  • This codebase is written for python3.
  • To install necessary python packages, run pip install -r requirements.txt.

Training

Basics

  • ELR loss is implemented in the file loss.py
  • All functions used for training the basic version of our technique (ELR) can be found in the ELR folder.
  • All functions used for training the more advanced version (ELR+) can be found in the ELR_plus folder.
  • Experiments settings and configurations used for different datasets are in the corresponding config json files.

Data

  • Please download the data before running the code, add path to the downloaded data to data_loader.args.data_dir in the corresponding config file.

Training

  • Code for training ELR is in the following file: train.py, code for training ELR+ is in the following file: train.py
usage: train.py [-c] [-r] [-d] [--lr learning_rate] [--bs batch_size] [--beta beta] [--lambda lambda] [--malpha mixup_alpha]
                [--percent percent] [--asym asym] [--ealpha ema_alpha]  [--name exp_name] 

  arguments:
    -c, --config                  config file path (default: None)
    -r, --resume                  path to latest checkpoint (default: None)
    -d, --device                  indices of GPUs to enable (default: all)     
  
  options:
    --lr learning_rate            learning rate (default value is the value in the config file)
    --bs batch_size               batch size (default value is the value in the config file)
    --beta beta                   temporal ensembling momentum beta for target estimation
    --lambda lambda               regularization coefficient
    --malpha mixup_alpha          mixup parameter alpha
    --percent percent             noise level (e.g. 0.4 for 40%)
    --asym asym                   asymmetric noise is used when set to True
    --ealpha ema_alpha            weight averaging momentum for target estimation
     --name exp_name              experiment name

Configuration file is required to be specified. Default option values, if not reset, will be the values in the configuration file. Examples for ELR and ELR+ are shown in the readme.md of ELR and ELR_plus subfolders respectively.

Example

In order to use our proposed early learning regularization (ELR), you can simply replace your loss function by the following loss function. Usually, lambda which used to control the strength of the regularization term need to be tuned more carefully, and the value of beta is often quite robust (can be 0.7, 0.9 or 0.99, etc.)

class elr_loss(nn.Module):
    def __init__(self, num_examp, num_classes=10, lambda = 3, beta=0.7):
         r"""Early Learning Regularization.
         Parameters
         * `num_examp` Total number of training examples.
         * `num_classes` Number of classes in the classification problem.
         * `lambda` Regularization strength; must be a positive float, controling the strength of the ELR.
         * `beta` Temporal ensembling momentum for target estimation.
         """

        super(elr_loss, self).__init__()
        self.num_classes = num_classes
        self.USE_CUDA = torch.cuda.is_available()
        self.target = torch.zeros(num_examp, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_examp, self.num_classes)
        self.beta = beta
        self.lambda = lambda
        

    def forward(self, index, output, label):
        r"""Early Learning Regularization.
         Args
         * `index` Training sample index, due to training set shuffling, index is used to track training examples in different iterations.
         * `output` Model's logits, same as PyTorch provided loss functions.
         * `label` Labels, same as PyTorch provided loss functions.
         """

        y_pred = F.softmax(output,dim=1)
        y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
        y_pred_ = y_pred.data.detach()
        self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
        ce_loss = F.cross_entropy(output, label)
        elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
        final_loss = ce_loss +  self.lambda *elr_reg
        return  final_loss

License and Contributing

  • This README is formatted based on paperswithcode.
  • Feel free to post issues via Github.

Reference

For technical details and full experimental results, please check our paper.

@article{liu2020early,
  title={Early-Learning Regularization Prevents Memorization of Noisy Labels},
  author={Liu, Sheng and Niles-Weed, Jonathan and Razavian, Narges and Fernandez-Granda, Carlos},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

Contact

Please contact shengliu@nyu.edu if you have any question on the codes.

About

Official Implementation of Early-Learning Regularization Prevents Memorization of Noisy Labels

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%