Skip to content

Commit

Permalink
Apply function on pytree directly. (#692)
Browse files Browse the repository at this point in the history
* Apply function on pytree directly.

Avoiding unnecssary unpacking

* Fix kwarg
  • Loading branch information
junpenglao authored Jun 5, 2024
1 parent 83bc3a0 commit a4408d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def step(iteration_state, weight_and_key):
x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)
Expand Down
31 changes: 15 additions & 16 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,31 +240,30 @@ def one_step(average_and_state, xs, return_state):


def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
expectation
the value of the expectation at the current timestep
streaming_avg
tuple of (total, average) where total is the sum of weights and average is
the current average
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new streaming average
new total weight and streaming average
"""

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,
)
total += weight
streaming_avg = (total, unravel_fn(average))
return streaming_avg
return current_weight, current_average

0 comments on commit a4408d3

Please sign in to comment.