Skip to content

Commit

Permalink
batches now supports cols argument to make non-square batch plots
Browse files Browse the repository at this point in the history
  • Loading branch information
theRealSuperMario committed Jul 4, 2019
1 parent cdfcc47 commit 36e369a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions edflow/iterators/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def tile(X, rows, cols):
return tiling


def plot_batch(X, out_path):
def plot_batch(X, out_path, cols=None):
"""Save batch of images tiled."""
canvas = batch_to_canvas(X)
canvas = batch_to_canvas(X, cols)
save_image(canvas, out_path)


def batch_to_canvas(X):
def batch_to_canvas(X, cols=None):
"""convert batch of images to canvas"""
if len(X.shape) == 5:
# tile
Expand All @@ -75,7 +75,11 @@ def batch_to_canvas(X):
if n_channels == 1:
X = np.tile(X, [1, 1, 1, 3])
rc = math.sqrt(X.shape[0])
rows = cols = math.ceil(rc)
if cols is None:
rows = cols = math.ceil(rc)
else:
cols = max(1, cols)
rows = math.ceil(X.shape[0] / cols)
canvas = tile(X, rows, cols)
return canvas

Expand Down

0 comments on commit 36e369a

Please sign in to comment.