Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral kills the process by taking too many RAM #1458

Closed
deep-diver opened this issue Feb 22, 2024 · 2 comments · Fixed by #1460
Closed

Mistral kills the process by taking too many RAM #1458

deep-diver opened this issue Feb 22, 2024 · 2 comments · Fixed by #1460
Labels
type:Bug Something isn't working

Comments

@deep-diver
Copy link

preprocessor = keras_nlp.models.MistralCausalLMPreprocessor.from_preset(
    "mistral_instruct_7b_en",
    sequence_length=128,
)
mistral_lm = keras_nlp.models.MistralCausalLM.from_preset(
    "mistral_instruct_7b_en", preprocessor=preprocessor
)

output = mistral_lm.generate("My trip to Yosemite was", max_length=64)
print("\nMistral output:")
print(output)
Screenshot 2024-02-23 at 3 01 32 AM

I was running Mistral model on Colab environment w/ A100(40GB) and 80GB RAM. I loaded up the model successfully. However, when generate text, the RAM usage hit the peak, and the runtime got restarted.

Is this an expected behavior? or could there be bugs?

@deep-diver deep-diver added the type:Bug Something isn't working label Feb 22, 2024
@tirthasheshpatel
Copy link
Contributor

Is this an expected behavior? or could there be bugs?

Mistral loads in bfloat16 by default. I noticed this caused issues in Colab with the JAX backbend (TensorFlow runs fine though with 16.5 GB of RAM).

The workaround to use the dtype set using keras.mixed_precision module is to pass dtype=None to the from_preset method:

mistral_lm = keras_nlp.models.MistralCausalLM.from_preset('mistral_instruct_7b_en', preprocessor=preprocessor, dtype=None)

Let me know if this lowers the RAM usage. This will be fixed in the next release.

@mattdangerw
Copy link
Member

Thanks for the bug! Just synced up with @tirthasheshpatel. We want to change two things here

  1. By default, mistral should follow global keras default settings. So keras.mixed_precision.set_global_policy("mixed...") -> variables load as float32. keras.config.set_floatx("bfloat16") -> variables load at bfloat16.
  2. There is a bug with the jax backend only, where generation for mistral is consuming significantly too much CPU and GPU memory. It's a one liner fix on our side I think.

These are both simple but important fixes, we should have a patch fix for this in a couple days. Thanks @deep-diver!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:Bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants