Skip to content

Training small GPT-2 style models using Kolmogorov-Arnold networks.

Notifications You must be signed in to change notification settings

CG80499/KAN-GPT-2

Repository files navigation

Training small GPT-2 style models using KANs instead of MLPs in JAX

This repository compares transformers using multilayer perceptron (MLP) and Kolmogorov-Arnold networks (KAN) layers.

Key points:

  • Uses Kolmogorov-Arnold Networks but with Chebyshev polynomials as the basis (inspired by this repo).
  • The tanh function is used to keep the activation values within [-1, 1] rather than using grids that update during training.
  • Both models are trained on 134M tokens of TinyStories.
  • They both use standard GPT-2 architecture (other than the KAN part).
  • The MLP version has 3.3M non-embedding weights and the KAN model has 2.5M non-embedding weights (~25% fewer).

Results:

They both achieve a final loss of ~2.46 (despite the KAN model having 25% fewer parameters!). image

Hyperparameters:

  • d_model: 128
  • d_mlp: 768 (when applicable)
  • n_heads: 8
  • n_layers: 16
  • learning_rate: 1e-5
  • batch_size: 16
  • weight_decay: 0.001
  • optimizer: adamw
  • seq_len: 64

Hardware: Single 1080ti GPU

Wandb: link.

About

Training small GPT-2 style models using Kolmogorov-Arnold networks.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages