diff --git a/release/BUILD b/release/BUILD index 72dcc1421..103820b4e 100644 --- a/release/BUILD +++ b/release/BUILD @@ -22,6 +22,7 @@ sh_binary( "//tensorflow_quantum/python/layers/circuit_construction:__init__.py", "//tensorflow_quantum/python/layers/circuit_executors:__init__.py", "//tensorflow_quantum/python/layers/high_level:__init__.py", + "//tensorflow_quantum/python/optimizers:__init__.py", # Datasets module. "//tensorflow_quantum/datasets:__init__.py", @@ -53,5 +54,6 @@ sh_binary( "//tensorflow_quantum/python/layers/high_level:pqc", "//tensorflow_quantum/python:quantum_context", "//tensorflow_quantum/python:util", + "//tensorflow_quantum/python/optimizers:rotosolve_minimizer", ], ) diff --git a/scripts/build_docs.py b/scripts/build_docs.py index 770697330..2a4cdf9c3 100644 --- a/scripts/build_docs.py +++ b/scripts/build_docs.py @@ -67,6 +67,7 @@ def main(unused_argv): "parameter_shift_util", "adjoint" ], "tfq.datasets": ["cluster_state"], + "tfq.optimizers": ["rotosolve_minimizer"], "tfq.util": [ "from_tensor", "convert_to_tensor", "exp_identity", "check_commutability", "kwargs_cartesian_product", diff --git a/scripts/import_test.py b/scripts/import_test.py index c0f3e8d6b..581af3add 100644 --- a/scripts/import_test.py +++ b/scripts/import_test.py @@ -63,6 +63,9 @@ def test_imports(): _ = tfq.datasets.excited_cluster_states _ = tfq.datasets.tfi_chain + #Optimizers + _ = tfq.optimizers.rotosolve_minimize + if __name__ == "__main__": test_imports() diff --git a/tensorflow_quantum/__init__.py b/tensorflow_quantum/__init__.py index 5b6858187..614e75eb8 100644 --- a/tensorflow_quantum/__init__.py +++ b/tensorflow_quantum/__init__.py @@ -47,6 +47,9 @@ # Import differentiators. import tensorflow_quantum.python.differentiators as differentiators +# Import optimizers. +import tensorflow_quantum.python.optimizers as optimizers + # Python adds these symbols for resolution of above imports to # work. We get rid of them so that we don't have two paths to # things. For example: tfq.layers and tfq.python.layers diff --git a/tensorflow_quantum/python/optimizers/BUILD b/tensorflow_quantum/python/optimizers/BUILD new file mode 100755 index 000000000..c6dca3ce3 --- /dev/null +++ b/tensorflow_quantum/python/optimizers/BUILD @@ -0,0 +1,22 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +# Export for the PIP package. +exports_files(["__init__.py"]) + +py_library( + name = "rotosolve_minimizer", + srcs = ["rotosolve_minimizer.py"], +) + +py_test( + name = "rotosolve_minimizer_test", + srcs = ["rotosolve_minimizer_test.py"], + python_version = "PY3", + deps = [ + ":rotosolve_minimizer", + "//tensorflow_quantum/python/layers/high_level:pqc", + "//tensorflow_quantum/core/ops:tfq_ps_util_ops_py" + ], +) diff --git a/tensorflow_quantum/python/optimizers/__init__.py b/tensorflow_quantum/python/optimizers/__init__.py new file mode 100755 index 000000000..21a9d08fc --- /dev/null +++ b/tensorflow_quantum/python/optimizers/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module definitions for tensorflow_quantum.python.optimizers.*""" + +# Quantum circuit specific optimizers. +from tensorflow_quantum.python.optimizers.rotosolve_minimizer import ( + minimize as rotosolve_minimize) diff --git a/tensorflow_quantum/python/optimizers/rotosolve_minimizer.py b/tensorflow_quantum/python/optimizers/rotosolve_minimizer.py new file mode 100755 index 000000000..bd1270068 --- /dev/null +++ b/tensorflow_quantum/python/optimizers/rotosolve_minimizer.py @@ -0,0 +1,263 @@ +# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The rotosolve minimization algorithm""" +import collections +import numpy as np +import tensorflow as tf + + +def prefer_static_shape(x): + """Return static shape of tensor `x` if available, + + else `tf.shape(x)`. + + Args: + x: `tf.Tensor` (already converted). + Returns: + Numpy array (if static shape is obtainable), else `tf.Tensor`. + """ + return prefer_static_value(tf.shape(x)) + + +def prefer_static_value(x): + """Return static value of tensor `x` if available, else `x`. + + Args: + x: `tf.Tensor` (already converted). + Returns: + Numpy array (if static value is obtainable), else `tf.Tensor`. + """ + static_x = tf.get_static_value(x) + if static_x is not None: + return static_x + return x + + +RotosolveOptimizerResults = collections.namedtuple( + 'RotosolveOptimizerResults', + [ + 'converged', + # Scalar boolean tensor indicating whether the minimum + # was found within tolerance. + 'num_iterations', + # The number of iterations of the rotosolve update. + 'num_objective_evaluations', + # The total number of objective + # evaluations performed. + 'position', + # A tensor containing the last argument value found + # during the search. If the search converged, then + # this value is the argmin of the objective function. + # A tensor containing the value of the objective from + # previous iteration + 'objective_value_previous_iteration', + # Save the evaluated value of the objective function + # from the previous iteration + 'objective_value', + # A tensor containing the value of the objective + # function at the `position`. If the search + # converged, then this is the (local) minimum of + # the objective function. + 'tolerance', + # Define the stop criteria. Iteration will stop when the + # objective value difference between two iterations is + # smaller than tolerance + 'solve_param_i', + # The parameter index where rotosolve is currently + # modifying. Reserved for internal use. + ]) + + +def _get_initial_state(initial_position, tolerance, expectation_value_function): + """Create RotosolveOptimizerResults with initial state of search.""" + init_args = { + "converged": tf.Variable(False), + "num_iterations": tf.Variable(0), + "num_objective_evaluations": tf.Variable(0), + "position": tf.Variable(initial_position), + "objective_value": tf.Variable(0.), + "objective_value_previous_iteration": tf.Variable(0.), + "tolerance": tolerance, + "solve_param_i": tf.Variable(0) + } + return RotosolveOptimizerResults(**init_args) + + +def minimize(expectation_value_function, + initial_position, + tolerance=1e-5, + max_iterations=50, + name=None): + """Applies the rotosolve algorithm. + + The rotosolve algorithm can be used to minimize a linear combination + + of quantum measurement expectation values. See the following paper: + + [arXiv:1903.12166](https://arxiv.org/abs/1903.12166), Ken M. Nakanishi. + [arXiv:1905.09692](https://arxiv.org/abs/1905.09692), Mateusz Ostaszewski. + + Usage: + + Here is an example of optimize a function which consists summation of + a few sinusoids. + + >>> n = 10 # Number of sinusoids + >>> coefficient = tf.random.uniform(shape=[n]) + >>> min_value = -tf.math.reduce_sum(tf.abs(coefficient)) + >>> func = lambda x:tf.math.reduce_sum(tf.sin(x) * coefficient) + >>> # Optimize the function with rotosolve, start with random parameters + >>> result = tfq.optimizers.rotosolve_minimize(func, np.random.random(n)) + >>> result.converged + tf.Tensor(True, shape=(), dtype=bool) + >>> result.objective_value + tf.Tensor(-4.7045116, shape=(), dtype=float32) + + Args: + expectation_value_function: A Python callable that accepts + a point as a real `tf.Tensor` and returns a `tf.Tensor`s + of real dtype containing the value of the function. + The function to be minimized. The input is of shape `[n]`, + where `n` is the size of the trainable parameters. + The return value is a real `tf.Tensor` Scalar (matching shape + `[1]`). This must be a linear combination of quantum + measurement expectation value, otherwise this algorithm cannot + work. + initial_position: Real `tf.Tensor` of shape `[n]`. The starting + point, or points when using batching dimensions, of the search + procedure. At these points the function value and the gradient + norm should be finite. + tolerance: Scalar `tf.Tensor` of real dtype. Specifies the tolerance + for the procedure. If the supremum norm between two iteration + vector is below this number, the algorithm is stopped. + name: (Optional) Python `str`. The name prefixed to the ops created + by this function. If not supplied, the default name 'minimize' + is used. + + Returns: + optimizer_results: A RotosolveOptimizerResults object contains the + result of the optimization process. + """ + + with tf.name_scope(name or 'minimize'): + initial_position = tf.convert_to_tensor(initial_position, + name='initial_position', + dtype='float32') + dtype = initial_position.dtype.base_dtype + tolerance = tf.convert_to_tensor(tolerance, + dtype=dtype, + name='grad_tolerance') + max_iterations = tf.convert_to_tensor(max_iterations, + name='max_iterations') + + def _rotosolve_one_parameter_once(state): + """Rotosolve a single parameter once. + + Args: + state: A RotosolveOptimizerResults object stores the + current state of the minimizer. + + Returns: + states: A list which the first element is the new state + """ + delta_shift = tf.scatter_nd([[state.solve_param_i]], + [tf.constant(np.pi / 2, dtype=dtype)], + prefer_static_shape(state.position)) + + # Evaluate three different point for curve fitting + v_l, v_n, v_r = expectation_value_function( + state.position - delta_shift), \ + state.objective_value, \ + expectation_value_function(state.position + delta_shift) + + # Use the analytical solution to find the optimized position + delta_update = -np.pi / 2 - \ + tf.math.atan2(2 * v_n - v_l - v_r, v_r - v_l) + + delta_update_tensor = tf.scatter_nd( + [[state.solve_param_i]], [delta_update], + prefer_static_shape(state.position)) + + state.solve_param_i.assign_add(1) + state.position.assign( + tf.math.floormod(state.position + delta_update_tensor, + np.pi * 2)) + + state.objective_value_previous_iteration.assign( + state.objective_value) + state.objective_value.assign( + expectation_value_function(state.position)) + + return [state] + + def _rotosolve_all_parameters_once(state): + """Iterate over all parameters and rotosolve each single + + of them once. + + Args: + state: A RotosolveOptimizerResults object stores the + current state of the minimizer. + + Returns: + states: A list which the first element is the new state + """ + + def _cond_internal(state_cond): + return state_cond.solve_param_i < \ + prefer_static_shape(state_cond.position)[0] + + state.num_objective_evaluations.assign_add(1) + + return tf.while_loop( + cond=_cond_internal, + body=_rotosolve_one_parameter_once, + loop_vars=[state], + parallel_iterations=1, + ) + + # The `state` here is a `RotosolveOptimizerResults` tuple with + # values for the current state of the algorithm computation. + def _cond(state): + """Continue if iterations remain and stopping condition + is not met.""" + return (state.num_iterations < max_iterations) \ + and (not state.converged) + + def _body(state): + """Main optimization loop.""" + + state.solve_param_i.assign(0) + + _rotosolve_all_parameters_once(state) + + state.num_iterations.assign_add(1) + state.converged.assign( + tf.abs(state.objective_value - + state.objective_value_previous_iteration) < + state.tolerance) + + return [state] + + initial_state = _get_initial_state(initial_position, tolerance, + expectation_value_function) + + initial_state.objective_value.assign( + expectation_value_function(initial_state.position)) + + return tf.while_loop(cond=_cond, + body=_body, + loop_vars=[initial_state], + parallel_iterations=1)[0] diff --git a/tensorflow_quantum/python/optimizers/rotosolve_minimizer_test.py b/tensorflow_quantum/python/optimizers/rotosolve_minimizer_test.py new file mode 100755 index 000000000..2631f9606 --- /dev/null +++ b/tensorflow_quantum/python/optimizers/rotosolve_minimizer_test.py @@ -0,0 +1,167 @@ +# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test module for tfq.python.optimizers.rotosolve_minimizer optimizer.""" +from operator import mul +from functools import reduce +import numpy as np +import tensorflow as tf +from absl.testing import parameterized +import cirq +import sympy +from tensorflow_quantum.python.layers.high_level import pqc +from tensorflow_quantum.python import util +from tensorflow_quantum.python.optimizers import rotosolve_minimizer + + +def loss_function_with_model_parameters(model, loss, train_x, train_y): + """Create a new function that assign the model parameter to the model + and evaluate its value. + + Args: + model : an instance of `tf.keras.Model` or its subclasses. + loss : a function with signature loss_value = loss(pred_y, true_y). + train_x : the input part of training data. + train_y : the output part of training data. + + Returns: + A function that has a signature of: + loss_value = f(model_parameters). + """ + + # obtain the shapes of all trainable parameters in the model + shapes = tf.shape_n(model.trainable_variables) + count = 0 + sizes = [] + + # Record the shape of each parameter + for shape in shapes: + n = reduce(mul, shape) + sizes.append(n) + count += n + + # Function accept the parameter and evaluate model + @tf.function + def func(params): + """A function that can be used by tfq.optimizer.rotosolve_minimize. + + Args: + params [in]: a 1D tf.Tensor. + + Returns: + Loss function value + """ + + # update the parameters of the model + start = 0 + for i, size in enumerate(sizes): + model.trainable_variables[i].assign( + tf.reshape(params[start:start + size], shape)) + start += size + + # evaluate the loss + loss_value = loss(model(train_x, training=True), train_y) + return loss_value + + return func + + +class RotosolveMinimizerTest(tf.test.TestCase, parameterized.TestCase): + """Tests for the rotosolve optimization algorithm.""" + + def test_function_optimization(self): + """Optimize a simple sinusoid function.""" + + n = 10 # Number of parameters to be optimized + coefficient = tf.random.uniform(shape=[n]) + min_value = -tf.math.reduce_sum(tf.abs(coefficient)) + + func = lambda x: tf.math.reduce_sum(tf.sin(x) * coefficient) + + result = rotosolve_minimizer.minimize(func, np.random.random(n)) + + self.assertAlmostEqual(func(result.position), min_value) + self.assertAlmostEqual(result.objective_value, min_value) + self.assertTrue(result.converged) + self.assertLess(result.num_iterations, + 50) # 50 is the default max iteration + + def test_nonlinear_function_optimization(self): + """Test to optimize a non-linear function. + A non-linear function cannot be optimized by rotosolve, + therefore the optimization must never converge. + """ + func = lambda x: x[0]**2 + x[1]**2 + + result = rotosolve_minimizer.minimize(func, + tf.random.uniform(shape=[2])) + + self.assertFalse(result.converged) + self.assertEqual(result.num_iterations, + 50) # 50 is the default max iteration + + def test_keras_model_optimization(self): + """Optimizate a PQC based keras model.""" + + x = np.asarray([ + [0, 0], + [0, 1], + [1, 0], + [1, 1], + ], dtype=float) + + y = np.asarray([[-1], [1], [1], [-1]], dtype=np.float32) + + def convert_to_circuit(input_data): + """Encode into quantum datapoint.""" + values = np.ndarray.flatten(input_data) + qubits = cirq.GridQubit.rect(1, 2) + circuit = cirq.Circuit() + for i, value in enumerate(values): + if value: + circuit.append(cirq.X(qubits[i])) + return circuit + + x_circ = util.convert_to_tensor([convert_to_circuit(x) for x in x]) + + # Create two qubits + q0, q1 = cirq.GridQubit.rect(1, 2) + + # Create an anzatz on these qubits. + a, b = sympy.symbols('a b') # parameters for the circuit + circuit = cirq.Circuit( + cirq.rx(a).on(q0), + cirq.ry(b).on(q1), cirq.CNOT(control=q0, target=q1)) + + # Build the Keras model. + model = tf.keras.Sequential([ + # The input is the data-circuit, encoded as a tf.string + tf.keras.layers.Input(shape=(), dtype=tf.string), + # The PQC layer returns the expected value of the + # readout gate, range [-1,1]. + pqc.PQC(circuit, cirq.Z(q1)), + ]) + + # Initial guess of the parameter from random number + result = rotosolve_minimizer.minimize( + loss_function_with_model_parameters(model, tf.keras.losses.Hinge(), + x_circ, y), + tf.random.uniform(shape=[2]) * 2 * np.pi) + + self.assertAlmostEqual(result.objective_value, 0) + self.assertTrue(result.converged) + + +if __name__ == "__main__": + tf.test.main()