Skip to content

wtong98/mlp-icl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLPs Learn In-Context

This repository contains code that reproduces experiments from our paper MLPs Learn In-Context

Installation

We use Python 3.10.12.

To install the requisite Python packages, run

pip install -r requirements.txt

By default, this will install CPU-only Jax. For GPU-enabled Jax, run

pip install -U "jax[cuda12]"

Depending on your machine, you may need additional or alternative packages for GPU-enabled Jax. Please consult the Jax installation instructions for details.

Running the code

Code organization:

  • experiment/: experiments and plotting
  • model/: model implementations
  • task/: task implementations
  • train: training routines

Experiments are organized per file. For example, 12_icl_clean.py will plot results for ICL regression and classification. Each experiment file expects results to be present in experiment/remote, which contains scripts that generate the results, and are intended to be run on a compute cluster. Running the corresponding run.py script for each sub-experiment will generate the corresponding results.

Experiment files are formatted as Jupyter code cells. One way of running these files interactively is through the Jupyter extension in VSCode. However, because the formatting occurs through comment-based delimiters, these files may also be run as standard Python files.

If you notice any errors or issues, we welcome your pull requests!

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages