-
Notifications
You must be signed in to change notification settings - Fork 2
Create generate with obs dist function #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new generate_with_obs_dist method to both HMM implementations and adds tests to verify its output shapes.
- Implements
generate_with_obs_distinGeneralizedHiddenMarkovModelfor batch-wise sequence generation with returned observation probabilities. - Adds shape-based tests for
generate_with_obs_distin bothtest_hidden_markov_model.pyandtest_generalized_hidden_markov_model.py. - Extends the existing HMM tests to ensure both state and observation outputs match expected dimensions.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| tests/generative_processes/test_hidden_markov_model.py | Added test_generate_with_obs_dist and related shape assertions |
| tests/generative_processes/test_generalized_hidden_markov_model.py | Parametrized test_generate_with_obs_dist for both model fixtures |
| simplexity/generative_processes/generalized_hidden_markov_model.py | Introduced generate_with_obs_dist method with vmapped vectorization |
Comments suppressed due to low confidence (1)
tests/generative_processes/test_hidden_markov_model.py:127
- Consider adding an assertion to verify that
intermediate_obs_probssums to 1 along the vocabulary axis (e.g.,jnp.allclose(intermediate_obs_probs.sum(-1), 1.0)), ensuring they form valid probability distributions.
assert intermediate_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)
| """Generate a batch of sequences of observations from the generative process. | ||
| Inputs: | ||
| state: (batch_size, num_states) | ||
| key: (batch_size, 2) | ||
| Returns: tuple of (belief states, observations, observation probabilities) where: |
Copilot
AI
May 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The docstring describes batch input shapes, but this method is vmapped over a per-sample state and key. Clarify that inputs are single-example (no batch dim) and outputs are vectorized across the batch.
| """Generate a batch of sequences of observations from the generative process. | |
| Inputs: | |
| state: (batch_size, num_states) | |
| key: (batch_size, 2) | |
| Returns: tuple of (belief states, observations, observation probabilities) where: | |
| """Generate sequences of observations from the generative process. | |
| Inputs (per-sample, no batch dimension): | |
| state: (num_states,) | |
| key: (2,) | |
| Returns (vectorized across the batch): |
hrbigelow
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did just a basic wall-clock profiling comparison between generate_with_obs_dist and my original implementation without using the simplexity API functions. I get a roughly 2.4x speed up. 18 seconds for my implementation, 45 seconds for this one.
I believe this is mostly due to the redundant calculations in your transition_states and observation_probability_distribution functions.
As we discussed, I'm okay if you want to keep things as a reference implementation for purposes of code clarity, but I will probably stick with my side implementation in that case.

No description provided.