diff --git a/pypgstac/pypgstac/load.py b/pypgstac/pypgstac/load.py index 58830bdc..f7123044 100644 --- a/pypgstac/pypgstac/load.py +++ b/pypgstac/pypgstac/load.py @@ -24,6 +24,7 @@ import psycopg from orjson import JSONDecodeError from plpygis.geometry import Geometry +from pkg_resources import parse_version as V from psycopg import sql from psycopg.types.range import Range from smart_open import open @@ -161,10 +162,19 @@ def __init__(self, db: PgstacDB): self._partition_cache: Dict[str, Partition] = {} def check_version(self) -> None: - if self.db.version != __version__: + db_version = self.db.version + if db_version is None: + raise Exception("Failed to detect the target database version.") + + v1 = V(db_version) + v2 = V(__version__) + if (v1.major, v1.minor) != ( + v2.major, + v2.minor, + ): raise Exception( f"pypgstac version {__version__} is not compatible with the target" - f" database version {self.db.version}." + f" database version {db_version}." ) @lru_cache(maxsize=128) diff --git a/pypgstac/setup.py b/pypgstac/setup.py index 9bbf663f..3e701dc6 100644 --- a/pypgstac/setup.py +++ b/pypgstac/setup.py @@ -6,7 +6,7 @@ desc = f.read() install_requires = [ - "smart-open[html]>=4.2,<7.0", + "smart-open>=4.2,<7.0", "orjson>=3.5.2", "python-dateutil==2.8.*", "fire==0.4.*", @@ -22,6 +22,7 @@ "black>=21.7b0", "mypy>=0.910", "types-orjson==0.1.1", + "types-pkg-resources", "pystac[validation]==1.*" ], "psycopg": [ diff --git a/pypgstac/tests/test_load.py b/pypgstac/tests/test_load.py index 204e97f1..bb486053 100644 --- a/pypgstac/tests/test_load.py +++ b/pypgstac/tests/test_load.py @@ -2,7 +2,8 @@ import json from pathlib import Path from unittest import mock -from pypgstac.load import Methods, Loader, read_json +from pkg_resources import parse_version as V +from pypgstac.load import Methods, Loader, read_json, __version__ from psycopg.errors import UniqueViolation import pytest import pystac @@ -28,6 +29,11 @@ ) +def version_increment(source_version: str) -> str: + version = V(source_version) + return ".".join(map(str, [version.major, version.minor, version.micro + 1])) + + def test_load_collections_succeeds(loader: Loader) -> None: """Test pypgstac collections loader.""" loader.load_collections( @@ -357,6 +363,10 @@ def test_load_collections_incompatible_version(loader: Loader) -> None: def test_load_items_incompatible_version(loader: Loader) -> None: """Test pypgstac items loader raises an exception for incompatible version.""" + loader.load_collections( + str(TEST_COLLECTIONS_JSON), + insert_mode=Methods.insert, + ) with mock.patch( "pypgstac.db.PgstacDB.version", new_callable=mock.PropertyMock ) as mock_version: @@ -366,3 +376,19 @@ def test_load_items_incompatible_version(loader: Loader) -> None: str(TEST_ITEMS), insert_mode=Methods.insert, ) + + +def test_load_compatible_major_minor_version(loader: Loader) -> None: + """Test pypgstac loader doesn't raise an exception.""" + with mock.patch( + "pypgstac.load.__version__", version_increment(__version__) + ) as mock_version: + loader.load_collections( + str(TEST_COLLECTIONS_JSON), + insert_mode=Methods.insert, + ) + loader.load_items( + str(TEST_ITEMS), + insert_mode=Methods.insert, + ) + assert mock_version != loader.db.version