diff --git a/simplexity/generative_processes/arithmetic_process.py b/simplexity/generative_processes/arithmetic_process.py index c6c2dd49..7f3a3c30 100644 --- a/simplexity/generative_processes/arithmetic_process.py +++ b/simplexity/generative_processes/arithmetic_process.py @@ -92,18 +92,116 @@ def child_sub_equation(self, sub_equation: jax.Array) -> tuple[int, jax.Array]: def full_equation(self, sub_equation: jax.Array, n: int, sequence_len: int) -> jax.Array: """Produce a random RPN sequence.""" + # Maximum number of iterations we expect + max_iterations = 10 + + # Precompute all child sub-equations using jax.lax.scan + def scan_fn(carry, _): + current_sub_equation, current_n = carry + should_compute = current_n > 1 + + # Use jax.lax.cond for conditional execution + next_n, next_sub_equation = jax.lax.cond( + should_compute, + lambda: self.child_sub_equation(current_sub_equation), + lambda: (jnp.int32(1), jnp.full_like(current_sub_equation, self.tokens[SpecialTokens.PAD.value])) + ) + + return (next_sub_equation, next_n), (next_sub_equation, next_n) + + # Initialize scan + init_carry = (sub_equation, n) + _, (sub_equations_scan, lengths_scan) = jax.lax.scan( + scan_fn, init_carry, None, length=max_iterations-1 + ) + + # Combine initial with scanned results + sub_equations = jnp.concatenate([ + sub_equation[None, :], + sub_equations_scan + ], axis=0) + lengths = jnp.concatenate([ + jnp.array([n]), + lengths_scan + ]) + + # Build equation using vectorized operations equation = jnp.full(sequence_len, self.tokens[SpecialTokens.PAD.value]) equation = equation.at[0].set(self.tokens[SpecialTokens.BOE.value]) - i = 1 - equation = equation.at[i : i + n].set(sub_equation) - i += n - while n > 1: - equation = equation.at[i].set(self.tokens[SpecialTokens.EQL.value]) - i += 1 - n, sub_equation = self.child_sub_equation(sub_equation) - equation = equation.at[i : i + n].set(sub_equation[:n]) - i += n - equation = equation.at[i].set(self.tokens[SpecialTokens.EOE.value]) + + # Calculate positions using scan + def position_scan_fn(prev_pos, i): + # Each position = previous position + previous length + 1 (for equals sign if needed) + prev_len = lengths[i-1] + prev_valid = lengths[i-1] > 0 + current_valid = lengths[i] > 0 + # Only add equals sign space if both previous and current are valid + should_add_equals = prev_valid & current_valid + new_pos = jax.lax.select(should_add_equals, prev_pos + prev_len + 1, prev_pos + prev_len) + return new_pos, new_pos + + init_pos = 1 # Start after BOE token + _, positions_scan = jax.lax.scan( + position_scan_fn, init_pos, jnp.arange(1, max_iterations) + ) + positions = jnp.concatenate([jnp.array([init_pos]), positions_scan]) + + # Place sub-equations using vectorized operations + def place_subequation(equation, i): + pos = positions[i] + length = lengths[i] + valid = length > 0 + sub_eq = sub_equations[i] + + # Create a mask for which elements of the sub-equation to place + indices = jnp.arange(sub_equation.shape[0]) + element_mask = (indices < length) & valid + + # Update equation element by element + def update_single_element(j, eq): + should_update = element_mask[j] & (pos + j < sequence_len) + target_idx = pos + j + return jax.lax.select( + should_update, + eq.at[target_idx].set(sub_eq[j]), + eq + ) + + equation = jax.lax.fori_loop(0, sub_equation.shape[0], update_single_element, equation) + + # Place equals sign + next_valid = (i + 1 < max_iterations) & (lengths[i + 1] > 0) + eq_pos = pos + length + should_place_equals = valid & next_valid & (eq_pos < sequence_len) + + equation = jax.lax.select( + should_place_equals, + equation.at[eq_pos].set(self.tokens[SpecialTokens.EQL.value]), + equation + ) + + return equation + + # Apply placement for all sub-equations + equation = jax.lax.fori_loop( + 0, max_iterations, + lambda i, eq: place_subequation(eq, i), + equation + ) + + # Place EOE token + valid_mask = lengths > 0 + last_valid_idx = jnp.sum(valid_mask) - 1 + last_pos = positions[last_valid_idx] + last_len = lengths[last_valid_idx] + eoe_pos = last_pos + last_len + + equation = jax.lax.select( + eoe_pos < sequence_len, + equation.at[eoe_pos].set(self.tokens[SpecialTokens.EOE.value]), + equation + ) + return equation def random_equation(self, key: chex.PRNGKey, k: int, sequence_len: int) -> jax.Array: