Skip to content

Understanding the interplay between memorization and generalization in neural networks, featuring MAT, a learning algorithm to enhance robustness by mitigating spurious correlations.

License

Notifications You must be signed in to change notification settings

facebookresearch/Pitfalls-of-Memorization

Repository files navigation

The Pitfalls of Memorization: When Memorization Hurts Generalization

License Python Build

This repository contains the code associated with the paper:
The Pitfalls of Memorization: When Memorization Hurts Generalization
Authors: Reza Bayat*, Mohammad Pezeshki*, Elvis Dohmatob, David Lopez-Paz, Pascal Vincent

We explore the interplay between memorization and generalization in neural networks. Includes Memorization-Aware Training (MAT), a novel framework to mitigate the adverse effects of memorization and spurious correlations, alongside theoretical insights, algorithms, and experiments that deepen our understanding of how memorization impacts generalization under distribution shifts.

The Interpretable Experiment (Figure 1)

python interpretable_experiment.py

Memorization: The Good, the Bad, and the Ugly (Figure 3)

python good_bad_ugly_memorization.py

Subpopulation Shift Experiments (Table 1)

Install the required packages and download the datasets:

pip install -r requirements.txt
python download.py --download --data_path ./data waterbirds celeba civilcomments multinli
export PYTHONPATH=$PYTHONPATH:./XRM

We first run XRM and store the held-out predictions for the training set as well as the inferred group labels for the validation set. For more details, checkout the instructions in the XRM repo. As an example, this is how it can be done for the Waterbirds dataset:

python main.py --phase 1 --datasets Waterbirds --group_labels no --algorithm XRM --out_dir ./phase_1_results --num_hparams_combs 10 --num_seeds 1 --slurm_partition <your_slurm_partition>

To run the MAT algorithm:

python main.py --phase 2 --datasets Waterbirds --group_labels yes --algorithm MAT --out_dir ./phase_2_results --phase_1_dir ./phase_1_results --num_hparams_combs 10 --num_seeds 1 --slurm_partition <your_slurm_partition>

To read the results:

  • Model selection using the best 'va_wga', i.e., validation worst group accuracy (ground-truth annotations)
python XRM/read_results.py --dir phase_2_results --datasets Waterbirds --algorithms MAT --group_labels yes --selection_criterion va_wga
  • Model selection using the best 'va_gi_wga', i.e., validation worst group accuracy (XRM-inferred annotations)
python XRM/read_results.py --dir phase_2_results --datasets Waterbirds --algorithms MAT --group_labels yes --selection_criterion va_gi_wga

License

This source code is released under the CC-BY-NC license, included here.

Citation

If you make use of our work or code, please cite this work :)

@article{bayat2024pitfalls,
  title={The Pitfalls of Memorization: When Memorization Hurts Generalization},
  author={Bayat, Reza and Pezeshki, Mohammad and Dohmatob, Elvis and Lopez-Paz, David and Vincent, Pascal},
  journal={arXiv preprint arXiv:2412.07684},
  year={2024}
}

About

Understanding the interplay between memorization and generalization in neural networks, featuring MAT, a learning algorithm to enhance robustness by mitigating spurious correlations.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages