-
Notifications
You must be signed in to change notification settings - Fork 203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
replaced seed
with key
#1167
base: main
Are you sure you want to change the base?
replaced seed
with key
#1167
Conversation
@rdyro 3 month later, but I've finished it) |
Also should close issue #1137 after PR got accepted |
@Tomas542 JAX now recommends using |
@carlosgmartin yes, I know, but there are some problems with this in tests where NumPy checks for type compatibility. The new keys are of type P.S. Accidentally, I've made a mistake in |
The changes look great so far! @carlosgmartin is right that we should move to Let me see if we can fix the optax hyperparameter validation before this PR so that we can stick to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, the changes look great! The PRNGKey fix is in now, so can you rebase and use jax.random.key
in place of PRNGKey
? The exception is chex.PRNGKey
which should stay as is, thanks!
@@ -30,7 +30,7 @@ class HungarianAlgorithmTest(parameterized.TestCase): | |||
m=[0, 1, 2, 4, 8, 16], | |||
) | |||
def test_hungarian_algorithm(self, n, m): | |||
key = jrd.key(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change to PRNGKey now
@@ -91,7 +91,7 @@ def test_hungarian_algorithm(self, n, m): | |||
m=[0, 1, 2, 4], | |||
) | |||
def test_hungarian_algorithm_vmap(self, k, n, m): | |||
key = jrd.key(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change to PRNGKey now
@@ -106,7 +106,7 @@ def test_hungarian_algorithm_vmap(self, k, n, m): | |||
assert j.shape == (k, r) | |||
|
|||
def test_hungarian_algorithm_jit(self): | |||
key = jrd.key(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change to PRNGKey now
optax/contrib/_privacy_test.py
Outdated
l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, seed=0 | ||
l2_norm_clip=jnp.finfo(jnp.float32).max, | ||
noise_multiplier=0.0, | ||
key=jrd.PRNGKey(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jrd.key(0)
please instead of PRNGKey
optax/contrib/_privacy_test.py
Outdated
l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, seed=42 | ||
l2_norm_clip=l2_norm_clip, | ||
noise_multiplier=0.0, | ||
key=jrd.PRNGKey(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jrd.key(0)
please instead of PRNGKey
optax/contrib/_common_test.py
Outdated
'opt_kwargs': {'learning_rate': 1.0, 'eta': 1e-4}, | ||
'opt_kwargs': { | ||
'learning_rate': 1.0, | ||
'key': jax.random.PRNGKey(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.random.key(0)
please instead of PRNGKey
) -> base.GradientTransformation: | ||
"""Aggregates gradients based on the DPSGD algorithm. | ||
|
||
Args: | ||
l2_norm_clip: maximum L2 norm of the per-example gradients. | ||
noise_multiplier: ratio of standard deviation to the clipping norm. | ||
seed: initial seed used for the jax.random.PRNGKey |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.random.key(0)
please instead of PRNGKey
optax/contrib/_privacy.py
Outdated
if key is None: | ||
raise ValueError( | ||
"differentially_private_aggregate optimizer requires specifying key: " | ||
"differentially_private_aggregate(..., key=jax.random.PRNGKey(0))" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.random.key(0)
please instead of PRNGKey
optax/contrib/_privacy.py
Outdated
if key is None: | ||
raise ValueError( | ||
"dpsgd optimizer requires specifying key: " | ||
"dpsgd(..., key=jax.random.PRNGKey(0))" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.random.key(0)
please instead of PRNGKey
optax/transforms/_adding_test.py
Outdated
seed = 314 | ||
noise = _adding.add_noise(eta, gamma, seed) | ||
noise_unit = _adding.add_noise(1.0, 0.0, seed) | ||
key = jax.random.PRNGKey(314) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.random.key(0)
please instead of PRNGKey
Hm, could someone help with docs? I can't understand what's the problem with tests. I know that |
Sorry, the doc failures is on our side (our CI is briefly broken :( ), this should be fixed soon and you can rerun your tests |
ok, but what about error in state utils? |
The new JAX key implementation key is not directly comparable, for now, to compare states containing random keys you need this new tree_util in main (you'll need to rebase your branch): then in from optax.tree_utils import _random
...
chex.assert_trees_all_equal(
_random.tree_unwrap_random_key_data(noise_state),
_random.tree_unwrap_random_key_data(expected_result)
)
...
chex.assert_trees_all_equal(
_random.tree_unwrap_random_key_data(new_state),
_random.tree_unwrap_random_key_data(expected_result)
) |
Replaced all seed values with key for uniformity of style with
jax.random