Skip to content

Commit

Permalink
test to ensure adagrad etc behaves simialrly to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Sep 7, 2022
1 parent 4c96044 commit f80c46f
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions tests/test_hf_gpt2_serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tempfile

import equinox
import jax
import jax.numpy as jnp
import jax.random as jrandom
Expand All @@ -10,6 +11,8 @@
from transformers import GPT2Config as HfGpt2Config
from transformers import GPT2LMHeadModel as HfGpt2LMHeadModel

from levanter.config import TrainerConfig


def has_torch():
try:
Expand Down Expand Up @@ -82,3 +85,92 @@ def compute(input):
torch_out2 = torch_out2.logits[0].detach().cpu().numpy()
torch_out2 = jax.nn.softmax(torch_out2, axis=-1)
assert onp.isclose(torch_out2, onp.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}"


# Gradient tests


@pytest.mark.skipif(not has_torch(), reason="torch not installed")
def test_hf_gradient():
_compare_gpt2_checkpoint_gradients("gpt2", None)


def _compare_gpt2_checkpoint_gradients(model_id, revision):
import torch

from levanter.compat.torch_checkpoints import load_hf_gpt2_checkpoint, load_hf_model_checkpoint

config, data = load_hf_model_checkpoint(model_id, revision=revision)
config = HfGpt2Config.from_dict(config)
torch_model: HfGpt2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_id, config=config, revision=revision)
torch_model.eval()

model = load_hf_gpt2_checkpoint(model_id, revision=revision)

input = _rand_input(PRNGKey(0), 1, config.n_positions, config.vocab_size)

def torch_loss(model, input_ids) -> torch.Tensor:
return model(input_ids, labels=input_ids)[0]

torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input)).to(torch.int64))

def compute_loss(model, input_ids):
pred_y = model(input_ids, key=None)
token_loss = jnp.mean(
optax.softmax_cross_entropy(
pred_y[:-1],
jax.nn.one_hot(input_ids[1:], num_classes=model.vocab_size),
)
)

return token_loss

jax_compute_grad = jax.value_and_grad(compute_loss)
jax_loss, jax_grad = jax_compute_grad(model, input[0])

# gradients are kind of a pain to get at in torch, but we do it anyway
torch_out.backward()
torch_dict = torch_model.transformer.state_dict(keep_vars=True)
torch_dict = {k: v.grad for k, v in torch_dict.items()}

jax_grads = jax.tree_util.tree_leaves(jax_grad)
jax_grad_keys = jax_grad.torch_key_leaves()

for jax_g, jax_key in zip(jax_grads, jax_grad_keys):
torch_g = torch_dict[jax_key]
assert onp.isclose(jax_g, torch_g.detach().cpu().numpy(), rtol=1e-2, atol=1e-2).all(), f"{jax_g} != {torch_g}"

# now we also want to check that the optimizers do similar things
trainer_config = TrainerConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0)

if trainer_config.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), trainer_config.max_grad_norm)
torch_optimizer = torch.optim.AdamW(
torch_model.parameters(),
lr=trainer_config.learning_rate,
weight_decay=trainer_config.weight_decay,
betas=(trainer_config.beta1, trainer_config.beta2),
eps=trainer_config.epsilon,
)

torch_optimizer.step()

jax_optimizer = trainer_config.optimizer()
state = jax_optimizer.init(model)
updates, state = jax_optimizer.update(updates=jax_grad, state=state, params=model)
new_model = equinox.apply_updates(model, updates)

new_leaves = jax.tree_util.tree_leaves(new_model)
torch_dict = torch_model.transformer.state_dict(keep_vars=True)
old_leaves = jax.tree_util.tree_leaves(model)

# now compare new params
for leaf, key, old_leaf in zip(new_leaves, jax_grad_keys, old_leaves):
print(key)
torch_leaf = torch_dict[key]
# print the distance between the two
print(onp.linalg.norm(leaf - torch_leaf.detach().cpu().numpy(), ord=onp.inf))
print(onp.linalg.norm(old_leaf - torch_leaf.detach().cpu().numpy(), ord=onp.inf))
assert onp.isclose(
leaf, torch_leaf.detach().cpu().numpy(), rtol=1e-2, atol=1e-2
).all(), f"{key}: {leaf} != {torch_leaf}"

0 comments on commit f80c46f

Please sign in to comment.