diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index ab4971eef..254377578 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -433,7 +433,7 @@ def wrapper(x, logits, w0, w1, wo): layer_w0 = gmm(x, w0, group_sizes) layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") layer_w1 = gmm(x, w1, group_sizes) - layer_w1 = checkpoint_name(layer_w0, "mlpwi_1") + layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") layer_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) intermediate_layer = jnp.multiply(layer_act, layer_w1) intermediate_output = gmm(intermediate_layer, wo, group_sizes) @@ -604,6 +604,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): ) return output, loss else: + top_k_weights /= top_k_weights.sum(-1, keepdims=True) weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("wi_0"): diff --git a/MaxText/scratch_code/mixtral-numerical-verification.ipynb b/MaxText/scratch_code/mixtral-numerical-verification.ipynb new file mode 100644 index 000000000..61d05d66e --- /dev/null +++ b/MaxText/scratch_code/mixtral-numerical-verification.ipynb @@ -0,0 +1,248 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "bce1951a-8eef-4842-a70f-987b85a3240f", + "metadata": {}, + "outputs": [], + "source": [ + "# installation\n", + "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", + "!pip3 install tokenizers -U\n", + "!pip3 install transformers -U" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9769e847-d838-473d-8d32-1061b3e0f1c8", + "metadata": {}, + "outputs": [], + "source": [ + "# go to maxtext/MaxText for library import\n", + "\n", + "current_dir = %pwd\n", + "working_dir = current_dir.replace(\"scratch_code\", \"\") \n", + "%cd $working_dir" + ] + }, + { + "cell_type": "markdown", + "id": "f1c108fc-d739-471d-9c64-c08151845f06", + "metadata": {}, + "source": [ + "# one layer mixtral model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf8eee59-295e-41f4-8c09-d2177b410ddc", + "metadata": {}, + "outputs": [], + "source": [ + "import pyconfig\n", + "from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n", + "\n", + "pyconfig.initialize(\n", + " [None, \"configs/base.yml\"],\n", + " base_emb_dim=4096,\n", + " base_num_query_heads=32,\n", + " base_num_kv_heads=8,\n", + " base_mlp_dim=14336,\n", + " base_num_decoder_layers=1, # 1 layer for simplicity\n", + " head_dim=128,\n", + " mlp_activations=[\"silu\",\"linear\"],\n", + " vocab_size=32000,\n", + " enable_dropout=False,\n", + " logits_via_embedding=False,\n", + " normalization_layer_epsilon=1.0e-5,\n", + " num_experts=8,\n", + " num_experts_per_tok=2,\n", + " rope_max_timescale=1_000_000,\n", + " decoder_block=\"mistral\",\n", + " run_name=\"moe_test\",\n", + " enable_checkpointing=False,\n", + " dtype=\"bfloat16\",\n", + " weight_dtype=\"bfloat16\",\n", + " megablox=True, # or False\n", + " max_target_length=4,\n", + " per_device_batch_size=1,\n", + " capacity_factor=-1,\n", + " scan_layers=False,\n", + ")\n", + "config_maxtext = pyconfig.config\n", + "\n", + "config_hf = MixtralConfig(\n", + " vocab_size=config_maxtext.vocab_size,\n", + " hidden_size=config_maxtext.emb_dim,\n", + " intermediate_size=config_maxtext.mlp_dim,\n", + " num_hidden_layers=config_maxtext.num_decoder_layers, \n", + " num_attention_heads=config_maxtext.base_num_query_heads,\n", + " num_key_value_heads=config_maxtext.num_kv_heads,\n", + " rms_norm_eps=config_maxtext.normalization_layer_epsilon,\n", + " rope_theta=config_maxtext.rope_max_timescale,\n", + " attention_dropout=0.0,\n", + " num_experts_per_tok=config_maxtext.num_experts_per_tok,\n", + " num_local_experts=config_maxtext.num_experts,\n", + " tie_word_embeddings=config_maxtext.logits_via_embedding,\n", + " output_router_logits=False,\n", + " router_aux_loss_coef=0.001,\n", + " router_jitter_noise=0.0,\n", + " torch_dtype=\"bfloat16\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c94c857a-2efd-48f3-9669-aef926329cbd", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, set_seed\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from layers.models import Transformer\n", + "import max_utils\n", + "from jax.sharding import Mesh\n", + "\n", + "# ensure the same model initialization\n", + "set_seed(0)\n", + "\n", + "model_hf = AutoModelForCausalLM.from_config(config_hf)\n", + "\n", + "devices_array = max_utils.create_device_mesh(config_maxtext)\n", + "mesh = Mesh(devices_array, config_maxtext.mesh_axes)\n", + "prng_key = jax.random.PRNGKey(1234)\n", + "model_maxtext = Transformer(config=config_maxtext, mesh=mesh, quant=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "707df022-ec37-44b3-b203-5f938151c6ca", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "input_np = {\n", + " 'inputs': np.random.randint(0, config_maxtext.vocab_size, size=(int(config_maxtext.per_device_batch_size), config_maxtext.max_target_length)),\n", + " 'inputs_position': np.tile(np.arange(config_maxtext.max_target_length), (int(config_maxtext.per_device_batch_size), 1)),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "baca50fb-28f2-48b1-b4f5-0145ac6cfe38", + "metadata": {}, + "outputs": [], + "source": [ + "state_maxtext = model_maxtext.init({'params': prng_key, 'dropout': prng_key, 'aqt': prng_key},\n", + " jnp.array(input_np['inputs']),\n", + " jnp.array(input_np['inputs_position']),\n", + " enable_dropout=config_maxtext.enable_dropout,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74e8353b-b87a-4c5e-9a7c-138052249250", + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "from flax import linen as nn\n", + "\n", + "state_map = {\n", + " \"['params']['decoder']['decoder_norm']['scale'].value\": (\"model.norm.weight\", lambda x: x), \n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['gate']['kernel'].value\": (\"model.layers.0.block_sparse_moe.gate.weight\", lambda x: x.T),\n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_0'].value\": (\"model.layers.0.block_sparse_moe.experts..w1.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_1'].value\": (\"model.layers.0.block_sparse_moe.experts..w3.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['wo'].value\": (\"model.layers.0.block_sparse_moe.experts..w2.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n", + " \"['params']['decoder']['layers_0']['post_self_attention_layer_norm']['scale'].value\": (\"model.layers.0.post_attention_layernorm.weight\", lambda x: x),\n", + " \"['params']['decoder']['layers_0']['pre_self_attention_layer_norm']['scale'].value\": (\"model.layers.0.input_layernorm.weight\", lambda x:x),\n", + " \"['params']['decoder']['layers_0']['self_attention']['key']['kernel'].value\": (\"model.layers.0.self_attn.k_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),\n", + " \"['params']['decoder']['layers_0']['self_attention']['out']['kernel'].value\": (\"model.layers.0.self_attn.o_proj.weight\", lambda x:x.T.reshape(config_hf.num_attention_heads, config_maxtext.head_dim, config_hf.hidden_size)),\n", + " \"['params']['decoder']['layers_0']['self_attention']['query']['kernel'].value\": (\"model.layers.0.self_attn.q_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_attention_heads, config_maxtext.head_dim) / np.sqrt(config_maxtext.head_dim)),\n", + " \"['params']['decoder']['layers_0']['self_attention']['value']['kernel'].value\": (\"model.layers.0.self_attn.v_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),\n", + " \"['params']['decoder']['logits_dense']['kernel'].value\": (\"lm_head.weight\", lambda x:x.T),\n", + " \"['params']['token_embedder']['embedding'].value\": (\"model.embed_tokens.weight\", lambda x:x),\n", + " }\n", + "\n", + "state_hf = model_hf.state_dict()\n", + "def map_fn(key_path, value):\n", + " key_path_str = jax.tree_util.keystr(key_path)\n", + " torch_key, transform_fn = state_map[key_path_str]\n", + " if \"\" in torch_key:\n", + " torch_tensors = [state_hf[torch_key.replace(\"\", str(i))] for i in range(config_hf.num_local_experts)]\n", + " else:\n", + " torch_tensors = state_hf[torch_key]\n", + " \n", + " torch_tensors = transform_fn(torch_tensors)\n", + "\n", + " assert value.shape == torch_tensors.shape, f\"{key_path_str}, {value.shape}, {torch_tensors.shape}\"\n", + " new_value = jnp.array(torch_tensors.to(torch.float32).numpy(), dtype=value.dtype)\n", + " if isinstance(value, nn.LogicallyPartitioned):\n", + " new_value = value.replace_boxed(new_value)\n", + " return new_value\n", + "\n", + "loaded_state_maxtext = jax.tree_util.tree_map_with_path(map_fn, state_maxtext)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1f88708-c3a6-4b95-bc51-94adfebdf2aa", + "metadata": {}, + "outputs": [], + "source": [ + "logits_hf = model_hf(torch.from_numpy(input_np['inputs'])).logits.detach()\n", + "\n", + "logits_maxtext = model_maxtext.apply(\n", + " loaded_state_maxtext,\n", + " input_np['inputs'],\n", + " input_np['inputs_position'],\n", + " enable_dropout=False,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1207375a-b92c-4a8c-975a-21f2f027d91e", + "metadata": {}, + "outputs": [], + "source": [ + "# currently, pass the following tests in both \"megablox=True\" & \"megablox=False capacity_factor=-1\"\n", + "\n", + "np.testing.assert_allclose(np.array(logits_maxtext), logits_hf.numpy(), rtol=1e-1, atol=1e-1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}