Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Commit

Permalink
Fix qgan.py where batch size is greater than data size to raise error (
Browse files Browse the repository at this point in the history
…#1115)

* Update qgan.py

Raise error if the number of the training data items is smaller than the batch size, e.g. due to truncation to bounds.

Co-authored-by: Steve Wood <40241007+woodsp-ibm@users.noreply.github.com>
  • Loading branch information
Zoufalc and woodsp-ibm authored Jul 16, 2020
1 parent d79ef30 commit 655bdf3
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions qiskit/aqua/algorithms/distribution_learners/qgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def _store_params(self, e, d_loss, g_loss, rel_entr):
def train(self):
"""
Train the qGAN
Raises:
AquaError: Batch size bigger than the number of items in the truncated data set
"""
if self._snapshot_dir is not None:
with open(os.path.join(self._snapshot_dir, 'output.csv'), mode='w') as csv_file:
Expand All @@ -278,6 +281,11 @@ def train(self):
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()

if len(self._data) < self._batch_size:
raise AquaError(
'The batch size needs to be less than the '
'truncated data size of {}'.format(len(self._data)))

for e in range(self._num_epochs):
aqua_globals.random.shuffle(self._data)
index = 0
Expand Down

0 comments on commit 655bdf3

Please sign in to comment.