diff --git a/docs/source/tutorials/nnvqe.ipynb b/docs/source/tutorials/nnvqe.ipynb index e0c861cc..ba48f3f3 100644 --- a/docs/source/tutorials/nnvqe.ipynb +++ b/docs/source/tutorials/nnvqe.ipynb @@ -6,7 +6,7 @@ "id": "64ba95d6", "metadata": {}, "source": [ - "#
NN-VQE" + "#
Neural Network encoded Variational Quantum Eigensolver (NN-VQE)" ] }, { @@ -105,7 +105,7 @@ "source": [ "## Ansatz circuit\n", "\n", - "Now we design the circuit. We choose multi-scale entangled renormalization ansatz (MERA) as the ansatz here, $d$ is the circuit depth. (see tutorial of MERA [here](https://tensorcircuit.readthedocs.io/en/latest/tutorials/mera.html))" + "Now we design the circuit. We choose multi-scale entangled renormalization ansatz (MERA) as the ansatz here, $d$ is the circuit depth. (see [MERA tutorial](https://tensorcircuit.readthedocs.io/en/latest/tutorials/mera.html))" ] }, { @@ -2909,7 +2909,7 @@ "source": [ "## NN-VQE\n", "\n", - "Design the NN-VQE. We use a neural network to transform the Hamiltonian parameters to the optimized parameters in the PQC for VQE." + "Design the NN-VQE. We use a neural network to transform the Hamiltonian parameters to the optimized parameters in the parameterized quantum circuit (PQC) for VQE." ] }, { @@ -3062,7 +3062,7 @@ "test_delta = np.linspace(-4.0, 4.0, 201) # test set\n", "test_energies = tf.zeros_like(test_delta).numpy()\n", "m = NN_MERA(n, d, lamb, NN_shape, stddev)\n", - "m.load_weights(\"DNN-MERA_2[20](-3.0,3.0,20)_drop05.weights.h5\")\n", + "m.load_weights(\"NN-VQE.weights.h5\")\n", "for i, de in tqdm(enumerate(test_delta)):\n", " test_energies[i] = m(K.reshape(de, [1]))" ] @@ -3074,7 +3074,7 @@ "source": [ "## Compare\n", "\n", - "We compare the results of NN-VQE with the analytical ones to calculate the ground-state relative error. From the figure, we can see that NN-VQE is able to estimate the ground-state energies of parameterized Hamiltonians with high precision without fine-tuning and has a favorable generalization capability." + "We compare the results of NN-VQE with the analytical ones to calculate the ground-state energy relative error. From the figure, we can see that NN-VQE is able to estimate the ground-state energies of parameterized Hamiltonians with high precision without fine-tuning and has a favorable generalization capability." ] }, { @@ -3982,7 +3982,7 @@ "id": "5f9bda8a", "metadata": {}, "source": [ - "To get more detailed information or further study, please refer to [our paper](https://arxiv.org/abs/2308.01068) and [GitHub](https://github.com/JachyMeow/NN-VQA)." + "To get more detailed information or further study, please refer to our [paper](https://arxiv.org/abs/2308.01068) and [GitHub](https://github.com/JachyMeow/NN-VQA)." ] } ], diff --git a/docs/source/tutorials/nnvqe_cn.ipynb b/docs/source/tutorials/nnvqe_cn.ipynb new file mode 100644 index 00000000..df6e0871 --- /dev/null +++ b/docs/source/tutorials/nnvqe_cn.ipynb @@ -0,0 +1,4013 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "64ba95d6", + "metadata": {}, + "source": [ + "#
神经网络编码的变分量子本征值求解器(NN-VQE)" + ] + }, + { + "cell_type": "markdown", + "id": "b65f64bf", + "metadata": {}, + "source": [ + "## 概述\n", + "\n", + "在本教程中,我们将使用TensorCircuit展示一个量子计算通用框架——神经网络编码的变分量子算法(neural network encoded variational quantum algorithms,NN-VQAs)。NN-VQA将一个给定问题的参量(如哈密顿量的参数)作为神经网络的输入,并使用其输出来参数化标准的变分量子算法(variational quantum algorithms,VQAs)的线路拟设(ansatz circuit)。在本文中,我们以神经网络编码的变分量子本征值求解器(NN-variational quantum eigensolver,NN-VQE)来具体说明。" + ] + }, + { + "cell_type": "markdown", + "id": "831930ae", + "metadata": {}, + "source": [ + "## 设置" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4e1651b9", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import tensorflow as tf\n", + "import tensorcircuit as tc\n", + "import cotengra\n", + "import quimb\n", + "from tqdm.notebook import tqdm\n", + "from functools import partial\n", + "\n", + "optc = cotengra.ReusableHyperOptimizer(\n", + " methods=[\"greedy\"],\n", + " parallel=\"ray\",\n", + " minimize=\"combo\",\n", + " max_time=30,\n", + " max_repeats=1024,\n", + " progbar=True,\n", + ")\n", + "tc.set_contractor(\"custom\", optimizer=optc, preprocessing=True)\n", + "\n", + "K = tc.set_backend(\"tensorflow\")\n", + "tc.set_dtype(\"complex128\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d78b480b", + "metadata": {}, + "source": [ + "## 能量\n", + "\n", + "本教程所使用的哈密顿量是具有周期性边界条件的一维XXZ伊辛模型。它具有横向场强$\\lambda$和各向异性参数$\\Delta$。我们选择哈密顿量的能量期望函数作为损失函数。\n", + "\n", + "$$ \\hat{H}_{XXZ}=\\sum_{i}{ \\left( X_{i}X_{i+1}+Y_{i}Y_{i+1}+\\Delta Z_{i}Z_{i+1} \\right) } + \\lambda \\sum_{i}{Z_{i}} $$" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fff67346", + "metadata": {}, + "outputs": [], + "source": [ + "def energy(c: tc.Circuit, lamb: float = 1.0, delta: float = 1.0):\n", + " e = 0.0\n", + " n = c._nqubits\n", + " for i in range(n):\n", + " e += lamb * c.expectation((tc.gates.z(), [i])) # \n", + " for i in range(n):\n", + " e += c.expectation(\n", + " (tc.gates.x(), [i]), (tc.gates.x(), [(i + 1) % n])\n", + " ) # \n", + " e += c.expectation(\n", + " (tc.gates.y(), [i]), (tc.gates.y(), [(i + 1) % n])\n", + " ) # \n", + " e += delta * c.expectation(\n", + " (tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n])\n", + " ) # \n", + " return K.real(e)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0ad6a7a6", + "metadata": {}, + "source": [ + "## 线路拟设\n", + "\n", + "现在我们来设计线路。我们选择多尺度纠缠重整化拟设(multi-scale entangled renormalization ansatz,MERA)作为线路拟设,其中$d$为线路深度。(详见[MERA教程](https://tensorcircuit.readthedocs.io/zh/latest/tutorials/mera_cn.html))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "445b7c86", + "metadata": {}, + "outputs": [], + "source": [ + "def MERA(inp, n, d=1, lamb=1.0, energy_flag=False): # 对于单变量一维XXZ模型,我们固定lamb\n", + " params = K.cast(inp[\"params\"], \"complex128\")\n", + " delta = K.cast(inp[\"delta\"], \"complex128\")\n", + " c = tc.Circuit(n)\n", + "\n", + " idx = 0\n", + "\n", + " for i in range(n):\n", + " c.rx(i, theta=params[3 * i])\n", + " c.rz(i, theta=params[3 * i + 1])\n", + " c.rx(i, theta=params[3 * i + 2])\n", + " idx += 3 * n\n", + "\n", + " for n_layer in range(1, int(np.log2(n)) + 1):\n", + " n_qubit = 2**n_layer # 涉及的量子比特数\n", + " step = int(n / n_qubit)\n", + "\n", + " for _ in range(d): # 线路深度\n", + " # 偶数层\n", + " for i in range(step, n - step, 2 * step):\n", + " c.rxx(i, i + step, theta=params[idx])\n", + " c.rzz(i, i + step, theta=params[idx + 1])\n", + " idx += 2\n", + "\n", + " # 奇数层\n", + " for i in range(0, n, 2 * step):\n", + " c.rxx(i, i + step, theta=params[idx])\n", + " c.rzz(i, i + step, theta=params[idx + 1])\n", + " idx += 2\n", + "\n", + " # 单比特门\n", + " for i in range(0, n, step):\n", + " c.rx(i, theta=params[idx])\n", + " c.rz(i, theta=params[idx + 1])\n", + " idx += 2\n", + "\n", + " if energy_flag:\n", + " return energy(c, lamb, delta) # 返回哈密顿量的能量期望\n", + " else:\n", + " return c, idx # 返回线路&线路参量数" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6daa2f64", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The number of parameters is 74\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-08-06T16:31:47.918946\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.5.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 线路可视化\n", + "n = 8\n", + "d = 1\n", + "cirq, idx = MERA({\"params\": np.zeros(3000), \"delta\": 0.0}, n, d, 1.0)\n", + "print(\"The number of parameters is\", idx)\n", + "cirq.draw()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bfe2fbee", + "metadata": {}, + "source": [ + "## NN-VQE\n", + "\n", + "设计NN-VQE。我们使用神经网络将哈密顿量参数转换为VQE的变分量子线路(parameterized quantum ciecuit,PQC)的优化参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1ac050a8", + "metadata": {}, + "outputs": [], + "source": [ + "def NN_MERA(n, d, lamb, NN_shape, stddev):\n", + " input = tf.keras.layers.Input(shape=[1]) # 输入层\n", + "\n", + " x = tf.keras.layers.Dense(\n", + " units=NN_shape,\n", + " kernel_initializer=tf.keras.initializers.RandomNormal(stddev=stddev),\n", + " activation=\"ReLU\",\n", + " )(\n", + " input\n", + " ) # 隐层\n", + "\n", + " x = tf.keras.layers.Dropout(0.05)(x) # dropout层\n", + "\n", + " _, idx = MERA(\n", + " {\"params\": np.zeros(3000), \"delta\": 0.0}, n, d, 1.0, energy_flag=False\n", + " )\n", + " params = tf.keras.layers.Dense(\n", + " units=idx,\n", + " kernel_initializer=tf.keras.initializers.RandomNormal(stddev=stddev),\n", + " activation=\"sigmoid\",\n", + " )(\n", + " x\n", + " ) # 输出层\n", + "\n", + " qlayer = tc.KerasLayer(partial(MERA, n=n, d=d, lamb=lamb, energy_flag=True)) # PQC\n", + "\n", + " output = qlayer({\"params\": 6.3 * params, \"delta\": input}) # NN-VQE输出\n", + "\n", + " m = tf.keras.Model(inputs=input, outputs=output)\n", + "\n", + " return m" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "e9afb1a5", + "metadata": {}, + "source": [ + "## 训练\n", + "\n", + "现在我们用TensorFlow来训练NN-VQE。" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "873ebd5e", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def train(n, d, lamb, delta, NN_shape, maxiter=10000, lr=0.005, stddev=1.0):\n", + " exp_lr = tf.keras.optimizers.schedules.ExponentialDecay(\n", + " initial_learning_rate=lr, decay_steps=1000, decay_rate=0.7\n", + " )\n", + " opt = tf.keras.optimizers.Adam(exp_lr) # 优化器\n", + "\n", + " m = NN_MERA(n, d, lamb, NN_shape, stddev)\n", + " for i in range(maxiter):\n", + " with tf.GradientTape() as tape:\n", + " e = tf.zeros([1], dtype=tf.float64)\n", + " for de in delta:\n", + " e += m(K.reshape(de, [1])) # 将所有训练点的能量相加\n", + " grads = tape.gradient(e, m.variables)\n", + " opt.apply_gradients(zip(grads, m.variables))\n", + " if i % 500 == 0:\n", + " print(\"epoch\", i, \":\", e)\n", + "\n", + " m.save_weights(\"NN-VQE.weights.h5\") # 保存已训练的模型" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "e8df3d67", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0 : tf.Tensor([[117.53523392]], shape=(1, 1), dtype=float64)\n", + "epoch 500 : tf.Tensor([[-361.85937039]], shape=(1, 1), dtype=float64)\n", + "epoch 1000 : tf.Tensor([[-365.35288984]], shape=(1, 1), dtype=float64)\n", + "epoch 1500 : tf.Tensor([[-366.65891358]], shape=(1, 1), dtype=float64)\n", + "epoch 2000 : tf.Tensor([[-366.94258369]], shape=(1, 1), dtype=float64)\n" + ] + } + ], + "source": [ + "n = 8 # 量子比特数\n", + "d = 2 # 线路深度\n", + "lamb = 0.75 # 固定参数\n", + "delta = np.linspace(-3.0, 3.0, 20, dtype=\"complex128\") # 训练集\n", + "NN_shape = 20 # 隐层节点数\n", + "maxiter = 2500 # 最大迭代轮数\n", + "lr = 0.009 # 学习率\n", + "stddev = 0.1 # 神经网络参数初始值的标准差\n", + "\n", + "with tf.device(\"/cpu:0\"):\n", + " train(n, d, lamb, delta, NN_shape=NN_shape, maxiter=maxiter, lr=lr, stddev=stddev)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c5a7cdb0", + "metadata": {}, + "source": [ + "## 测试\n", + "\n", + "我们使用较训练集更大的测试集来测试NN-VQE的准确性和泛化能力。" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "f15f4f68", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4a3ceb8bdf90463f88050b771fde6925", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "test_delta = np.linspace(-4.0, 4.0, 201) # 测试集\n", + "test_energies = tf.zeros_like(test_delta).numpy()\n", + "m = NN_MERA(n, d, lamb, NN_shape, stddev)\n", + "m.load_weights(\"NN-VQE.weights.h5\")\n", + "for i, de in tqdm(enumerate(test_delta)):\n", + " test_energies[i] = m(K.reshape(de, [1]))" + ] + }, + { + "cell_type": "markdown", + "id": "924027f8", + "metadata": {}, + "source": [ + "## 对比\n", + "\n", + "我们将NN-VQE的结果与解析解相比,计算基态能量相对误差。从图中可以看出,NN-VQE能够在无微调(fine-tuning)的情况下准确估计参数化哈密顿量的基态能量,且具有良好的泛化能力。" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "c8668a13", + "metadata": {}, + "outputs": [], + "source": [ + "analytical_energies = [] # 解析解\n", + "for i in test_delta:\n", + " h = quimb.tensor.tensor_builder.MPO_ham_XXZ(\n", + " n, i * 4, jxy=4.0, bz=2.0 * 0.75, S=0.5, cyclic=True\n", + " )\n", + " h = h.to_dense()\n", + " analytical_energies.append(np.min(quimb.eigvalsh(h)))" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "42799e3e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-08-06T23:44:26.082098\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.5.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# 基态能量相对误差\n", + "plt.plot(\n", + " test_delta,\n", + " (test_energies - analytical_energies) / np.abs(analytical_energies),\n", + " \"-\",\n", + " color=\"b\",\n", + ")\n", + "plt.xlabel(\"Delta\", fontsize=14)\n", + "plt.ylabel(\"GS Relative Error\", fontsize=14)\n", + "plt.axvspan(-3.0, 3.0, color=\"darkgrey\", alpha=0.5) # 训练集区间\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5f9bda8a", + "metadata": {}, + "source": [ + "想要获得更详细的信息或进一步的研究,请参考我们的[论文](https://arxiv.org/abs/2308.01068)和[GitHub](https://github.com/JachyMeow/NN-VQA)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "18d2a9923f839b0d86cf68fd09770e726264cf9d62311eaf57b1fff0ca4bed8e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}