Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mistral memory consumption with JAX and default dtype bug #1460

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def next(prompt, cache, index):
mask=padding_mask,
end_token_id=end_token_id,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/mistral/mistral_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"path": "mistral",
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3",
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6",
},
"mistral_instruct_7b_en": {
"metadata": {
Expand All @@ -33,6 +33,6 @@
"path": "mistral",
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3",
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
},
}
190 changes: 69 additions & 121 deletions tools/checkpoint_conversion/convert_mistral_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import gc
import json
import os
import pathlib
import shutil
import tempfile
import traceback

import keras
import numpy as np
import requests
from absl import app
Expand All @@ -27,10 +25,10 @@
from transformers import AutoTokenizer
from transformers import MistralForCausalLM

import keras_nlp
from keras_nlp.models import MistralBackbone
from keras_nlp.models import MistralCausalLMPreprocessor
from keras_nlp.models import MistralTokenizer
from keras_nlp.utils.preset_utils import save_to_preset

PRESET_MAP = {
"mistral_7b_en": "mistralai/Mistral-7B-v0.1",
Expand Down Expand Up @@ -227,124 +225,74 @@ def main(_):
preset = FLAGS.preset
hf_preset = PRESET_MAP[preset]

# === Create the save directories ===
model_dir = pathlib.Path(__file__).parent / f"{preset}"
tokenizer_dir = model_dir / "assets" / "tokenizer"
if not model_dir.exists():
os.makedirs(model_dir)
if not tokenizer_dir.exists():
os.makedirs(tokenizer_dir)
# === Create the temporary save directories ===
temp_dir = tempfile.mkdtemp()

# === Load the Huggingface model ===
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print("\n-> Huggingface model and tokenizer loaded")

# === Load the KerasNLP model ===
keras_nlp_config = dict(
vocabulary_size=hf_model.config.vocab_size,
hidden_dim=hf_model.config.hidden_size,
num_layers=hf_model.config.num_hidden_layers,
num_query_heads=hf_model.config.num_attention_heads,
num_key_value_heads=hf_model.config.num_key_value_heads,
intermediate_dim=hf_model.config.intermediate_size,
sliding_window=hf_model.config.sliding_window,
layer_norm_epsilon=hf_model.config.rms_norm_eps,
rope_max_wavelength=hf_model.config.rope_theta,
dtype="float32",
)
keras_nlp_model = MistralBackbone(**keras_nlp_config)

# === Download the tokenizer from Huggingface model card ===
spm_path = (
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
)
response = requests.get(spm_path)
if not response.ok:
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
tokenizer_path = tokenizer_dir / "vocabulary.spm"
with open(tokenizer_path, "wb") as tokenizer_file:
tokenizer_file.write(response.content)
keras_nlp_tokenizer = MistralTokenizer(str(tokenizer_path.absolute()))
print("\n-> Keras 3 model and tokenizer loaded.")

# === Port the weights ===
convert_checkpoints(keras_nlp_model, hf_model)
print("\n-> Weight transfer done.")

# === Check that the models and tokenizers outputs match ===
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
print("\n-> Tests passed!")

# === Save the model weights in float32 format ===
keras_nlp_model.save_weights(
str((model_dir / "model.weights.h5").absolute())
)
print("\n-> Saved the model weights in float16")

del keras_nlp_model, hf_model
gc.collect()

keras_nlp_config["dtype"] = "float16"

# === Save the weights again in float16 ===
keras_nlp_model = MistralBackbone(**keras_nlp_config)
keras_nlp_model.load_weights(
str((model_dir / "model.weights.h5").absolute())
)
keras_nlp_model.save_weights(
str((model_dir / "model.weights.h5").absolute())
)
print("-> Saved the model weights in float16")

# === Save the model config ===
keras_nlp_config["dtype"] = "bfloat16"
model_config = {
"module": "keras_nlp.src.models.mistral.mistral_backbone",
"class_name": "MistralBackbone",
"config": {**keras_nlp_config},
"registered_name": "keras_nlp>MistralBackbone",
"assets": [],
"weights": "model.weights.h5",
}
model_config_json = json.dumps(model_config)
with open(model_dir / "config.json", "w") as model_config_file:
model_config_file.write(model_config_json)
print("\n-> Saved model config")

# === Save the tokenizer config ===
tokenizer_config = {
"module": "keras_nlp.src.models.mistral.Mistral_tokenizer",
"class_name": "MistralTokenizer",
"config": {
"name": "mistral_tokenizer",
"trainable": True,
"dtype": "int32",
"proto": None,
"sequence_length": None,
},
"registered_name": "keras_nlp>MistralTokenizer",
"assets": ["assets/tokenizer/vocabulary.spm"],
"weights": None,
}
tokenizer_config_json = json.dumps(tokenizer_config)
with open(model_dir / "tokenizer.json", "w") as tokenizer_config_file:
tokenizer_config_file.write(tokenizer_config_json)
print("\n-> Saved tokenizer config")
try:
# === Load the Huggingface model ===
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print("\n-> Huggingface model and tokenizer loaded")

# === Load the KerasNLP model ===
backbone_kwargs = dict(
vocabulary_size=hf_model.config.vocab_size,
hidden_dim=hf_model.config.hidden_size,
num_layers=hf_model.config.num_hidden_layers,
num_query_heads=hf_model.config.num_attention_heads,
num_key_value_heads=hf_model.config.num_key_value_heads,
intermediate_dim=hf_model.config.intermediate_size,
sliding_window=hf_model.config.sliding_window,
layer_norm_epsilon=hf_model.config.rms_norm_eps,
rope_max_wavelength=hf_model.config.rope_theta,
dtype="float32",
)
keras_nlp_model = MistralBackbone(**backbone_kwargs)

# === Save metadata ===
metadata_config = {
"keras_version": keras.__version__,
"keras_nlp_version": keras_nlp.__version__,
"parameter_count": keras_nlp_model.count_params(),
"date_saved": datetime.datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_config_json = json.dumps(metadata_config)
with open(model_dir / "metadata.json", "w") as metadata_config_file:
metadata_config_file.write(metadata_config_json)
print("\n-> Saved metadata")
# === Download the tokenizer from Huggingface model card ===
spm_path = (
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
)
response = requests.get(spm_path)
if not response.ok:
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
tokenizer_path = os.path.join(temp_dir, "vocabulary.spm")
with open(tokenizer_path, "wb") as tokenizer_file:
tokenizer_file.write(response.content)
keras_nlp_tokenizer = MistralTokenizer(tokenizer_path)
print("\n-> Keras 3 model and tokenizer loaded.")

# === Port the weights ===
convert_checkpoints(keras_nlp_model, hf_model)
print("\n-> Weight transfer done.")

# === Check that the models and tokenizers outputs match ===
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
print("\n-> Tests passed!")

# === Save the model weights in float32 format ===
keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
print("\n-> Saved the model weights in float32")

del keras_nlp_model, hf_model
gc.collect()

# === Save the weights again in float16 ===
backbone_kwargs["dtype"] = "float16"
keras_nlp_model = MistralBackbone(**backbone_kwargs)
keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))
save_to_preset(keras_nlp_model, preset)
print("\n-> Saved the model preset in float16")

# === Save the tokenizer ===
save_to_preset(
keras_nlp_tokenizer, preset, config_filename="tokenizer.json"
)
print("\n-> Saved the tokenizer")
finally:
shutil.rmtree(temp_dir)


if __name__ == "__main__":
Expand Down
Loading