Skip to content

Notes, resources, packaged code, and minimal viable implementations of transformers and large language models.

Notifications You must be signed in to change notification settings

rosikand/min-llm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Implementing transformers

This is a learning repo where I will take notes on and implement LLM's from scratch, mainly for learning purposes. Essentially, it will collate notes, resources, and code implementations for learning about LLM's.

Code repo structure

  • ./apps:
    • Examples using the package
  • ./notes:
    • Unstructed markdown notes on various topics
  • ./min_llm:
    • Contains the packaged code
  • ./nbs:
    • Contains sandbox notebooks for learning concepts
  • playground:
    • Contains sandbox code for learning concepts

Goal

Implement llama 3 in JAX and train a mini version of it. Implement all the fancy techniques discussed in the linked paper. Also make package of useful llm utilities like here.

Resources

Blog posts

Papers

Chronologically ordered

More:

Books

Posts

Other

  • GPU MODE YouTube channel

Why?

Jax (eventually):

  • Jax has better parallelization primitives which are useful for training large models
  • Jax is lower-level and more similar to numpy, which forces you to dive deeper into the concepts
  • Developing at a lower-level will make it easier to implement custom add-ons like speeding up inference with CUDA kernels or porting the inference module to C/Rust
  • In the meantime, we will use some fancy PyTorch distributed stuff

Things to learn/implement

  • Architectures:
    • GPT-2 (starter)
    • Llama (main goal)
    • VLM
    • Pixtral/multimodal VLM
    • Efficient VLA for robotics
    • Mixture of experts
    • SSM's. Start with langauge and then move to speech.
    • Implement mini versions of architectures like Gemini, Mixtral
  • MLSys:
    • CUDA kernels
    • Triton kernels
    • Thunderkitten kernels
    • Quantization
    • pybind to integrate custom-written kernels into a PyTorch framework
    • Model and data parallelism across GPU's: tensor parallelism, column parallelism, pipeline parallelism, data parallelism (fully-sharded data parallelism)
  • General learning:
    • Optimizing for both memory-bound and compute-bound operations
    • Understanding GPU memory hierarchy and computation capabilities
    • Efficient attention algorithms

Installation and usage

Some parts of this repo implement a package that you can download and use. The motivation is inspired by Meta's Lingua.

$ pip install git+https://github.com/rosikand/min-llm.git

Usage:

import min_llm

About

Notes, resources, packaged code, and minimal viable implementations of transformers and large language models.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages