This repository contains code that reproduces experiments from our paper MLPs Learn In-Context
We use Python 3.10.12.
To install the requisite Python packages, run
pip install -r requirements.txt
By default, this will install CPU-only Jax. For GPU-enabled Jax, run
pip install -U "jax[cuda12]"
Depending on your machine, you may need additional or alternative packages for GPU-enabled Jax. Please consult the Jax installation instructions for details.
Code organization:
experiment/
: experiments and plottingmodel/
: model implementationstask/
: task implementationstrain
: training routines
Experiments are organized per file. For example, 12_icl_clean.py
will plot results for ICL regression and classification. Each experiment file expects results to be present in experiment/remote
, which contains scripts that generate the results, and are intended to be run on a compute cluster. Running the corresponding run.py
script for each sub-experiment will generate the corresponding results.
Experiment files are formatted as Jupyter code cells. One way of running these files interactively is through the Jupyter extension in VSCode. However, because the formatting occurs through comment-based delimiters, these files may also be run as standard Python files.
If you notice any errors or issues, we welcome your pull requests!