Skip to content

Commit

Permalink
Merge pull request #335 from Living-with-machines/331-hwc-bug
Browse files Browse the repository at this point in the history
Fix patchify for single band images
  • Loading branch information
rwood-97 authored Jan 5, 2024
2 parents 7eb659e + e02f857 commit 2efad70
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 11 deletions.
40 changes: 29 additions & 11 deletions mapreader/load/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from glob import glob
from typing import Literal

import matplotlib.image as mpimg
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -472,9 +471,12 @@ def show_sample(
plt.subplot(num_samples // 3 + 1, 3, i + 1)
img = Image.open(self.images[tree_level][image_id]["image_path"])
plt.title(image_id, size=8)
plt.imshow(
img,
)

# check if grayscale
if len(img.getbands()) == 1:
plt.imshow(img, cmap="gray", vmin=0, vmax=255)
else:
plt.imshow(img)
plt.xticks([])
plt.yticks([])

Expand Down Expand Up @@ -672,10 +674,13 @@ def _add_shape_id(
tree_level = self._get_tree_level(image_id)

try:
myimg = mpimg.imread(self.images[tree_level][image_id]["image_path"])
img = Image.open(self.images[tree_level][image_id]["image_path"])
# shape = (hwc)
myimg_shape = myimg.shape
self.images[tree_level][image_id]["shape"] = myimg_shape
height = img.height
width = img.width
channels = len(img.getbands())

self.images[tree_level][image_id]["shape"] = (height, width, channels)
except OSError:
raise ValueError(
f'[ERROR] Problem with "{image_id}". Please either redownload or remove from list of images to load.'
Expand Down Expand Up @@ -1485,7 +1490,12 @@ def show(

fig = plt.figure(figsize=figsize)
plt.axis("off")
plt.imshow(img, zorder=1)

# check if grayscale
if len(img.getbands()) == 1:
plt.imshow(img, cmap="gray", vmin=0, vmax=255, zorder=1)
else:
plt.imshow(img, zorder=1)

if column_to_plot:
print(
Expand Down Expand Up @@ -1603,7 +1613,11 @@ def show(
parent_path = parent_images[parent_id]["image_path"]
parent_image = Image.open(parent_path)

ax.imshow(parent_image)
# check if grayscale
if len(parent_image.getbands()) == 1:
ax.imshow(parent_image, cmap="gray", vmin=0, vmax=255)
else:
ax.imshow(parent_image)

if save_kml_dir:
os.makedirs(save_kml_dir, exist_ok=True)
Expand Down Expand Up @@ -2322,7 +2336,6 @@ def _save_patch_as_geotiff(

patch_affine = rasterio.transform.from_bounds(*coords, width, height)
patch = Image.open(patch_path)
patch_array = reshape_as_raster(patch)

with rasterio.open(
f"{geotiff_path}",
Expand All @@ -2336,7 +2349,12 @@ def _save_patch_as_geotiff(
nodata=0,
crs=crs,
) as dst:
dst.write(patch_array)
if len(patch.getbands()) == 1:
patch_array = np.array(patch)
dst.write(patch_array, indexes=1)
else:
patch_array = reshape_as_raster(patch)
dst.write(patch_array)

def save_patches_to_geojson(
self,
Expand Down
Binary file added tests/sample_files/cropped_L.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions tests/test_load/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ def test_init_png(sample_dir, image_id):
len(maps)


def test_init_png_grayscale(sample_dir):
image_id = "cropped_L.png"
maps = MapImages(f"{sample_dir}/{image_id}")
assert len(maps.list_parents()) == 1
assert len(maps.list_patches()) == 0
assert isinstance(maps, MapImages)
maps.add_shape()
assert maps.parents[image_id]["shape"] == (9, 9, 1)


def test_init_tiff(sample_dir):
image_id = "cropped_non_geo.tif"
tiffs = MapImages(f"{sample_dir}/{image_id}")
Expand Down Expand Up @@ -429,6 +439,17 @@ def test_patchify_meters(sample_dir, image_id, tmp_path):
assert len(maps.list_patches()) == 25


def test_patchify_grayscale(sample_dir, tmp_path):
image_id = "cropped_L.png"
maps = MapImages(f"{sample_dir}/{image_id}")
maps.patchify_all(patch_size=3, path_save=tmp_path)
parent_list = maps.list_parents()
patch_list = maps.list_patches()
assert len(parent_list) == 1
assert len(patch_list) == 9
assert os.path.isfile(f"{tmp_path}/patch-0-0-3-3-#{image_id}#.png")


def test_patchify_meters_errors(sample_dir, image_id, tmp_path):
maps = MapImages(f"{sample_dir}/{image_id}")
with pytest.raises(ValueError, match="add coordinate information"):
Expand Down Expand Up @@ -569,6 +590,22 @@ def test_add_patch_polygons(init_maps):
def test_save_patches_as_geotiffs(init_maps):
maps, _, _ = init_maps
maps.save_patches_as_geotiffs()
patch_id = maps.list_patches()[0]
assert "geotiff_path" in maps.patches[patch_id].keys()
assert os.path.isfile(maps.patches[patch_id]["geotiff_path"])


def test_save_patches_as_geotiffs_grayscale(sample_dir, tmp_path):
image_id = "cropped_L.png"
maps = MapImages(f"{sample_dir}/{image_id}")
metadata = pd.read_csv(f"{sample_dir}/ts_downloaded_maps.csv", index_col=0)
metadata.loc[0, "name"] = "cropped_L.png"
maps.add_metadata(metadata)
maps.patchify_all(patch_size=3, path_save=tmp_path)
maps.save_patches_as_geotiffs()
patch_id = maps.list_patches()[0]
assert "geotiff_path" in maps.patches[patch_id].keys()
assert os.path.isfile(maps.patches[patch_id]["geotiff_path"])


def test_save_to_geojson(init_maps, tmp_path, capfd):
Expand Down Expand Up @@ -677,3 +714,14 @@ def test_save_parents_as_geotiffs(init_maps, sample_dir, image_id):
maps.save_parents_as_geotiffs()
image_id = image_id.split(".")[0]
assert os.path.isfile(f"{sample_dir}/{image_id}.tif")


def test_save_parents_as_geotiffs_grayscale(sample_dir, tmp_path):
image_id = "cropped_L.png"
maps = MapImages(f"{sample_dir}/{image_id}")
metadata = pd.read_csv(f"{sample_dir}/ts_downloaded_maps.csv", index_col=0)
metadata.loc[0, "name"] = "cropped_L.png"
maps.add_metadata(metadata)
maps.save_parents_as_geotiffs()
assert "geotiff_path" in maps.parents[image_id].keys()
assert os.path.isfile(maps.parents[image_id]["geotiff_path"])

0 comments on commit 2efad70

Please sign in to comment.