Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scoring mode to MistralCausalLM #1521

Merged
merged 5 commits into from
Mar 25, 2024

Conversation

RyanMullins
Copy link
Contributor

@RyanMullins RyanMullins commented Mar 22, 2024

Adds the .score() function introduced with Gemma (#1448) to the MistralCausalLM model class. As with Gemma, this function supports a variety of interpretability use cases with Mistral by providing an API by which generated sequences can be scored (logits or loss) with gradient tracking on. Use cases include salience maps, patching, and training data attribution.

This is a direct port of the implementation and tests from Gemma, so hopefully that helps ease the review process.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Mar 22, 2024
@RyanMullins RyanMullins marked this pull request as ready for review March 22, 2024 19:48
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Mar 23, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Mar 23, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks! Just see one tiny nit.

keras_nlp/models/mistral/mistral_causal_lm.py Outdated Show resolved Hide resolved
@mattdangerw
Copy link
Member

Thank you!

@mattdangerw mattdangerw merged commit 8c189ce into keras-team:master Mar 25, 2024
7 checks passed
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
* Add scoring mode to MistralCausalLM

* Fixing names in Docstring

* Fix padding mask arg name

* Fix embedded shape in test

* Remove errant underscore in Docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants