Skip to content

Commit

Permalink
Prevent downcasting bands to fit previous band dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Aug 5, 2024
1 parent c5f86cb commit 4206d09
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/sknnr_spatial/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ def _load_rasters_to_dataset(


def _load_rasters_to_array(file_paths: list[Path]) -> NDArray:
"""Load a list of rasters as a numpy array."""
"""Load single-band rasters as a multi-band numpy array of shape (band, y, x)."""
arr = None
for i, path in enumerate(file_paths):
for path in file_paths:
with rasterio.open(path) as src:
band = src.read(1)
if arr is None:
arr = np.empty((len(file_paths), *band.shape), dtype=band.dtype)
# Add a band dimension to the array to allow concatenation
band = band[np.newaxis, ...]

arr[i] = band
arr = band if arr is None else np.concatenate((arr, band), axis=0)

return arr

Expand Down
33 changes: 33 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

import pickle
import sys
import tempfile
import warnings
from dataclasses import dataclass
from pathlib import Path
from unittest import mock

import numpy as np
import pytest
import rasterio
from numpy.testing import assert_array_almost_equal
from typing_extensions import Any

from sknnr_spatial.datasets import load_swo_ecoplot
from sknnr_spatial.datasets._base import _load_rasters_to_array


@dataclass
Expand Down Expand Up @@ -98,3 +105,29 @@ def test_load_dataset_missing_imports(missing_import):

with pytest.raises(ImportError, match=msg):
from sknnr_spatial.datasets import load_swo_ecoplot # noqa: F401


def test_load_rasters_promotes_dtype():
"""Test that loading rasters from paths promotes to the largest dtype."""
int_array = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8)
float_array = np.random.rand(10, 10).astype(np.float32)
expected_array = np.stack([int_array, float_array])

with tempfile.TemporaryDirectory() as tmpdir, warnings.catch_warnings():
# For simplicity, just ignore rasterio warnings about missing geotransforms
warnings.filterwarnings("ignore", message="Dataset has no geotransform")
int_path = Path(tmpdir) / "int.tif"
float_path = Path(tmpdir) / "float.tif"
meta = {"height": 10, "width": 10, "count": 1}

with rasterio.open(int_path, "w", dtype=np.uint8, **meta) as dst:
dst.write(int_array, 1)

with rasterio.open(float_path, "w", dtype=np.float32, **meta) as dst:
dst.write(float_array, 1)

array = _load_rasters_to_array([int_path, float_path])

assert array.dtype == np.float32
# Allow for small floating point errors during writing/reading
assert_array_almost_equal(array, expected_array)

0 comments on commit 4206d09

Please sign in to comment.