Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Fix qgan.py where batch size is greater than data size to raise error #1115

Merged
merged 4 commits into from
Jul 16, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions qiskit/aqua/algorithms/distribution_learners/qgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def _store_params(self, e, d_loss, g_loss, rel_entr):
def train(self):
"""
Train the qGAN

Raises:
AquaError: Batch size bigger than the number of items in the truncated data set
"""
if self._snapshot_dir is not None:
with open(os.path.join(self._snapshot_dir, 'output.csv'), mode='w') as csv_file:
Expand All @@ -278,6 +281,9 @@ def train(self):
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()

if len(self._data) < self._batch_size:
raise AquaError('Please reduce the batch size.')
woodsp-ibm marked this conversation as resolved.
Show resolved Hide resolved

for e in range(self._num_epochs):
aqua_globals.random.shuffle(self._data)
index = 0
Expand Down