diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 78b62f485..929f6fe81 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -1,5 +1,6 @@ import tempfile +import equinox import jax import jax.numpy as jnp import jax.random as jrandom @@ -10,6 +11,8 @@ from transformers import GPT2Config as HfGpt2Config from transformers import GPT2LMHeadModel as HfGpt2LMHeadModel +from levanter.config import TrainerConfig + def has_torch(): try: @@ -82,3 +85,92 @@ def compute(input): torch_out2 = torch_out2.logits[0].detach().cpu().numpy() torch_out2 = jax.nn.softmax(torch_out2, axis=-1) assert onp.isclose(torch_out2, onp.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + + +# Gradient tests + + +@pytest.mark.skipif(not has_torch(), reason="torch not installed") +def test_hf_gradient(): + _compare_gpt2_checkpoint_gradients("gpt2", None) + + +def _compare_gpt2_checkpoint_gradients(model_id, revision): + import torch + + from levanter.compat.torch_checkpoints import load_hf_gpt2_checkpoint, load_hf_model_checkpoint + + config, data = load_hf_model_checkpoint(model_id, revision=revision) + config = HfGpt2Config.from_dict(config) + torch_model: HfGpt2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_id, config=config, revision=revision) + torch_model.eval() + + model = load_hf_gpt2_checkpoint(model_id, revision=revision) + + input = _rand_input(PRNGKey(0), 1, config.n_positions, config.vocab_size) + + def torch_loss(model, input_ids) -> torch.Tensor: + return model(input_ids, labels=input_ids)[0] + + torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input)).to(torch.int64)) + + def compute_loss(model, input_ids): + pred_y = model(input_ids, key=None) + token_loss = jnp.mean( + optax.softmax_cross_entropy( + pred_y[:-1], + jax.nn.one_hot(input_ids[1:], num_classes=model.vocab_size), + ) + ) + + return token_loss + + jax_compute_grad = jax.value_and_grad(compute_loss) + jax_loss, jax_grad = jax_compute_grad(model, input[0]) + + # gradients are kind of a pain to get at in torch, but we do it anyway + torch_out.backward() + torch_dict = torch_model.transformer.state_dict(keep_vars=True) + torch_dict = {k: v.grad for k, v in torch_dict.items()} + + jax_grads = jax.tree_util.tree_leaves(jax_grad) + jax_grad_keys = jax_grad.torch_key_leaves() + + for jax_g, jax_key in zip(jax_grads, jax_grad_keys): + torch_g = torch_dict[jax_key] + assert onp.isclose(jax_g, torch_g.detach().cpu().numpy(), rtol=1e-2, atol=1e-2).all(), f"{jax_g} != {torch_g}" + + # now we also want to check that the optimizers do similar things + trainer_config = TrainerConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0) + + if trainer_config.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), trainer_config.max_grad_norm) + torch_optimizer = torch.optim.AdamW( + torch_model.parameters(), + lr=trainer_config.learning_rate, + weight_decay=trainer_config.weight_decay, + betas=(trainer_config.beta1, trainer_config.beta2), + eps=trainer_config.epsilon, + ) + + torch_optimizer.step() + + jax_optimizer = trainer_config.optimizer() + state = jax_optimizer.init(model) + updates, state = jax_optimizer.update(updates=jax_grad, state=state, params=model) + new_model = equinox.apply_updates(model, updates) + + new_leaves = jax.tree_util.tree_leaves(new_model) + torch_dict = torch_model.transformer.state_dict(keep_vars=True) + old_leaves = jax.tree_util.tree_leaves(model) + + # now compare new params + for leaf, key, old_leaf in zip(new_leaves, jax_grad_keys, old_leaves): + print(key) + torch_leaf = torch_dict[key] + # print the distance between the two + print(onp.linalg.norm(leaf - torch_leaf.detach().cpu().numpy(), ord=onp.inf)) + print(onp.linalg.norm(old_leaf - torch_leaf.detach().cpu().numpy(), ord=onp.inf)) + assert onp.isclose( + leaf, torch_leaf.detach().cpu().numpy(), rtol=1e-2, atol=1e-2 + ).all(), f"{key}: {leaf} != {torch_leaf}"