Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep object cache in memory throughout tests #54

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 70 additions & 35 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
import json
import warnings
from collections import UserDict
from pathlib import Path

import ee
import pytest

import eerepr

CACHE_DIR = Path(__file__).parent / "data/"
CACHE_PATH = CACHE_DIR / ".cache.json"

class ObjectCache(UserDict):
"""
A cache that reads and writes to a local JSON file.
"""

CACHE_DIR = Path(__file__).parent / "data/"
CACHE_PATH = CACHE_DIR / ".cache.json"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.modified = False

def __setitem__(self, key, value):
super().__setitem__(key, value)
self.modified = True

def __delitem__(self, key):
super().__delitem__(key)
self.modified = True

@classmethod
def load_from_disk(cls):
if not cls.CACHE_PATH.exists():
return ObjectCache()

with open(cls.CACHE_PATH) as src:
try:
return ObjectCache(json.load(src))
except json.JSONDecodeError:
warnings.warn("Failed to load cache file.", stacklevel=2)
return ObjectCache()

def write_to_disk(self):
if not self.CACHE_PATH.exists():
self.CACHE_DIR.mkdir(parents=True, exist_ok=True)
warnings.warn(
f"Creating new cache file at {self.CACHE_PATH}.", stacklevel=2
)

with open(self.CACHE_PATH, "w") as dst:
warnings.warn("Saving modified cache to disk.", stacklevel=2)
json.dump(self.data, dst)


@pytest.fixture(scope="session")
def object_cache():
"""Fixture to access the in-memory object cache for the session."""
session_cache = ObjectCache.load_from_disk()

yield session_cache

if session_cache.modified:
session_cache.write_to_disk()


def pytest_sessionstart(session):
Expand All @@ -29,7 +82,7 @@ def original_datadir():


@pytest.fixture(autouse=True)
def mock_getinfo(mocker):
def mock_getinfo(mocker, object_cache):
"""
Mock `getInfo` to load Earth Engine object info from a local cache, if available.

Expand All @@ -44,43 +97,25 @@ def mock_getinfo(mocker):
def get_cached_info(obj: ee.ComputedObject) -> dict:
serialized = obj.serialize()

CACHE_DIR.mkdir(parents=True, exist_ok=True)

try:
with open(CACHE_PATH) as src:
existing_data = json.load(src)

# File is missing or unreadable
except (FileNotFoundError, json.JSONDecodeError):
existing_data = {}
with open(CACHE_PATH, "w") as dst:
warnings.warn(f"Creating new cache file at {CACHE_PATH}.", stacklevel=2)
json.dump(existing_data, dst)

# File exists, but info does not
if serialized not in existing_data:
with open(CACHE_PATH, "w") as dst:
warnings.warn(
"Fetching and caching new object from server.", stacklevel=2
)
try:
info = get_server_info(obj)
except Exception as e:
# Errors should be stored to allow raising from the cache without
# needing to re-fetch the object
info = {"error": str(e)}

existing_data[serialized] = info
json.dump(existing_data, dst, indent=2)
if serialized not in object_cache:
warnings.warn("Fetching and caching new object from server.", stacklevel=2)
try:
info = get_server_info(obj)
except Exception as e:
# Errors should be stored to allow raising from the cache without
# needing to re-fetch the object
info = {"error": str(e)}

object_cache[serialized] = info

# Raise on error (including cached) to simulate getInfo behavior
if (
isinstance(existing_data[serialized], dict)
and "error" in existing_data[serialized]
isinstance(object_cache[serialized], dict)
and "error" in object_cache[serialized]
):
raise ee.EEException(existing_data[serialized]["error"])
raise ee.EEException(object_cache[serialized]["error"])

return existing_data[serialized]
return object_cache[serialized]

mocker.patch(
"ee.ComputedObject.getInfo", side_effect=get_cached_info, autospec=True
Expand Down
Loading