Skip to content

Latest commit

 

History

History
30 lines (22 loc) · 1.66 KB

README.md

File metadata and controls

30 lines (22 loc) · 1.66 KB

MLPs Learn In-Context

This repository contains code that reproduces experiments from our paper MLPs Learn In-Context

Installation

We use Python 3.10.12.

To install the requisite Python packages, run

pip install -r requirements.txt

By default, this will install CPU-only Jax. For GPU-enabled Jax, run

pip install -U "jax[cuda12]"

Depending on your machine, you may need additional or alternative packages for GPU-enabled Jax. Please consult the Jax installation instructions for details.

Running the code

Code organization:

  • experiment/: experiments and plotting
  • model/: model implementations
  • task/: task implementations
  • train: training routines

Experiments are organized per file. For example, 12_icl_clean.py will plot results for ICL regression and classification. Each experiment file expects results to be present in experiment/remote, which contains scripts that generate the results, and are intended to be run on a compute cluster. Running the corresponding run.py script for each sub-experiment will generate the corresponding results.

Experiment files are formatted as Jupyter code cells. One way of running these files interactively is through the Jupyter extension in VSCode. However, because the formatting occurs through comment-based delimiters, these files may also be run as standard Python files.

If you notice any errors or issues, we welcome your pull requests!