Skip to content

Conversation

ShuxiangCao
Copy link
Contributor

No description provided.

@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@ShuxiangCao
Copy link
Contributor Author

@googlebot I signed it!

1 similar comment
@ShuxiangCao
Copy link
Contributor Author

@googlebot I signed it!

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

Copy link
Collaborator

@MichaelBroughton MichaelBroughton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great first pass. Very well documented! Couple of comments and questions about your implementation:

  1. You've implemented the __init__ logic of exposing rotosolve_minimize but there is no test for ensuring the api exposure is correct. Can you add a line in import_test.py to verify you are exposing this optimizer correctly ?

  2. We have rotosolve_one_parameter_once that gets called by rotosolve_all_parameters_once. Do we need the functionality of rotosolve_one_parameter_once or can be it be removed and simplified ?

minimize as rotosolve_minimize)

# Utils for optimizers.
from tensorflow_quantum.python.optimizers.utils import (function_factory)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is function factory something you want a user to use ?

Copy link
Contributor Author

@ShuxiangCao ShuxiangCao May 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the user needs to wrap their model with this function_factory to expose the trainable parameters.

# limitations under the License.
# ==============================================================================
"""The rotosolve minimization algorithm
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: One line module description.

Copy link
Contributor Author

@ShuxiangCao ShuxiangCao May 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Fixed in a8ea8a7

Comment on lines 52 to 70
'converged', # Scalar boolean tensor indicating whether the minimum
# was found within tolerance.
'num_iterations', # The number of iterations of the BFGS update.
# The total number of objective evaluations performed.
'num_objective_evaluations',
'position', # A tensor containing the last argument value found
# during the search. If the search converged, then
# this value is the argmin of the objective function.
# A tensor containing the value of the objective from previous iteration
'last_objective_value',
'objective_value', # A tensor containing the value of the objective
# function at the `position`. If the search
# converged, then this is the (local) minimum of
# the objective function.
'tolerance', # Define the stop criteria. Iteration will stop when the
# objective value difference between two iterations is smaller than
# tolerance
'solve_param_i', # The parameter index where rotosolve is currently
# modifying. Reserved for internal use.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting is a little hard to follow here. Could you make sure that the comments describing each element are always above the variable they are describing instead of sometimes being to the right and sometimes on top.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please use the same comment convention (including indentation) with that of TFP (here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 37e926f

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized the format checker doesn't happy with the indent like this. Should we change the configuration? @MichaelBroughton @jaeyoo

The following example demonstrates the Rotosolve optimizer attempting
to find the minimum for two qubit ansatz expectation value.

```python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really good useage example. It is a little too complex for something that we want on the API docs. Do you think we could shorten things up a little bit ?

Also can you use >>> and ... like in https://github.com/tensorflow/quantum/blob/master/tensorflow_quantum/python/layers/high_level/pqc.py ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 41b3161. Any changes suggested?


Args:
expectation_value_function: A Python callable that accepts
a point as a real `Tensor` and returns a tuple of `Tensor`s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: tf.Tensor. Can you also outline the shape requirements for the input as well as the returned shape you get from the output ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in f2b15d4

tolerance: Scalar `Tensor` of real dtype. Specifies the gradient
tolerance for the procedure. If the supremum norm of the gradient
vector is below this number, the algorithm is stopped.
name: (Optional) Python str. The name prefixed to the ops created
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Python str

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 25a75df

Copy link
Member

@jaeyoo jaeyoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding the first TFQ optimizer! I left some comments.

Comment on lines 26 to 32
"""Return static shape of tensor `x` if available,
else `tf.shape(x)`.
Args:
x: `Tensor` (already converted).
Returns:
Numpy array (if static shape is obtainable), else `Tensor`.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using "four space" indent in comments, with a one-line description header with period.
e.g.

"""Static shape of a given tensor `x`.

Return static shape of tensor `x` if available, else `tf.shape(x)`.

Args:
    x: A `tf.Tensor` of which we want to know the shape.
Returns:
    A numpy array (if static shape is obtainable), else `Tensor`
    containing a shape of `x`.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed in 112400e

Comment on lines 37 to 42
"""Return static value of tensor `x` if available, else `x`.
Args:
x: `Tensor` (already converted).
Returns:
Numpy array (if static value is obtainable), else `Tensor`.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 112400e

Comment on lines 159 to 161
Args:
expectation_value_function: A Python callable that accepts
a point as a real `Tensor` and returns a tuple of `Tensor`s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 112400e

point, or points when using batching dimensions, of the search
procedure. At these points the function value and the gradient
norm should be finite.
tolerance: Scalar `Tensor` of real dtype. Specifies the gradient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 25a75df

Comment on lines 198 to 201
def _rotosolve_one_parameter_once(state):
"""
Rotosolve a single parameter.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This internal function is hard to test. could you bring it out and write a test code?
  2. one line description + arguments description please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this function is written here purely to help people read it, and it's a bit difficult to extract it out since too many variables are used in this closure. Also, the rest of the code is nothing but iterate over parameters and calling this function. Do you think it is okay to leave it as an internal function?

Copy link
Contributor Author

@ShuxiangCao ShuxiangCao May 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is fixed in 64f76c5

Comment on lines 28 to 34
"""Tests for optimizer utils"""

def test_function_factory(self):
"""Test the function_factory"""

class LinearModel(object):
""" A simple tensorflow linear model"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

period please, and please trim the meaningless white space in front of sentence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 564256f


result = rotosolve_minimizer.minimize(
rotosolve_minimizer.function_factory(
model, hinge_loss, x_circ, Y), np.random.rand([2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little bit afraid of this kind of new and independent optimizer design. I understand that this design is conventional in TF Probability, but Tensorflow Keras provides its own Optimizer class to be inherited (e.g. see AdamOptimizer here).

To me, TFP's optimizer design seems like Keras' optimizer + custom trainer. I think it can't be runnable with keras model fit. I feel it's ok that it works well for saving weights and reloading to use the model after training, but I am not sure for the saving models with keras save function in the future. (it can save optimizer together)

What do you think of, @MichaelBroughton?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making a new Keras Optimizers requires that optimizer to be gradient based. This new approach lets us depart from those standards. Rotosolve doesn't calculate gradients, just like nelder-mead in TFP doesn't use gradients.

You're right that it can't be runnable with keras model fit, but this seems to be the best workaround. I could ask the TFP people about this too. What do you think ?

Copy link
Member

@jaeyoo jaeyoo May 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've checked gradient-free optimizers around several libraries, it is natural to implement gradient-free optimizers in an independent design. I thought the optimizer was designed enough to support both, but I was wrong. Then I agree with TFP's approach for the seamless ecosystem :) Thank you for your idea!

is also of shape `[..., n]` like the input value to the function.
This must be a linear combination of quantum measurement
expectation value, otherwise this algorithm cannot work.
initial_position: Real `Tensor` of shape `[..., n]`. The starting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 25a75df

Comment on lines 52 to 70
'converged', # Scalar boolean tensor indicating whether the minimum
# was found within tolerance.
'num_iterations', # The number of iterations of the BFGS update.
# The total number of objective evaluations performed.
'num_objective_evaluations',
'position', # A tensor containing the last argument value found
# during the search. If the search converged, then
# this value is the argmin of the objective function.
# A tensor containing the value of the objective from previous iteration
'last_objective_value',
'objective_value', # A tensor containing the value of the objective
# function at the `position`. If the search
# converged, then this is the (local) minimum of
# the objective function.
'tolerance', # Define the stop criteria. Iteration will stop when the
# objective value difference between two iterations is smaller than
# tolerance
'solve_param_i', # The parameter index where rotosolve is currently
# modifying. Reserved for internal use.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please use the same comment convention (including indentation) with that of TFP (here)



def _get_initial_state(initial_position, tolerance, expectation_value_function):
"""Create RotosolveOptimizerResults with initial state of search ."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove white space between the last character and period.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 564256f

Copy link
Member

@jaeyoo jaeyoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for confusing, I request changes :)

@ShuxiangCao
Copy link
Contributor Author

Thank you very much. I'll fix these issues soon.

@ShuxiangCao

This comment has been minimized.

@ShuxiangCao
Copy link
Contributor Author

ShuxiangCao commented May 31, 2020

Great first pass. Very well documented! Couple of comments and questions about your implementation:

  1. You've implemented the __init__ logic of exposing rotosolve_minimize but there is no test for ensuring the api exposure is correct. Can you add a line in import_test.py to verify you are exposing this optimizer correctly ?
  2. We have rotosolve_one_parameter_once that gets called by rotosolve_all_parameters_once. Do we need the functionality of rotosolve_one_parameter_once or can be it be removed and simplified ?

Thanks for the comment again.

  1. Done!
  2. I abstracted rotosolve_one_parameter_once purely to make it easier to read. The rotosolve simply iterate over the parameters and call this function many times. I feel like it's better to leave it as an internal function, however, I am happy to embed it if people feel it is necessary.

@ShuxiangCao
Copy link
Contributor Author

Thanks for the review, I have made most of the changes required. Can I have a bit more in detail how can I implement point 3 in best practice?

  1. More sanity check tests to ensure that the number of function evalutions == max_iterations when the function terminates and didn't converge or if we set tolerance very high we do indeed terminate with function evaluations < max_iterations etc.

Also, I found the tests seems not working properly. I found some error when I execute the tests manually but then I try to run it with bazel it just simply passed. Can you help me have a look if I made some mistake which causes the tests got ignored?

@MichaelBroughton
Copy link
Collaborator

More sanity check tests to ensure that the number of function evalutions == max_iterations when the function terminates and didn't converge or if we set tolerance very high we do indeed terminate with function evaluations < max_iterations etc.

This would mean adding at least two more tests:

  1. Running a test where when the rotosolve procedure terminates because evaluations == max_iterations. Then check that these two quantities are indeed equal.

  2. Writing a test with a simple optimization problem with a large tolernace like 1e-2 and then ensure that rotosolve can terminate before reaching max_iterations.

Regarding the tests not working, can you post a small snippet or add in the changes with a comment to indicate which ones are not working ?

@ShuxiangCao
Copy link
Contributor Author

ShuxiangCao commented Sep 7, 2020

Regarding the tests not working, can you post a small snippet or add in the changes with a comment to indicate which ones are not working ?

Please find all the changes in 25ecc76 . Before this commit the code should yield error since I called a function that does not even exist tf.sum and the syntax was wrong for scatter_update. Once I execute examples manually these problems got all exposed. However, when I was running it with bazel they were not spotted.

@ShuxiangCao
Copy link
Contributor Author

ShuxiangCao commented Sep 7, 2020

  1. Running a test where when the rotosolve procedure terminates because evaluations == max_iterations. Then check that these two quantities are indeed equal.

A new test called test_nonlinear_function_optimization did this check.

  1. Writing a test with a simple optimization problem with a large tolernace like 1e-2 and then ensure that rotosolve can terminate before reaching max_iterations.

I have changed the ordinary test_function_optimization to do this check. In fact, it takes only 3 or 4 iterations to converge for the examples shown here.

@MichaelBroughton
Copy link
Collaborator

I'm guessing your tests aren't working because you don't have this at the bottom of your test module:

if __name__ == "__main__":
    tf.test.main()

Could you give that a try ?

@ShuxiangCao
Copy link
Contributor Author

I'm guessing your tests aren't working because you don't have this at the bottom of your test module:

if __name__ == "__main__":
    tf.test.main()

Could you give that a try ?

Thanks for the help. I have fixed all the tests.

Copy link
Collaborator

@MichaelBroughton MichaelBroughton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alrighty, this is ready to merge! Once we get these last round of minor nits changed, I can go ahead and merge this in :D

Comment on lines 16 to 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these needed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in baf5e99


arXiv:1905.09692

### Usage:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still have ### Usage: . We can just get rid of that altogether.

Comment on lines 117 to 121
The following example demonstrates the Rotosolve optimizer attempting
to find the minimum for two qubit ansatz expectation value.

Here is an example of optimize a function which consists summation of
a few sinusoids.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want both of these descriptions or just one ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the first line in e9da877

Comment on lines 123 to 125
>>> import tensorflow_quantum as tfq
>>> import numpy as np
>>> import tensorflow as tf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need these in the snippet, we assume people have done this already :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e9da877

>>> min_value = -tf.math.reduce_sum(tf.abs(coefficient))
>>> func = lambda x:tf.math.reduce_sum(tf.sin(x) * coefficient)
>>> # Optimize the function with rotosolve, start with random parameters
>>> result = tfq.optimizers.rotosolve_minimizer.minimize( \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need the \ can just do result = tfq.optimizers.rotosolve_minimizer.minimize(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also shouldn't the function call be tfq.optimizers.rotosolve_minimize( and not tfq.optimizers.rotosolve_minimizer.minimize( ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed both in e9da877

Comment on lines 133 to 136
>>> print(result.converged)
tf.Tensor(True, shape=(), dtype=bool)
>>> print(result.objective_value)
tf.Tensor(-4.7045116, shape=(), dtype=float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the doctests snippets we try to prefer just giving one output, and not using the print statement unless we really need to, would you mind changing these lines to just:

>>> result.objective_value
tf.Tensor(-4.7045116, shape=(), dtype=float32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e9da877

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

math.pi is used here but np.pi is used in the tests. Could we change this module to also use np.pi just for consistency ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e9da877

@ShuxiangCao
Copy link
Contributor Author

Alrighty, this is ready to merge! Once we get these last round of minor nits changed, I can go ahead and merge this in :D

Thanks for the review! I have fixed all these problems.

Copy link
Collaborator

@MichaelBroughton MichaelBroughton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ready to merge :D

@MichaelBroughton MichaelBroughton merged commit 5e5ed34 into tensorflow:master Sep 14, 2020
therooler pushed a commit to therooler/quantum that referenced this pull request Oct 6, 2020
* Implement rotosolve optimizers

* Fix import issue

* Fix comments layout

* Modify rotosolve example docstring

* Remove meaningless whitespaces

* Correct types in docstring

* Change docstring indents

* Trying to fix pytest faliure

* Remove extra spances and change docstring layout

* Fix typo in utils_test.py

* Check the input / output shape requirement for rotosolve expectation value function

* Add docstring for internal functions

* Fix formats

* Fix dependency

* Fix import

* Modify docstring layout

* Fix format

* Remove function factory (tensorflow#1)

Remove function factory.

* Fix format

* Fix lint

* Fix tests

* Remove utils

* Remove utils

* Fix format and typos

* Update examples

* Fix lint

* Fix typo

* Switch to scatter_nd

* Update rotosolve test

* Reformat code

* Update tests and comments

* Update test

* Fix lint

* fix lint

* Modify build docs, modify comments, change test to use built-in hinge loss

* Fix format of build_docs.py

* Change name of object value from previous iteration

* Fix comments examples

* Fix lint

* Fix lint

* Test tolerance criteria

* Fix format

* Fix tests

* Fix format

* Fix format

* Change test

* Fix tests

* Fix typo in comment

* Fix minor problems

* Fix lint

* Remove unused input

* Update comment
jaeyoo pushed a commit to jaeyoo/quantum that referenced this pull request Mar 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants