Skip to content

Commit

Permalink
Merge branch 'main' into template
Browse files Browse the repository at this point in the history
  • Loading branch information
12rambau authored Jan 10, 2025
2 parents ca22924 + aa03ffa commit 07a65a5
Show file tree
Hide file tree
Showing 17 changed files with 585 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"deprecated>=1.2.14",
"earthengine-api>=0.1.397", # new ee.data.createFolder method
"pytest",
"pytest-regressions",
"pytest-regressions>=2.7.0", # get the fullpath parameter in the Imageregression
"geopandas",
"pillow",
]
Expand Down
50 changes: 46 additions & 4 deletions pytest_gee/dictionary_regression.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Implementation of the ``dictionary_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
from pytest_regressions.data_regression import DataRegressionFixture

from .utils import round_data
from .utils import build_fullpath, check_serialized, round_data


class DictionaryFixture(DataRegressionFixture):
Expand All @@ -27,6 +28,47 @@ def check(
fullpath: complete path to use as a reference file. This option will ignore ``datadir`` fixture when reading *expected* files but will still use it to write *obtained* files. Useful if a reference file is located in the session data dir for example.
precision: The number of decimal places to round to when comparing floats.
"""
# round any float value before serving the data to the check function
data_dict = round_data(data_dict.getInfo(), prescision)
super().check(data_dict, basename=basename, fullpath=fullpath)
# build the different filename to be consistent between our 3 checks
data_name = build_fullpath(
datadir=self.original_datadir,
request=self.request,
extension=".yml",
basename=basename,
fullpath=fullpath,
with_test_class_names=self.with_test_class_names,
)

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
check_serialized(
object=ee.Dictionary(data_dict),
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
data = round_data(data_dict.getInfo(), prescision)
try:
super().check(data, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_dict,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
47 changes: 45 additions & 2 deletions pytest_gee/feature_collection_regression.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Implementation of the ``feature_collection_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
import geopandas as gpd
from pytest_regressions.data_regression import DataRegressionFixture

from .utils import round_data
from .utils import build_fullpath, check_serialized, round_data


class FeatureCollectionFixture(DataRegressionFixture):
Expand All @@ -33,6 +34,29 @@ def check(
if drop_index is True:
data_fc = data_fc.map(lambda f: f.select(f.propertyNames().remove("system:index")))

# build the different filename to be consistent between our 3 checks
data_name = build_fullpath(
datadir=self.original_datadir,
request=self.request,
extension=".yml",
basename=basename,
fullpath=fullpath,
with_test_class_names=self.with_test_class_names,
)

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
check_serialized(
object=data_fc,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# round the geometry using geopandas to make sre with use the specific number of decimal places
gdf = gpd.GeoDataFrame.from_features(data_fc.getInfo())
gdf.geometry = gdf.set_precision(grid_size=10 ** (-prescision)).remove_repeated_points()
Expand All @@ -41,4 +65,23 @@ def check(
data = gdf.to_geo_dict()
data = round_data(data, prescision)

super().check(data, basename=basename, fullpath=fullpath)
# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
try:
super().check(data, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_fc,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
64 changes: 55 additions & 9 deletions pytest_gee/image_regression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""implementation of the ``image_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
import requests
from pytest_regressions.image_regression import ImageRegressionFixture

from .utils import build_fullpath, check_serialized


class ImageFixture(ImageRegressionFixture):
"""Fixture for regression testing of :py:class:`ee.Image`."""
Expand All @@ -16,6 +19,7 @@ def check(
diff_threshold: float = 0.1,
expect_equal: bool = True,
basename: Optional[str] = None,
fullpath: Optional[os.PathLike] = None,
scale: Optional[int] = 30,
viz_params: Optional[dict] = None,
):
Expand All @@ -32,27 +36,69 @@ def check(
diff_threshold: The threshold for the difference between the expected and obtained images.
expect_equal: If ``True`` the images are expected to be equal, otherwise they are expected to be different.
basename: The basename of the file to test/record. If not given the name of the test is used.
fullpath: complete path to use as a reference file. This option will ignore ``datadir`` fixture when reading *expected* files but will still use it to write *obtained* files. Useful if a reference file is located in the session data dir for example.
scale: The scale to use for the thumbnail.
viz_params: The visualization parameters to use for the thumbnail. If not given, the min and max values of the image will be used.
"""
# grescale the original image
# rescale the original image
geometry = data_image.geometry()
image = data_image.clipToBoundsAndScale(geometry, scale=scale)
data_image = data_image.clipToBoundsAndScale(geometry, scale=scale)

# build the different filename to be consistent between our 3 checks
data_name = build_fullpath(
datadir=self.original_datadir,
request=self.request,
extension=".png",
basename=basename,
fullpath=fullpath,
with_test_class_names=self.with_test_class_names,
)

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
check_serialized(
object=data_image,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# extract min and max for visualization
minMax = image.reduceRegion(ee.Reducer.minMax(), geometry, scale)
minMax = data_image.reduceRegion(ee.Reducer.minMax(), geometry, scale)

# create visualization parameters based on the computed minMax values
if viz_params is None:
nbBands = ee.Algorithms.If(image.bandNames().size().gte(3), 3, 1)
bands = image.bandNames().slice(0, ee.Number(nbBands))
nbBands = ee.Algorithms.If(data_image.bandNames().size().gte(3), 3, 1)
bands = data_image.bandNames().slice(0, ee.Number(nbBands))
min = bands.map(lambda b: minMax.get(ee.String(b).cat("_min")))
max = bands.map(lambda b: minMax.get(ee.String(b).cat("_max")))
viz_params = ee.Dictionary({"bands": bands, "min": min, "max": max}).getInfo()

# get the thumbnail image
thumb_url = image.getThumbURL(params=viz_params)
thumb_url = data_image.getThumbURL(params=viz_params)
byte_data = requests.get(thumb_url).content

# call the parent check method
super().check(byte_data, diff_threshold, expect_equal, basename=basename)
# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
try:
super().check(byte_data, diff_threshold, expect_equal, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_image,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
50 changes: 46 additions & 4 deletions pytest_gee/list_regression.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Implementation of the ``list_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
from pytest_regressions.data_regression import DataRegressionFixture

from .utils import round_data
from .utils import build_fullpath, check_serialized, round_data


class ListFixture(DataRegressionFixture):
Expand All @@ -27,6 +28,47 @@ def check(
fullpath: complete path to use as a reference file. This option will ignore ``datadir`` fixture when reading *expected* files but will still use it to write *obtained* files. Useful if a reference file is located in the session data dir for example.
precision: The number of decimal places to round to when comparing floats.
"""
# round any float value before serving the data to the check function
data_list = round_data(data_list.getInfo(), prescision)
super().check(data_list, basename=basename, fullpath=fullpath)
# build the different filename to be consistent between our 3 checks
data_name = build_fullpath(
datadir=self.original_datadir,
request=self.request,
extension=".yml",
basename=basename,
fullpath=fullpath,
with_test_class_names=self.with_test_class_names,
)

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
check_serialized(
object=data_list,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
data = round_data(data_list.getInfo(), prescision)
try:
super().check(data, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_list,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
Loading

0 comments on commit 07a65a5

Please sign in to comment.