Skip to content

Commit

Permalink
Merge pull request mmistakes#93 from pesser/bbatches
Browse files Browse the repository at this point in the history
Bbatches
  • Loading branch information
pesser authored Jul 5, 2019
2 parents 03c2d75 + 36e369a commit b760117
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
12 changes: 8 additions & 4 deletions edflow/iterators/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def tile(X, rows, cols):
return tiling


def plot_batch(X, out_path):
def plot_batch(X, out_path, cols=None):
"""Save batch of images tiled."""
canvas = batch_to_canvas(X)
canvas = batch_to_canvas(X, cols)
save_image(canvas, out_path)


def batch_to_canvas(X):
def batch_to_canvas(X, cols=None):
"""convert batch of images to canvas"""
if len(X.shape) == 5:
# tile
Expand All @@ -75,7 +75,11 @@ def batch_to_canvas(X):
if n_channels == 1:
X = np.tile(X, [1, 1, 1, 3])
rc = math.sqrt(X.shape[0])
rows = cols = math.ceil(rc)
if cols is None:
rows = cols = math.ceil(rc)
else:
cols = max(1, cols)
rows = math.ceil(X.shape[0] / cols)
canvas = tile(X, rows, cols)
return canvas

Expand Down
21 changes: 21 additions & 0 deletions tests/test_iterators/test_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import numpy as np
from edflow.iterators import batches


def test_batch_to_canvas():
x = np.ones((9, 100, 100, 3))
canvas = batches.batch_to_canvas(x)
assert canvas.shape == (300, 300, 3)

canvas = batches.batch_to_canvas(x, cols=5)
assert canvas.shape == (200, 500, 3)

canvas = batches.batch_to_canvas(x, cols=1)
assert canvas.shape == (900, 100, 3)

canvas = batches.batch_to_canvas(x, cols=0)
assert canvas.shape == (900, 100, 3)

canvas = batches.batch_to_canvas(x, cols=None)
assert canvas.shape == (300, 300, 3)

0 comments on commit b760117

Please sign in to comment.