From 655bdf3cffdd4dcb7e76266b50f5b9effae50773 Mon Sep 17 00:00:00 2001 From: Zoufalc <40824883+Zoufalc@users.noreply.github.com> Date: Fri, 17 Jul 2020 01:52:07 +0200 Subject: [PATCH] Fix qgan.py where batch size is greater than data size to raise error (#1115) * Update qgan.py Raise error if the number of the training data items is smaller than the batch size, e.g. due to truncation to bounds. Co-authored-by: Steve Wood <40241007+woodsp-ibm@users.noreply.github.com> --- qiskit/aqua/algorithms/distribution_learners/qgan.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/qiskit/aqua/algorithms/distribution_learners/qgan.py b/qiskit/aqua/algorithms/distribution_learners/qgan.py index 245ddf23ab..1f75138894 100644 --- a/qiskit/aqua/algorithms/distribution_learners/qgan.py +++ b/qiskit/aqua/algorithms/distribution_learners/qgan.py @@ -270,6 +270,9 @@ def _store_params(self, e, d_loss, g_loss, rel_entr): def train(self): """ Train the qGAN + + Raises: + AquaError: Batch size bigger than the number of items in the truncated data set """ if self._snapshot_dir is not None: with open(os.path.join(self._snapshot_dir, 'output.csv'), mode='w') as csv_file: @@ -278,6 +281,11 @@ def train(self): writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer.writeheader() + if len(self._data) < self._batch_size: + raise AquaError( + 'The batch size needs to be less than the ' + 'truncated data size of {}'.format(len(self._data))) + for e in range(self._num_epochs): aqua_globals.random.shuffle(self._data) index = 0