Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced seed with key in add_noise and noisy_sgd #1138

Closed
wants to merge 8 commits into from

Conversation

Tomas542
Copy link

Replaced seed in add_noise with key in favor of jax.random-like style.
Added (duplicated) example with add_noise from noisy_sgd.
Changed seed to key in noisy_sgd. Imported chex for annotation purpose in _alias.py

Copy link

google-cla bot commented Nov 13, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@rdyro
Copy link
Collaborator

rdyro commented Nov 14, 2024

These changes look good! Can you sign the google CLA to be able to run unit tests?

I like the change of placing the key first in the argument list. Can you take a look at these files as well:

It'd be good to rename the argument to key from seed in all these places and move the key argument to the beginning of the argument list - this might require touching up a couple of tests.

@Tomas542
Copy link
Author

I signed the CLA yesterday and went to bed.
It's a great idea, but I think I'll do it this weekend.
I also have a question (sorry if it's stupid, I've never done Pull Requests before): if I make a new commit with these changes to my branch - do they automatically go into this PR? Or is there some button I have to press?

@rdyro
Copy link
Collaborator

rdyro commented Nov 14, 2024

Awesome

if I make a new commit with these changes to my branch - do they automatically go into this PR? Or is there some button I have to press?

Not at all, thanks for the contribution! Exactly, any additional commits in your branch will be added to the PR here, in the end, you can also squish your commits into one before merging with something like:

$ git rebase -i main
$ # change all but one commit to "squish" leaving one of them (probably the first one) as "pick"
$ # if everything went well
$ git push -f (to overwrite history, which is necessary when squishing commits)

@Tomas542
Copy link
Author

Ok, and one more question - there are a lot of things that will change. Maybe we'll add key as an optional argument with an DeprecationWarning if seed is set, like Warning: argument seed will be removed (or replaced in cases where seed already PRNGKey) in the next release. Use key instead. But in this case key will move from the first argument, and we will have to make it of type Optional for now. And in the next-next major release I will remove/replace seed with key.
Or is it better to change everything now? What do you think?

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small comment about the signature of the optimizer. Otherwise I fully agree with Robert that a full pass would be great. Could be in a separate PR if needed.

@@ -1253,10 +1253,10 @@ def lamb(


def noisy_sgd(
key: chex.PRNGKey,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still keep learning_rate first as in other optimizers for this function, and I would give a default value for the key for convenience, like key=jax.random.PNRGKey(0).
Sometimes one wants to benchmark quickly optimizers and it seems easier if they all have a similar signature, at least for the first argument.
On the other hand, for the transform, (add_noise) and other functions Robert pointed out, I agree that it would make sense to have the key first.

@Tomas542
Copy link
Author

Hi, haven't had time to make changes.
I agree with the positions of the arguments.
But I don't agree that the key should be set to the default value for two reasons:

  • At this point, this parameter does not have a default value yet. So not having a value will not “slow down” optimizer testing.
  • JAX (as far as I know) never sets default values for keys, this is more specific to libraries such as datasets.
    So perhaps “don't set a default value” is a better option?

@rdyro
Copy link
Collaborator

rdyro commented Nov 18, 2024

A default key value might be confusing giving determinism where people don't expect it, but noisy_sgd is a first-order optimizer and argument order consistency is definitely a value. Maybe we can keep key last, but required.

Something like:

def noisy_sgd(
    learning_rate: base.ScalarOrSchedule,
    eta: float = 0.01,
    gamma: float = 0.55,
    key: chex.PRNGKey | None = None,
) -> base.GradientTransformation:
  if key is None:
    raise ValueError("noisy_sgd optimizer requires specifying random key: noisy_sgd(..., key=random.key(0))")

@Tomas542 @vroulet what do you think?

@vroulet
Copy link
Collaborator

vroulet commented Nov 18, 2024

I like this idea. It may also smooth out backward compatibility issues with clear raised errors.

@rdyro
Copy link
Collaborator

rdyro commented Nov 20, 2024

A default key value might be confusing giving determinism where people don't expect it, but noisy_sgd is a first-order optimizer and argument order consistency is definitely a value. Maybe we can keep key last, but required.

Something like:

def noisy_sgd(
    learning_rate: base.ScalarOrSchedule,
    eta: float = 0.01,
    gamma: float = 0.55,
    key: chex.PRNGKey | None = None,
) -> base.GradientTransformation:
  if key is None:
    raise ValueError("noisy_sgd optimizer requires specifying random key: noisy_sgd(..., key=random.key(0))")

@Tomas542 @vroulet what do you think?

@Tomas542 could you add this change?

@Tomas542
Copy link
Author

Tomas542 commented Nov 22, 2024

Yeap, I can do this. Also I will try to change tests.

To summarize, we decided to replace all seed with key, make key the default value None and raise an error if it is None. As for the order of function arguments, we put key in the last position for optimizers and in the first position in some other functions, such as add_noise?

UPD: Also I would create another PR after we finish this discussion. And annotation would be Optional, cause a | b doesn't support in Python3.9

@vroulet
Copy link
Collaborator

vroulet commented Nov 22, 2024

Yes, good point for the annotation. And yes for the summary!

@Tomas542 Tomas542 closed this Nov 30, 2024
@Tomas542
Copy link
Author

Now we can move to #1145

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants