Skip to content

Commit

Permalink
add @lee2024b grokfast
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Jun 16, 2024
1 parent 3532cc2 commit ba885df
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Description: Dockerfile for JAX with CUDA support
FROM 12.2.2-cudnn8-devel-ubuntu20.04
FROM nvidia/cuda:12.0.0-cudnn8-devel-ubuntu20.04

# Set the working directory
WORKDIR /workspace
Expand Down
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
base: 16
base: 10
n: 2048 # 16384
emb: 128
lr: 0.001
depth: 2
depth: 3
heads: 4
epochs: 2000
block: vaswani
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def main():

# train
apply_fn = miiii.make_apply_fn(miiii.vaswani_fn)
train_fn, opt_state = miiii.init_train(apply_fn, params, cfg, *data)
(params, opt_state), metrics = train_fn(cfg.epochs, rng, (params, opt_state))
train_fn, state = miiii.init_train(apply_fn, params, cfg, *data)
state, metrics = train_fn(cfg.epochs, rng, state)

# evaluate
log_run(cfg, metrics) # log run
Expand Down
51 changes: 40 additions & 11 deletions miiii/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import jax
from jax import random, grad, jit, value_and_grad
import optax
from jax import tree_util
from tqdm import tqdm
import jax.numpy as jnp
from typing import List, Set, Tuple
from functools import partial
from oeis import oeis
from einops import rearrange
import seaborn as sns
from typing import NamedTuple
import matplotlib.pyplot as plt

if __name__ == "__main__":
Expand All @@ -21,6 +23,13 @@
from .utils import alpha_fn


# constants
class TrainState(NamedTuple):
params: dict
opt_state: dict
ema_grads: dict


# functions
@jit
def loss_fn(logits, y): # cross entropy loss
Expand All @@ -40,28 +49,45 @@ def update_fn(params, grads, opt_state):

def make_grad_fn(loss_fn, apply_fn, cfg):
@jit
def grad_fn(params, rng, x, y): # maybe add allow_int flag below
def grad_fn(state, rng, x, y): # maybe add allow_int flag below
def loss_and_logits(params):
logits = apply_fn(params, rng, x, cfg.dropout)
loss = loss_fn(logits, y)
return loss, logits

(loss, logits), grads = value_and_grad(
loss_and_logits, allow_int=True, has_aux=True
)(params)
return loss, grads, logits
(loss, logits), grads = value_and_grad(loss_and_logits, has_aux=True)(
state.params
)
grads, state = gradfilter_ema(grads, state) # @lee2024b (grokfast)
return loss, grads, logits, state

return grad_fn


@partial(jit, static_argnums=(2,))
def gradfilter_ema(grads, state, alpha=0.98, lamb=2.0):
def _update_ema(prev_ema, grad):
return prev_ema * alpha + grad * (1 - alpha)

def _apply_ema(grad, ema):
return grad + ema * lamb

ema_grads = jax.tree_map(_update_ema, state.ema_grads, grads)
filtered_grads = jax.tree_map(_apply_ema, grads, ema_grads)
state = state._replace(ema_grads=ema_grads)
return filtered_grads, state


def make_step_fn(grad_fn, update_fn, train_data, eval_fn):
@jit
def step_fn(carry, rng):
(params, opt_state), (rng, key) = carry, random.split(rng)
train_loss, grads, train_logits = grad_fn(params, key, *train_data)
def step_fn(state, rng):
params, opt_state = state.params, state.opt_state
rng, key = random.split(rng)
loss, grads, logits, state = grad_fn(state, key, *train_data)
params, opt_state = update_fn(params, grads, opt_state)
metrics = eval_fn(params, rng, train_loss, train_logits)
return (params, opt_state), metrics
metrics = eval_fn(params, rng, loss, logits)
state = state._replace(params=params, opt_state=opt_state)
return state, metrics

return step_fn

Expand Down Expand Up @@ -118,11 +144,14 @@ def init_train(apply_fn, params, cfg, train_data, valid_data):
update_fn = make_update_fn(opt)
grad_fn = make_grad_fn(loss_fn, apply_fn, cfg)

ema_grads = jax.tree_map(jnp.zeros_like, params)
state = TrainState(params=params, opt_state=opt_state, ema_grads=ema_grads)

eval_fn = make_eval_fn(apply_fn, loss_fn, train_data, valid_data)
step_fn = make_step_fn(grad_fn, update_fn, train_data, eval_fn)

train_fn = make_train_fn(step_fn)
return train_fn, opt_state
return train_fn, state


# testing
Expand Down

0 comments on commit ba885df

Please sign in to comment.