Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 108 additions & 10 deletions simplexity/generative_processes/arithmetic_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,116 @@ def child_sub_equation(self, sub_equation: jax.Array) -> tuple[int, jax.Array]:

def full_equation(self, sub_equation: jax.Array, n: int, sequence_len: int) -> jax.Array:
"""Produce a random RPN sequence."""
# Maximum number of iterations we expect
max_iterations = 10

# Precompute all child sub-equations using jax.lax.scan
def scan_fn(carry, _):
current_sub_equation, current_n = carry
should_compute = current_n > 1

# Use jax.lax.cond for conditional execution
next_n, next_sub_equation = jax.lax.cond(
should_compute,
lambda: self.child_sub_equation(current_sub_equation),
lambda: (jnp.int32(1), jnp.full_like(current_sub_equation, self.tokens[SpecialTokens.PAD.value]))
)

return (next_sub_equation, next_n), (next_sub_equation, next_n)

# Initialize scan
init_carry = (sub_equation, n)
_, (sub_equations_scan, lengths_scan) = jax.lax.scan(
scan_fn, init_carry, None, length=max_iterations-1
)

# Combine initial with scanned results
sub_equations = jnp.concatenate([
sub_equation[None, :],
sub_equations_scan
], axis=0)
lengths = jnp.concatenate([
jnp.array([n]),
lengths_scan
])

# Build equation using vectorized operations
equation = jnp.full(sequence_len, self.tokens[SpecialTokens.PAD.value])
equation = equation.at[0].set(self.tokens[SpecialTokens.BOE.value])
i = 1
equation = equation.at[i : i + n].set(sub_equation)
i += n
while n > 1:
equation = equation.at[i].set(self.tokens[SpecialTokens.EQL.value])
i += 1
n, sub_equation = self.child_sub_equation(sub_equation)
equation = equation.at[i : i + n].set(sub_equation[:n])
i += n
equation = equation.at[i].set(self.tokens[SpecialTokens.EOE.value])

# Calculate positions using scan
def position_scan_fn(prev_pos, i):
# Each position = previous position + previous length + 1 (for equals sign if needed)
prev_len = lengths[i-1]
prev_valid = lengths[i-1] > 0
current_valid = lengths[i] > 0
# Only add equals sign space if both previous and current are valid
should_add_equals = prev_valid & current_valid
new_pos = jax.lax.select(should_add_equals, prev_pos + prev_len + 1, prev_pos + prev_len)
return new_pos, new_pos

init_pos = 1 # Start after BOE token
_, positions_scan = jax.lax.scan(
position_scan_fn, init_pos, jnp.arange(1, max_iterations)
)
positions = jnp.concatenate([jnp.array([init_pos]), positions_scan])

# Place sub-equations using vectorized operations
def place_subequation(equation, i):
pos = positions[i]
length = lengths[i]
valid = length > 0
sub_eq = sub_equations[i]

# Create a mask for which elements of the sub-equation to place
indices = jnp.arange(sub_equation.shape[0])
element_mask = (indices < length) & valid

# Update equation element by element
def update_single_element(j, eq):
should_update = element_mask[j] & (pos + j < sequence_len)
target_idx = pos + j
return jax.lax.select(
should_update,
eq.at[target_idx].set(sub_eq[j]),
eq
)

equation = jax.lax.fori_loop(0, sub_equation.shape[0], update_single_element, equation)

# Place equals sign
next_valid = (i + 1 < max_iterations) & (lengths[i + 1] > 0)
eq_pos = pos + length
should_place_equals = valid & next_valid & (eq_pos < sequence_len)

equation = jax.lax.select(
should_place_equals,
equation.at[eq_pos].set(self.tokens[SpecialTokens.EQL.value]),
equation
)

return equation

# Apply placement for all sub-equations
equation = jax.lax.fori_loop(
0, max_iterations,
lambda i, eq: place_subequation(eq, i),
equation
)

# Place EOE token
valid_mask = lengths > 0
last_valid_idx = jnp.sum(valid_mask) - 1
last_pos = positions[last_valid_idx]
last_len = lengths[last_valid_idx]
eoe_pos = last_pos + last_len

equation = jax.lax.select(
eoe_pos < sequence_len,
equation.at[eoe_pos].set(self.tokens[SpecialTokens.EOE.value]),
equation
)

return equation

def random_equation(self, key: chex.PRNGKey, k: int, sequence_len: int) -> jax.Array:
Expand Down
Loading