Skip to content

Latest commit

 

History

History
117 lines (79 loc) · 4.29 KB

README.md

File metadata and controls

117 lines (79 loc) · 4.29 KB

Implementing transformers

This is a learning repo where I will take notes on and implement LLM's from scratch, mainly for learning purposes. Essentially, it will collate notes, resources, and code implementations for learning about LLM's.

Code repo structure

  • ./apps:
    • Examples using the package
  • ./notes:
    • Unstructed markdown notes on various topics
  • ./min_llm:
    • Contains the packaged code
  • ./nbs:
    • Contains sandbox notebooks for learning concepts
  • playground:
    • Contains sandbox code for learning concepts

Goal

Implement llama 3 in JAX and train a mini version of it. Implement all the fancy techniques discussed in the linked paper. Also make package of useful llm utilities like here.

Resources

Blog posts

Papers

Chronologically ordered

More:

Books

Posts

Other

  • GPU MODE YouTube channel

Why?

Jax (eventually):

  • Jax has better parallelization primitives which are useful for training large models
  • Jax is lower-level and more similar to numpy, which forces you to dive deeper into the concepts
  • Developing at a lower-level will make it easier to implement custom add-ons like speeding up inference with CUDA kernels or porting the inference module to C/Rust
  • In the meantime, we will use some fancy PyTorch distributed stuff

Things to learn/implement

  • Architectures:
    • GPT-2 (starter)
    • Llama (main goal)
    • VLM
    • Pixtral/multimodal VLM
    • Efficient VLA for robotics
    • Mixture of experts
    • SSM's. Start with langauge and then move to speech.
    • Implement mini versions of architectures like Gemini, Mixtral
  • MLSys:
    • CUDA kernels
    • Triton kernels
    • Thunderkitten kernels
    • Quantization
    • pybind to integrate custom-written kernels into a PyTorch framework
    • Model and data parallelism across GPU's: tensor parallelism, column parallelism, pipeline parallelism, data parallelism (fully-sharded data parallelism)
  • General learning:
    • Optimizing for both memory-bound and compute-bound operations
    • Understanding GPU memory hierarchy and computation capabilities
    • Efficient attention algorithms

Installation and usage

Some parts of this repo implement a package that you can download and use. The motivation is inspired by Meta's Lingua.

$ pip install git+https://github.com/rosikand/min-llm.git

Usage:

import min_llm