Skip to content

Conversation

@ealt
Copy link
Collaborator

@ealt ealt commented May 23, 2025

No description provided.

@ealt ealt requested a review from hrbigelow May 23, 2025 23:37
@ealt ealt marked this pull request as ready for review May 23, 2025 23:39
Copilot AI review requested due to automatic review settings May 23, 2025 23:39
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new generate_with_obs_dist method to both HMM implementations and adds tests to verify its output shapes.

  • Implements generate_with_obs_dist in GeneralizedHiddenMarkovModel for batch-wise sequence generation with returned observation probabilities.
  • Adds shape-based tests for generate_with_obs_dist in both test_hidden_markov_model.py and test_generalized_hidden_markov_model.py.
  • Extends the existing HMM tests to ensure both state and observation outputs match expected dimensions.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
tests/generative_processes/test_hidden_markov_model.py Added test_generate_with_obs_dist and related shape assertions
tests/generative_processes/test_generalized_hidden_markov_model.py Parametrized test_generate_with_obs_dist for both model fixtures
simplexity/generative_processes/generalized_hidden_markov_model.py Introduced generate_with_obs_dist method with vmapped vectorization
Comments suppressed due to low confidence (1)

tests/generative_processes/test_hidden_markov_model.py:127

  • Consider adding an assertion to verify that intermediate_obs_probs sums to 1 along the vocabulary axis (e.g., jnp.allclose(intermediate_obs_probs.sum(-1), 1.0)), ensuring they form valid probability distributions.
assert intermediate_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)

Comment on lines +75 to +80
"""Generate a batch of sequences of observations from the generative process.
Inputs:
state: (batch_size, num_states)
key: (batch_size, 2)
Returns: tuple of (belief states, observations, observation probabilities) where:
Copy link

Copilot AI May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The docstring describes batch input shapes, but this method is vmapped over a per-sample state and key. Clarify that inputs are single-example (no batch dim) and outputs are vectorized across the batch.

Suggested change
"""Generate a batch of sequences of observations from the generative process.
Inputs:
state: (batch_size, num_states)
key: (batch_size, 2)
Returns: tuple of (belief states, observations, observation probabilities) where:
"""Generate sequences of observations from the generative process.
Inputs (per-sample, no batch dimension):
state: (num_states,)
key: (2,)
Returns (vectorized across the batch):

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@hrbigelow hrbigelow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did just a basic wall-clock profiling comparison between generate_with_obs_dist and my original implementation without using the simplexity API functions. I get a roughly 2.4x speed up. 18 seconds for my implementation, 45 seconds for this one.

I believe this is mostly due to the redundant calculations in your transition_states and observation_probability_distribution functions.

As we discussed, I'm okay if you want to keep things as a reference implementation for purposes of code clarity, but I will probably stick with my side implementation in that case.

https://github.com/Astera-org/simplex-research/blob/henry/mods/in-context-learning/tests/process_gen_profile.py

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants