-
Couldn't load subscription status.
- Fork 628
Implement rotosolve optimizers #247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
b67c7f8
Implement rotosolve optimizers
ShuxiangCao a8ea8a7
Fix import issue
ShuxiangCao 37e926f
Fix comments layout
ShuxiangCao 41b3161
Modify rotosolve example docstring
ShuxiangCao 564256f
Remove meaningless whitespaces
ShuxiangCao 25a75df
Correct types in docstring
ShuxiangCao 112400e
Change docstring indents
ShuxiangCao 26403bd
Trying to fix pytest faliure
ShuxiangCao 7f8bfb4
Remove extra spances and change docstring layout
ShuxiangCao 490331f
Fix typo in utils_test.py
ShuxiangCao f2b15d4
Check the input / output shape requirement for rotosolve expectation …
ShuxiangCao 64f76c5
Add docstring for internal functions
ShuxiangCao f8ec7fe
Fix formats
ShuxiangCao 5f5508c
Fix dependency
ShuxiangCao 62105dc
Fix import
ShuxiangCao aa74043
Modify docstring layout
ShuxiangCao 056c72c
Fix format
ShuxiangCao 3c58174
Merge branch 'master' into master
ShuxiangCao feef330
Merge branch 'master' into master
ShuxiangCao 21345ed
Remove function factory (#1)
ShuxiangCao 90a5b93
Fix format
ShuxiangCao 0e51e5c
Fix lint
ShuxiangCao a19595f
Fix tests
ShuxiangCao c0e3ab6
Remove utils
ShuxiangCao 150e6b3
Remove utils
ShuxiangCao b63b14b
Merge branch 'master' into master
ShuxiangCao 2622d64
Merge branch 'master' into master
ShuxiangCao 5d0d1f4
Fix format and typos
ShuxiangCao ca22374
Update examples
ShuxiangCao 8467538
Fix lint
ShuxiangCao 1807087
Fix typo
ShuxiangCao c0aadd3
Merge branch 'master' into master
ShuxiangCao f28afca
Merge branch 'master' into master
ShuxiangCao 8aee2e6
Switch to scatter_nd
ShuxiangCao a5f0b76
Update rotosolve test
ShuxiangCao 93ffa53
Reformat code
ShuxiangCao a559b54
Update tests and comments
ShuxiangCao c0ab8e8
Update test
ShuxiangCao 2abe7b5
Merge branch 'master' into master
ShuxiangCao 223d597
Fix lint
ShuxiangCao 163a156
fix lint
ShuxiangCao ca95830
Merge branch 'master' into master
ShuxiangCao 4053d74
Modify build docs, modify comments, change test to use built-in hinge…
ShuxiangCao 01714d3
Fix format of build_docs.py
ShuxiangCao 10cee94
Change name of object value from previous iteration
ShuxiangCao 25ecc76
Fix comments examples
ShuxiangCao 8813c35
Fix lint
ShuxiangCao e4ea224
Fix lint
ShuxiangCao 9c4fe84
Merge branch 'master' into master
ShuxiangCao 5535067
Test tolerance criteria
ShuxiangCao 3ba151c
Fix format
ShuxiangCao 72b43bc
Merge branch 'master' into master
ShuxiangCao 3bc4bf2
Fix tests
ShuxiangCao 5004e92
Fix format
ShuxiangCao 52984ca
Fix format
ShuxiangCao 8e1d701
Merge branch 'master' of github.com:SachinCompton/quantum
ShuxiangCao f4ddd7d
Change test
ShuxiangCao 25c1f65
Fix tests
ShuxiangCao 2a425e7
Fix typo in comment
ShuxiangCao 1a40bf9
Merge branch 'master' into master
ShuxiangCao e9da877
Fix minor problems
ShuxiangCao fef587c
Fix lint
ShuxiangCao baf5e99
Remove unused input
ShuxiangCao 5217868
Update comment
ShuxiangCao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| package(default_visibility = ["//visibility:public"]) | ||
|
|
||
| licenses(["notice"]) | ||
|
|
||
| # Export for the PIP package. | ||
| exports_files(["__init__.py"]) | ||
|
|
||
| py_library( | ||
| name = "rotosolve_minimizer", | ||
| srcs = ["rotosolve_minimizer.py"], | ||
| ) | ||
|
|
||
| py_test( | ||
| name = "rotosolve_minimizer_test", | ||
| srcs = ["rotosolve_minimizer_test.py"], | ||
| python_version = "PY3", | ||
| deps = [ | ||
| ":rotosolve_minimizer", | ||
| "//tensorflow_quantum/python/layers/high_level:pqc", | ||
| "//tensorflow_quantum/core/ops:tfq_ps_util_ops_py" | ||
| ], | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Module definitions for tensorflow_quantum.python.optimizers.*""" | ||
|
|
||
| # Quantum circuit specific optimizers. | ||
| from tensorflow_quantum.python.optimizers.rotosolve_minimizer import ( | ||
| minimize as rotosolve_minimize) |
263 changes: 263 additions & 0 deletions
263
tensorflow_quantum/python/optimizers/rotosolve_minimizer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,263 @@ | ||
| # Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """The rotosolve minimization algorithm""" | ||
| import collections | ||
| import numpy as np | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| def prefer_static_shape(x): | ||
| """Return static shape of tensor `x` if available, | ||
|
|
||
| else `tf.shape(x)`. | ||
|
|
||
| Args: | ||
| x: `tf.Tensor` (already converted). | ||
| Returns: | ||
| Numpy array (if static shape is obtainable), else `tf.Tensor`. | ||
| """ | ||
| return prefer_static_value(tf.shape(x)) | ||
|
|
||
|
|
||
| def prefer_static_value(x): | ||
| """Return static value of tensor `x` if available, else `x`. | ||
|
|
||
| Args: | ||
| x: `tf.Tensor` (already converted). | ||
| Returns: | ||
| Numpy array (if static value is obtainable), else `tf.Tensor`. | ||
| """ | ||
| static_x = tf.get_static_value(x) | ||
| if static_x is not None: | ||
| return static_x | ||
| return x | ||
|
|
||
|
|
||
| RotosolveOptimizerResults = collections.namedtuple( | ||
| 'RotosolveOptimizerResults', | ||
| [ | ||
| 'converged', | ||
| # Scalar boolean tensor indicating whether the minimum | ||
| # was found within tolerance. | ||
| 'num_iterations', | ||
| # The number of iterations of the rotosolve update. | ||
| 'num_objective_evaluations', | ||
| # The total number of objective | ||
| # evaluations performed. | ||
| 'position', | ||
| # A tensor containing the last argument value found | ||
| # during the search. If the search converged, then | ||
| # this value is the argmin of the objective function. | ||
| # A tensor containing the value of the objective from | ||
| # previous iteration | ||
| 'objective_value_previous_iteration', | ||
| # Save the evaluated value of the objective function | ||
| # from the previous iteration | ||
| 'objective_value', | ||
| # A tensor containing the value of the objective | ||
| # function at the `position`. If the search | ||
| # converged, then this is the (local) minimum of | ||
| # the objective function. | ||
| 'tolerance', | ||
| # Define the stop criteria. Iteration will stop when the | ||
| # objective value difference between two iterations is | ||
| # smaller than tolerance | ||
| 'solve_param_i', | ||
| # The parameter index where rotosolve is currently | ||
| # modifying. Reserved for internal use. | ||
| ]) | ||
|
|
||
|
|
||
| def _get_initial_state(initial_position, tolerance, expectation_value_function): | ||
| """Create RotosolveOptimizerResults with initial state of search.""" | ||
| init_args = { | ||
| "converged": tf.Variable(False), | ||
| "num_iterations": tf.Variable(0), | ||
| "num_objective_evaluations": tf.Variable(0), | ||
| "position": tf.Variable(initial_position), | ||
| "objective_value": tf.Variable(0.), | ||
| "objective_value_previous_iteration": tf.Variable(0.), | ||
| "tolerance": tolerance, | ||
| "solve_param_i": tf.Variable(0) | ||
| } | ||
| return RotosolveOptimizerResults(**init_args) | ||
|
|
||
|
|
||
| def minimize(expectation_value_function, | ||
| initial_position, | ||
| tolerance=1e-5, | ||
| max_iterations=50, | ||
| name=None): | ||
| """Applies the rotosolve algorithm. | ||
|
|
||
| The rotosolve algorithm can be used to minimize a linear combination | ||
|
|
||
| of quantum measurement expectation values. See the following paper: | ||
|
|
||
| [arXiv:1903.12166](https://arxiv.org/abs/1903.12166), Ken M. Nakanishi. | ||
| [arXiv:1905.09692](https://arxiv.org/abs/1905.09692), Mateusz Ostaszewski. | ||
|
|
||
| Usage: | ||
|
|
||
| Here is an example of optimize a function which consists summation of | ||
| a few sinusoids. | ||
|
|
||
| >>> n = 10 # Number of sinusoids | ||
| >>> coefficient = tf.random.uniform(shape=[n]) | ||
| >>> min_value = -tf.math.reduce_sum(tf.abs(coefficient)) | ||
| >>> func = lambda x:tf.math.reduce_sum(tf.sin(x) * coefficient) | ||
| >>> # Optimize the function with rotosolve, start with random parameters | ||
| >>> result = tfq.optimizers.rotosolve_minimize(func, np.random.random(n)) | ||
| >>> result.converged | ||
| tf.Tensor(True, shape=(), dtype=bool) | ||
| >>> result.objective_value | ||
| tf.Tensor(-4.7045116, shape=(), dtype=float32) | ||
|
|
||
| Args: | ||
| expectation_value_function: A Python callable that accepts | ||
| a point as a real `tf.Tensor` and returns a `tf.Tensor`s | ||
| of real dtype containing the value of the function. | ||
| The function to be minimized. The input is of shape `[n]`, | ||
| where `n` is the size of the trainable parameters. | ||
| The return value is a real `tf.Tensor` Scalar (matching shape | ||
| `[1]`). This must be a linear combination of quantum | ||
| measurement expectation value, otherwise this algorithm cannot | ||
| work. | ||
| initial_position: Real `tf.Tensor` of shape `[n]`. The starting | ||
| point, or points when using batching dimensions, of the search | ||
| procedure. At these points the function value and the gradient | ||
| norm should be finite. | ||
| tolerance: Scalar `tf.Tensor` of real dtype. Specifies the tolerance | ||
| for the procedure. If the supremum norm between two iteration | ||
| vector is below this number, the algorithm is stopped. | ||
| name: (Optional) Python `str`. The name prefixed to the ops created | ||
| by this function. If not supplied, the default name 'minimize' | ||
| is used. | ||
|
|
||
| Returns: | ||
| optimizer_results: A RotosolveOptimizerResults object contains the | ||
| result of the optimization process. | ||
| """ | ||
|
|
||
| with tf.name_scope(name or 'minimize'): | ||
| initial_position = tf.convert_to_tensor(initial_position, | ||
| name='initial_position', | ||
| dtype='float32') | ||
| dtype = initial_position.dtype.base_dtype | ||
| tolerance = tf.convert_to_tensor(tolerance, | ||
| dtype=dtype, | ||
| name='grad_tolerance') | ||
| max_iterations = tf.convert_to_tensor(max_iterations, | ||
| name='max_iterations') | ||
|
|
||
| def _rotosolve_one_parameter_once(state): | ||
| """Rotosolve a single parameter once. | ||
|
|
||
| Args: | ||
| state: A RotosolveOptimizerResults object stores the | ||
| current state of the minimizer. | ||
|
|
||
| Returns: | ||
| states: A list which the first element is the new state | ||
| """ | ||
| delta_shift = tf.scatter_nd([[state.solve_param_i]], | ||
| [tf.constant(np.pi / 2, dtype=dtype)], | ||
| prefer_static_shape(state.position)) | ||
|
|
||
| # Evaluate three different point for curve fitting | ||
| v_l, v_n, v_r = expectation_value_function( | ||
| state.position - delta_shift), \ | ||
| state.objective_value, \ | ||
| expectation_value_function(state.position + delta_shift) | ||
|
|
||
| # Use the analytical solution to find the optimized position | ||
| delta_update = -np.pi / 2 - \ | ||
| tf.math.atan2(2 * v_n - v_l - v_r, v_r - v_l) | ||
|
|
||
| delta_update_tensor = tf.scatter_nd( | ||
| [[state.solve_param_i]], [delta_update], | ||
| prefer_static_shape(state.position)) | ||
|
|
||
| state.solve_param_i.assign_add(1) | ||
| state.position.assign( | ||
| tf.math.floormod(state.position + delta_update_tensor, | ||
| np.pi * 2)) | ||
|
|
||
| state.objective_value_previous_iteration.assign( | ||
| state.objective_value) | ||
| state.objective_value.assign( | ||
| expectation_value_function(state.position)) | ||
|
|
||
| return [state] | ||
|
|
||
| def _rotosolve_all_parameters_once(state): | ||
| """Iterate over all parameters and rotosolve each single | ||
|
|
||
| of them once. | ||
|
|
||
| Args: | ||
| state: A RotosolveOptimizerResults object stores the | ||
| current state of the minimizer. | ||
|
|
||
| Returns: | ||
| states: A list which the first element is the new state | ||
| """ | ||
|
|
||
| def _cond_internal(state_cond): | ||
| return state_cond.solve_param_i < \ | ||
| prefer_static_shape(state_cond.position)[0] | ||
|
|
||
| state.num_objective_evaluations.assign_add(1) | ||
|
|
||
| return tf.while_loop( | ||
| cond=_cond_internal, | ||
| body=_rotosolve_one_parameter_once, | ||
| loop_vars=[state], | ||
| parallel_iterations=1, | ||
| ) | ||
|
|
||
| # The `state` here is a `RotosolveOptimizerResults` tuple with | ||
| # values for the current state of the algorithm computation. | ||
| def _cond(state): | ||
| """Continue if iterations remain and stopping condition | ||
| is not met.""" | ||
| return (state.num_iterations < max_iterations) \ | ||
| and (not state.converged) | ||
|
|
||
| def _body(state): | ||
| """Main optimization loop.""" | ||
|
|
||
| state.solve_param_i.assign(0) | ||
|
|
||
| _rotosolve_all_parameters_once(state) | ||
|
|
||
| state.num_iterations.assign_add(1) | ||
| state.converged.assign( | ||
| tf.abs(state.objective_value - | ||
| state.objective_value_previous_iteration) < | ||
| state.tolerance) | ||
|
|
||
| return [state] | ||
|
|
||
| initial_state = _get_initial_state(initial_position, tolerance, | ||
| expectation_value_function) | ||
|
|
||
| initial_state.objective_value.assign( | ||
| expectation_value_function(initial_state.position)) | ||
|
|
||
| return tf.while_loop(cond=_cond, | ||
| body=_body, | ||
| loop_vars=[initial_state], | ||
| parallel_iterations=1)[0] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 5d0d1f4