diff --git a/sno/base_dataset.py b/sno/base_dataset.py index 56325600f..0a25f0592 100644 --- a/sno/base_dataset.py +++ b/sno/base_dataset.py @@ -9,6 +9,10 @@ class BaseDataset(ImportSource): """ Common interface for all datasets - mainly Dataset2, but there is also Dataset0 and Dataset1 used by `sno upgrade`. + + A Dataset instance is immutable since it is a view of a particular git tree. + To get a new version of a dataset, commit the desired changes, + then instantiate a new Dataset instance that references the new git tree. """ # Constants that subclasses should generally define. diff --git a/sno/checkout.py b/sno/checkout.py index 2031fc6c2..587a49011 100644 --- a/sno/checkout.py +++ b/sno/checkout.py @@ -20,7 +20,7 @@ def reset_wc_if_needed(repo, target_tree_or_commit, *, discard_changes=False): """Resets the working copy to the target if it does not already match, or if discard_changes is True.""" - working_copy = WorkingCopy.get(repo, allow_uncreated=True) + working_copy = WorkingCopy.get(repo, allow_uncreated=True, allow_invalid_state=True) if working_copy is None: click.echo( "(Bare sno repository - to create a working copy, use `sno create-workingcopy`)" @@ -378,7 +378,7 @@ def create_workingcopy(ctx, discard_changes, wc_path): wc_path = WorkingCopy.default_path(repo.workdir_path) if wc_path != old_wc_path: - WorkingCopy.check_valid_creation_path(repo.workdir_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo.workdir_path) # Finished sanity checks - start work: if old_wc and wc_path != old_wc_path: diff --git a/sno/clone.py b/sno/clone.py index c598746ad..63574566a 100644 --- a/sno/clone.py +++ b/sno/clone.py @@ -106,7 +106,7 @@ def clone( if repo_path.exists() and any(repo_path.iterdir()): raise InvalidOperation(f'"{repo_path}" isn\'t empty', param_hint="directory") - WorkingCopy.check_valid_creation_path(repo_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_path) if not repo_path.exists(): repo_path.mkdir(parents=True) diff --git a/sno/geometry.py b/sno/geometry.py index 6caf3d860..8873c8ddd 100644 --- a/sno/geometry.py +++ b/sno/geometry.py @@ -57,6 +57,16 @@ def with_crs_id(self, crs_id): crs_id_bytes = struct.pack(">> URI_SCHEME::[HOST]/DBNAME/DBSCHEMA + """ + url = urlsplit(db_uri) + + if url.scheme != cls.URI_SCHEME: + raise click.UsageError( + f"Invalid {cls.WORKING_COPY_TYPE_NAME} URI - " + f"Expecting URI in form: {cls.URI_SCHEME}://[HOST]/DBNAME/DBSCHEMA" + ) + + url_path = url.path + path_parts = url_path[1:].split("/", 3) if url_path else [] + + suggestion_message = "" + if len(path_parts) == 1 and workdir_path is not None: + suggested_path = f"/{path_parts[0]}/{cls.default_db_schema(workdir_path)}" + suggested_uri = urlunsplit( + [url.scheme, url.netloc, suggested_path, url.query, ""] + ) + suggestion_message = f"\nFor example: {suggested_uri}" + + if len(path_parts) != 2: + raise click.UsageError( + f"Invalid {cls.WORKING_COPY_TYPE_NAME} URI - URI requires both database name and database schema:\n" + f"Expecting URI in form: {cls.URI_SCHEME}://[HOST]/DBNAME/DBSCHEMA" + + suggestion_message + ) + + @classmethod + def _separate_db_schema(cls, db_uri): + """ + Removes the DBSCHEMA part off the end of a uri in the form URI_SCHEME::[HOST]/DBNAME/DBSCHEMA - + and returns the URI and the DBSCHEMA separately. + Useful since generally, URI_SCHEME::[HOST]/DBNAME is what is needed to connect to the database, + and then DBSCHEMA must be specified in each query. + """ + url = urlsplit(db_uri) + url_path = url.path + path_parts = url_path[1:].split("/", 3) if url_path else [] + assert len(path_parts) == 2 + url_path = "/" + path_parts[0] + db_schema = path_parts[1] + return urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]), db_schema + + @classmethod + def default_db_schema(cls, workdir_path): + """Returns a suitable default database schema - named after the folder this Sno repo is in.""" + stem = workdir_path.stem + schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" + if schema[0].isdigit(): + schema = "_" + schema + return schema + + @property + @functools.lru_cache(maxsize=1) + def DB_SCHEMA(self): + """Escaped, dialect-specific name of the database-schema owned by this working copy (if any).""" + if self.db_schema is None: + raise RuntimeError("No schema to escape.") + return self.preparer.format_schema(self.db_schema) diff --git a/sno/working_copy/gpkg.py b/sno/working_copy/gpkg.py index 3166f68a9..2a0846cd8 100644 --- a/sno/working_copy/gpkg.py +++ b/sno/working_copy/gpkg.py @@ -7,9 +7,10 @@ import click from osgeo import gdal -import sqlalchemy +import sqlalchemy as sa from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.compiler import IdentifierPreparer +from sqlalchemy.types import UserDefinedType from . import gpkg_adapter @@ -25,6 +26,15 @@ class WorkingCopy_GPKG(WorkingCopy): + """ + GPKG working copy implementation. + + Requirements: + 1. Can read and write to the filesystem at the specified path. + """ + + WORKING_COPY_TYPE_NAME = "GPKG" + def __init__(self, repo, path): self.repo = repo self.path = path @@ -36,34 +46,34 @@ def __init__(self, repo, path): self.sno_tables = GpkgSnoTables @classmethod - def check_valid_creation_path(cls, workdir_path, path): - cls.check_valid_path(workdir_path, path) + def check_valid_creation_path(cls, wc_path, workdir_path=None): + cls.check_valid_path(wc_path, workdir_path) - gpkg_path = (workdir_path / path).resolve() + gpkg_path = (workdir_path / wc_path).resolve() if gpkg_path.exists(): desc = "path" if gpkg_path.is_dir() else "GPKG file" raise InvalidOperation( - f"Error creating GPKG working copy at {path} - {desc} already exists" + f"Error creating GPKG working copy at {wc_path} - {desc} already exists" ) @classmethod - def check_valid_path(cls, workdir_path, path): - if not str(path).endswith(".gpkg"): - suggested_path = f"{os.path.splitext(str(path))[0]}.gpkg" + def check_valid_path(cls, wc_path, workdir_path=None): + if not str(wc_path).endswith(".gpkg"): + suggested_path = f"{os.path.splitext(str(wc_path))[0]}.gpkg" raise click.UsageError( f"Invalid GPKG path - expected .gpkg suffix, eg {suggested_path}" ) @classmethod - def normalise_path(cls, repo, path): + def normalise_path(cls, repo, wc_path): """Rewrites a relative path (relative to the current directory) as relative to the repo.workdir_path.""" - path = Path(path) - if not path.is_absolute(): + wc_path = Path(wc_path) + if not wc_path.is_absolute(): try: - return str(path.resolve().relative_to(repo.workdir_path.resolve())) + return str(wc_path.resolve().relative_to(repo.workdir_path.resolve())) except ValueError: pass - return str(path) + return str(wc_path) @property def full_path(self): @@ -75,9 +85,28 @@ def _quoted_trigger_name(self, dataset, trigger_type): # but changing it means migrating working copies, unfortunately. return self.quote(f"gpkg_sno_{dataset.table_name}_{trigger_type}") - def insert_or_replace_into_dataset(self, dataset): + def _insert_or_replace_into_dataset(self, dataset): # SQLite optimisation. - return self.table_def_for_dataset(dataset).insert().prefix_with("OR REPLACE") + return self._table_def_for_dataset(dataset).insert().prefix_with("OR REPLACE") + + def _table_def_for_column_schema(self, col, dataset): + if col.data_type == "geometry": + # This user-defined GeometryType normalises GPKG geometry to the Sno V2 GPKG geometry. + return sa.column(col.name, GeometryType) + else: + # Don't need to specify type information for other columns at present, since we just pass through the values. + return sa.column(col.name) + + def _insert_or_replace_state_table_tree(self, sess, tree_id): + r = sess.execute( + self.sno_tables.sno_state.insert().prefix_with("OR REPLACE"), + { + "table_name": "*", + "key": "tree", + "value": tree_id, + }, + ) + return r.rowcount @contextlib.contextmanager def session(self, bulk=0): @@ -99,37 +128,36 @@ def session(self, bulk=0): # - do something consistent and safe from then on. if hasattr(self, "_session"): - # inner - reuse + # Inner call - reuse existing session. L.debug(f"session(bulk={bulk}): existing...") yield self._session L.debug(f"session(bulk={bulk}): existing/done") + return - else: - L.debug(f"session(bulk={bulk}): new...") + # Outer call - create new session: + L.debug(f"session(bulk={bulk}): new...") + self._session = self.sessionmaker() - try: - self._session = self.sessionmaker() - - if bulk: - self._session.execute("PRAGMA synchronous = OFF;") - self._session.execute( - "PRAGMA cache_size = -1048576;" - ) # -KiB => 1GiB - if bulk >= 2: - self._session.execute("PRAGMA journal_mode = MEMORY;") - self._session.execute("PRAGMA locking_mode = EXCLUSIVE;") - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - self._session.execute("BEGIN TRANSACTION;") - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug(f"session(bulk={bulk}): new/done") + try: + if bulk: + self._session.execute("PRAGMA synchronous = OFF;") + self._session.execute("PRAGMA cache_size = -1048576;") # -KiB => 1GiB + if bulk >= 2: + self._session.execute("PRAGMA journal_mode = MEMORY;") + self._session.execute("PRAGMA locking_mode = EXCLUSIVE;") + + # TODO - use tidier syntax for opening transactions from sqlalchemy. + self._session.execute("BEGIN TRANSACTION;") + yield self._session + self._session.commit() + + except Exception: + self._session.rollback() + raise + finally: + self._session.close() + del self._session + L.debug(f"session(bulk={bulk}): new/done") def delete(self, keep_db_schema_if_possible=False): """Delete the working copy files.""" @@ -408,12 +436,15 @@ def _delete_meta_metadata(self, sess, table_name): """DELETE FROM gpkg_metadata WHERE id IN :ids;""", ) for sql in sqls: - stmt = sqlalchemy.text(sql).bindparams( - sqlalchemy.bindparam("ids", expanding=True) - ) + stmt = sa.text(sql).bindparams(sa.bindparam("ids", expanding=True)) sess.execute(stmt, {"ids": ids}) - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_pre(self, sess, dataset): + # Implementing only _create_spatial_index_pre: + # gpkgAddSpatialIndex has to be called before writing any features, + # since it only adds on-write triggers to update the index - it doesn't + # add any pre-existing features to the index. + L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") geom_col = dataset.geom_column_name @@ -460,56 +491,42 @@ def _create_triggers(self, sess, dataset): table_identifier = self.table_identifier(dataset) pk_column = self.quote(dataset.primary_key) - # SQLite doesn't let you do param substitutions in CREATE TRIGGER: - escaped_table_name = dataset.table_name.replace("'", "''") - - sess.execute( + insert_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'ins')} AFTER INSERT ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES ('{escaped_table_name}', NEW.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, NEW.{pk_column}); END; """ ) - sess.execute( + update_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'upd')} AFTER UPDATE ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES - ('{escaped_table_name}', NEW.{pk_column}), - ('{escaped_table_name}', OLD.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, NEW.{pk_column}), (:table_name, OLD.{pk_column}); END; """ ) - sess.execute( + delete_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'del')} AFTER DELETE ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES - ('{escaped_table_name}', OLD.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, OLD.{pk_column}); END; """ ) - - def _db_geom_to_gpkg_geom(self, g): - # Its possible in GPKG to put arbitrary values in columns, regardless of type. - # We don't try to convert them here - we let the commit validation step report this as an error. - if not isinstance(g, bytes): - return g - # We normalise geometries to avoid spurious diffs - diffs where nothing - # of any consequence has changed (eg, only endianness has changed). - # This includes setting the SRID to zero for each geometry so that we don't store a separate SRID per geometry, - # but only one per column at most. - return normalise_gpkg_geom(g) + for trigger in (insert_trigger, update_trigger, delete_trigger): + # Placeholders not allowed in CREATE TRIGGER - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + trigger.bindparams(table_name=dataset.table_name).compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() def _is_meta_update_supported(self, dataset_version, meta_diff): """ @@ -603,6 +620,21 @@ def _apply_meta_metadata_dataset_json(self, sess, dataset, src_value, dest_value def _update_last_write_time(self, sess, dataset, commit=None): self._update_gpkg_contents(sess, dataset, commit) + def _get_geom_extent(self, sess, dataset, default=None): + """Returns the envelope around the entire dataset as (min_x, min_y, max_x, max_y).""" + # FIXME: Why doesn't Extent(geom) work here as an aggregate? + geom_col = dataset.geom_column_name + r = sess.execute( + f""" + WITH _E AS ( + SELECT Extent({self.quote(geom_col)}) AS extent FROM {self.table_identifier(dataset)} + ) + SELECT ST_MinX(extent), ST_MinY(extent), ST_MaxX(extent), ST_MaxY(extent) FROM _E; + """ + ) + result = r.fetchone() + return default if result == (None, None, None, None) else result + def _update_gpkg_contents(self, sess, dataset, commit=None): """ Update the metadata for the given table in gpkg_contents to have the new bounding-box / last-updated timestamp. @@ -614,17 +646,11 @@ def _update_gpkg_contents(self, sess, dataset, commit=None): # GPKG Spec Req. 15: gpkg_change_time = change_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - table_identifer = self.table_identifier(dataset) geom_col = dataset.geom_column_name if geom_col is not None: - # FIXME: Why doesn't Extent(geom) work here as an aggregate? - r = sess.execute( - f""" - WITH _E AS (SELECT extent({self.quote(geom_col)}) AS extent FROM {table_identifer}) - SELECT ST_MinX(extent), ST_MinY(extent), ST_MaxX(extent), ST_MaxY(extent) FROM _E - """ + min_x, min_y, max_x, max_y = self._get_geom_extent( + sess, dataset, default=(None, None, None, None) ) - min_x, min_y, max_x, max_y = r.fetchone() rc = sess.execute( """ UPDATE gpkg_contents @@ -646,3 +672,21 @@ def _update_gpkg_contents(self, sess, dataset, commit=None): {"last_change": gpkg_change_time, "table_name": dataset.table_name}, ).rowcount assert rc == 1, f"gpkg_contents update: expected 1Δ, got {rc}" + + +class GeometryType(UserDefinedType): + """UserDefinedType so that GPKG geometry is normalised to V2 format.""" + + def result_processor(self, dialect, coltype): + def process(gpkg_bytes): + # Its possible in GPKG to put arbitrary values in columns, regardless of type. + # We don't try to convert them here - we let the commit validation step report this as an error. + if not isinstance(gpkg_bytes, bytes): + return gpkg_bytes + # We normalise geometries to avoid spurious diffs - diffs where nothing + # of any consequence has changed (eg, only endianness has changed). + # This includes setting the SRID to zero for each geometry so that we don't store a separate SRID per geometry, + # but only one per column at most. + return normalise_gpkg_geom(gpkg_bytes) + + return process diff --git a/sno/working_copy/gpkg_adapter.py b/sno/working_copy/gpkg_adapter.py index 08efdbb1a..f527ec5ca 100644 --- a/sno/working_copy/gpkg_adapter.py +++ b/sno/working_copy/gpkg_adapter.py @@ -377,7 +377,7 @@ def _column_schema_to_gpkg(cid, column_schema, has_geometry): } -# Types that can't be roundtrip perfectly in GPKG, and what they end up as. +# Types that can't be roundtripped perfectly in GPKG, and what they end up as. APPROXIMATED_TYPES = {"interval": "text", "time": "text", "numeric": "text"} diff --git a/sno/working_copy/postgis.py b/sno/working_copy/postgis.py index 3723c4d31..975045fdf 100644 --- a/sno/working_copy/postgis.py +++ b/sno/working_copy/postgis.py @@ -1,36 +1,39 @@ import contextlib import logging -import re import time -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit -import click + +from sqlalchemy import Index +from sqlalchemy.dialects.postgresql import insert as postgresql_insert from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.orm import sessionmaker -from .base import WorkingCopy from . import postgis_adapter +from .db_server import DatabaseServer_WorkingCopy from .table_defs import PostgisSnoTables from sno import crs_util -from sno.exceptions import InvalidOperation from sno.schema import Schema from sno.sqlalchemy import postgis_engine -""" -* database needs to exist -* database needs to have postgis enabled -* database user needs to be able to: - 1. create 'sno' schema & tables - 2. create & alter tables in the default (or specified) schema - 3. create triggers -""" +class WorkingCopy_Postgis(DatabaseServer_WorkingCopy): + """ + PosttGIS working copy implementation. -L = logging.getLogger("sno.working_copy.postgis") + Requirements: + 1. The database needs to exist + 2. If the dataset has geometry, then PostGIS (https://postgis.net/) v2.4 or newer needs + to be installed into the database and available in the database user's search path + 3. The database user needs to be able to: + - Create the specified schema (unless it already exists). + - Create, delete and alter tables and triggers in the specified schema. + """ + WORKING_COPY_TYPE_NAME = "PostGIS" + URI_SCHEME = "postgresql" -class WorkingCopy_Postgis(WorkingCopy): def __init__(self, repo, uri): """ uri: connection string of the form postgresql://[user[:password]@][netloc][:port][/dbname/schema][?param1=value1&...] @@ -41,81 +44,18 @@ def __init__(self, repo, uri): self.uri = uri self.path = uri - url = urlsplit(uri) - - if url.scheme != "postgresql": - raise ValueError("Expecting postgresql://") - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - if len(path_parts) != 2: - raise ValueError("Expecting postgresql://[HOST]/DBNAME/SCHEMA") - url_path = f"/{path_parts[0]}" - self.db_schema = path_parts[1] - - url_query = url.query - if "fallback_application_name" not in url_query: - url_query = "&".join( - filter(None, [url_query, "fallback_application_name=sno"]) - ) + self.check_valid_db_uri(uri) + self.db_uri, self.db_schema = self._separate_db_schema(uri) - # Rebuild DB URL suitable for postgres - self.dburl = urlunsplit([url.scheme, url.netloc, url_path, url_query, ""]) - self.engine = postgis_engine(self.dburl) + self.engine = postgis_engine(self.db_uri) self.sessionmaker = sessionmaker(bind=self.engine) self.preparer = IdentifierPreparer(self.engine.dialect) self.sno_tables = PostgisSnoTables(self.db_schema) @classmethod - def check_valid_creation_path(cls, workdir_path, path): - cls.check_valid_path(workdir_path, path) - postgis_wc = cls(None, path) - - # Less strict on Postgis - we are okay with the schema being already created, so long as its empty. - if postgis_wc.has_data(): - raise InvalidOperation( - f"Error creating Postgis working copy at {path} - non-empty schema already exists" - ) - - @classmethod - def check_valid_path(cls, workdir_path, path): - url = urlsplit(path) - - if url.scheme != "postgresql": - raise click.UsageError( - "Invalid postgres URI - Expecting URI in form: postgresql://[HOST]/DBNAME/SCHEMA" - ) - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - - suggestion_message = "" - if len(path_parts) == 1 and workdir_path is not None: - suggested_path = f"/{path_parts[0]}/{cls.default_schema(workdir_path)}" - suggested_uri = urlunsplit( - [url.scheme, url.netloc, suggested_path, url.query, ""] - ) - suggestion_message = f"\nFor example: {suggested_uri}" - - if len(path_parts) != 2: - raise click.UsageError( - "Invalid postgres URI - postgis working copy requires both dbname and schema:\n" - "Expecting URI in form: postgresql://[HOST]/DBNAME/SCHEMA" - + suggestion_message - ) - - @classmethod - def normalise_path(cls, repo, path): - return path - - @classmethod - def default_schema(cls, workdir_path): - stem = workdir_path.stem - schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" - if schema[0].isdigit(): - schema = "_" + schema - return schema + def check_valid_path(cls, wc_path, workdir_path=None): + cls.check_valid_db_uri(wc_path, workdir_path) def __str__(self): p = urlsplit(self.uri) @@ -129,44 +69,11 @@ def __str__(self): p._replace(netloc=nl) return p.geturl() - @contextlib.contextmanager - def session(self, bulk=0): - """ - Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - - Calling again yields the _same_ connection, the transaction/etc only happen in the outer one. - """ - L = logging.getLogger(f"{self.__class__.__qualname__}.session") - - if hasattr(self, "_session"): - # inner - reuse - L.debug("session: existing...") - yield self._session - L.debug("session: existing/done") - - else: - L.debug("session: new...") - - try: - self._session = self.sessionmaker() - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - self._session.execute("BEGIN TRANSACTION;") - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug("session: new/done") - def is_created(self): """ - Returns true if the postgres schema referred to by this working copy exists and + Returns true if the DB schema referred to by this working copy exists and contains at least one table. If it exists but is empty, it is treated as uncreated. - This is so the postgres schema can be created ahead of time before a repo is created + This is so the DB schema can be created ahead of time before a repo is created or configured, without it triggering code that checks for corrupted working copies. Note that it might not be initialised as a working copy - see self.is_initialised. """ @@ -182,7 +89,7 @@ def is_created(self): def is_initialised(self): """ - Returns true if the postgis working copy is initialised - + Returns true if the PostGIS working copy is initialised - the schema exists and has the necessary sno tables, _sno_state and _sno_track. """ with self.session() as sess: @@ -197,7 +104,7 @@ def is_initialised(self): def has_data(self): """ - Returns true if the postgis working copy seems to have user-created content already. + Returns true if the PostGIS working copy seems to have user-created content already. """ with self.session() as sess: count = sess.scalar( @@ -287,6 +194,13 @@ def _create_table_for_dataset(self, sess, dataset): f"""CREATE TABLE IF NOT EXISTS {self.table_identifier(dataset)} ({table_spec});""" ) + def _insert_or_replace_into_dataset(self, dataset): + # See https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#insert-on-conflict-upsert + pk_col_names = [c.name for c in dataset.schema.pk_columns] + stmt = postgresql_insert(self._table_def_for_dataset(dataset)) + update_dict = {c.name: c for c in stmt.excluded if c.name not in pk_col_names} + return stmt.on_conflict_do_update(index_elements=pk_col_names, set_=update_dict) + def _write_meta(self, sess, dataset): """Write the title (as a comment) and the CRS. Other metadata is not stored in a PostGIS WC.""" self._write_meta_title(sess, dataset) @@ -325,20 +239,25 @@ def delete_meta(self, dataset): """Delete any metadata that is only needed by this dataset.""" pass # There is no metadata except for the spatial_ref_sys table. - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_post(self, sess, dataset): + # Only implemented as _create_spatial_index_post: + # It is more efficient to write the features first, then index them all in bulk. L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") geom_col = dataset.geom_column_name + index_name = f"{dataset.table_name}_idx_{geom_col}" + table = self._table_def_for_dataset(dataset) - # Create the PostGIS Spatial Index L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) t0 = time.monotonic() - sess.execute( - f""" - CREATE INDEX "{dataset.table_name}_idx_{geom_col}" - ON {self.table_identifier(dataset)} USING GIST ({self.quote(geom_col)}); - """ + + spatial_index = Index( + index_name, table.columns[geom_col], postgres_using="GIST" ) + spatial_index.table = table + spatial_index.create(sess.connection()) + sess.execute(f"""ANALYZE {self.table_identifier(dataset)};""") + L.info("Created spatial index in %ss", time.monotonic() - t0) def _drop_spatial_index(self, sess, dataset): @@ -460,10 +379,6 @@ def _remove_hidden_meta_diffs(self, dataset, ds_meta_items, wc_meta_items): del wc_meta_items[key] # If either definition is custom, we keep the diff, since it could be important. - def _db_geom_to_gpkg_geom(self, g): - # This is already handled by register_type - return g - def _is_meta_update_supported(self, dataset_version, meta_diff): """ Returns True if the given meta-diff is supported *without* dropping and rewriting the table. @@ -492,7 +407,7 @@ def _is_meta_update_supported(self, dataset_version, meta_diff): return sum(dt.values()) == 0 def _apply_meta_title(self, sess, dataset, src_value, dest_value): - db.execute( + sess.execute( f"COMMENT ON TABLE {self._table_identifier(dataset.table_name)} IS :comment", {"comment": dest_value}, ) @@ -503,6 +418,9 @@ def _apply_meta_description(self, sess, dataset, src_value, dest_value): def _apply_meta_metadata_dataset_json(self, sess, dataset, src_value, dest_value): pass # This is a no-op for postgis + def _apply_meta_metadata_xml(self, sess, dataset, src_value, dest_value): + pass # This is a no-op for postgis + def _apply_meta_schema_json(self, sess, dataset, src_value, dest_value): src_schema = Schema.from_column_dicts(src_value) dest_schema = Schema.from_column_dicts(dest_value) diff --git a/sno/working_copy/postgis_adapter.py b/sno/working_copy/postgis_adapter.py index dc56c6fac..5a7c13470 100644 --- a/sno/working_copy/postgis_adapter.py +++ b/sno/working_copy/postgis_adapter.py @@ -58,7 +58,7 @@ def quote(ident): "varchar": "text", } -# Types that can't be roundtrip perfectly in Postgis, and what they end up as. +# Types that can't be roundtripped perfectly in PostGIS, and what they end up as. APPROXIMATED_TYPES = {("integer", 8): ("integer", 16)} diff --git a/sno/working_copy/sqlserver.py b/sno/working_copy/sqlserver.py new file mode 100644 index 000000000..6156a4c5b --- /dev/null +++ b/sno/working_copy/sqlserver.py @@ -0,0 +1,490 @@ +import contextlib +import logging +import time +from urllib.parse import urlsplit + +import sqlalchemy as sa +from sqlalchemy import literal_column +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql import crud, quoted_name +from sqlalchemy.sql.dml import ValuesBase +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.compiler import IdentifierPreparer +from sqlalchemy.types import UserDefinedType + +from . import sqlserver_adapter +from .db_server import DatabaseServer_WorkingCopy +from .table_defs import SqlServerSnoTables +from sno import crs_util +from sno.geometry import Geometry +from sno.sqlalchemy import sqlserver_engine + + +class WorkingCopy_SqlServer(DatabaseServer_WorkingCopy): + """ + SQL Server working copy implementation. + + Requirements: + 1. The database needs to exist + 2. The database user needs to be able to: + - Create the specified schema (unless it already exists). + - Create, delete and alter tables and triggers in the specified schema. + """ + + WORKING_COPY_TYPE_NAME = "SQL Server" + URI_SCHEME = "mssql" + + def __init__(self, repo, uri): + """ + uri: connection string of the form mssql://[user[:password]@][netloc][:port][/dbname/schema][?param1=value1&...] + """ + self.L = logging.getLogger(self.__class__.__qualname__) + + self.repo = repo + self.uri = uri + self.path = uri + + self.check_valid_db_uri(uri) + self.db_uri, self.db_schema = self._separate_db_schema(uri) + + self.engine = sqlserver_engine(self.db_uri) + self.sessionmaker = sessionmaker(bind=self.engine) + self.preparer = IdentifierPreparer(self.engine.dialect) + + self.sno_tables = SqlServerSnoTables(self.db_schema) + + def __str__(self): + p = urlsplit(self.uri) + if p.password is not None: + nl = p.hostname + if p.username is not None: + nl = f"{p.username}@{nl}" + if p.port is not None: + nl += f":{p.port}" + + p._replace(netloc=nl) + return p.geturl() + + def is_created(self): + """ + Returns true if the db schema referred to by this working copy exists and + contains at least one table. If it exists but is empty, it is treated as uncreated. + This is so the schema can be created ahead of time before a repo is created + or configured, without it triggering code that checks for corrupted working copies. + Note that it might not be initialised as a working copy - see self.is_initialised. + """ + with self.session() as sess: + count = sess.scalar( + """SELECT COUNT(*) FROM sys.schemas WHERE name=:schema_name;""", + {"schema_name": self.db_schema}, + ) + return count > 0 + + def is_initialised(self): + """ + Returns true if the SQL server working copy is initialised - + the schema exists and has the necessary sno tables, _sno_state and _sno_track. + """ + with self.session() as sess: + count = sess.scalar( + f""" + SELECT COUNT(*) FROM sys.tables + WHERE schema_id = SCHEMA_ID(:schema_name) + AND name IN ('{self.SNO_STATE_NAME}', '{self.SNO_TRACK_NAME}'); + """, + {"schema_name": self.db_schema}, + ) + return count == 2 + + def has_data(self): + """ + Returns true if the SQL server working copy seems to have user-created content already. + """ + with self.session() as sess: + count = sess.scalar( + f""" + SELECT COUNT(*) FROM sys.tables + WHERE schema_id = SCHEMA_ID(:schema_name) + AND name NOT IN ('{self.SNO_STATE_NAME}', '{self.SNO_TRACK_NAME}'); + """, + {"schema_name": self.db_schema}, + ) + return count > 0 + + def create_and_initialise(self): + with self.session() as sess: + if not self.is_created(): + sess.execute(f"CREATE SCHEMA {self.DB_SCHEMA};") + + with self.session() as sess: + self.sno_tables.create_all(sess) + + def delete(self, keep_db_schema_if_possible=False): + """Delete all tables in the schema.""" + with self.session() as sess: + # Drop tables + r = sess.execute( + "SELECT name FROM sys.tables WHERE schema_id=SCHEMA_ID(:schema);", + {"schema": self.db_schema}, + ) + table_identifiers = ", ".join((self.table_identifier(row[0]) for row in r)) + if table_identifiers: + sess.execute(f"DROP TABLE IF EXISTS {table_identifiers};") + + # Drop schema, unless keep_db_schema_if_possible=True + if not keep_db_schema_if_possible: + sess.execute( + f"DROP SCHEMA IF EXISTS {self.DB_SCHEMA};", + ) + + def _create_table_for_dataset(self, sess, dataset): + table_spec = sqlserver_adapter.v2_schema_to_sqlserver_spec( + dataset.schema, dataset + ) + sess.execute( + f"""CREATE TABLE {self.table_identifier(dataset)} ({table_spec});""" + ) + + def _table_def_for_column_schema(self, col, dataset): + if col.data_type == "geometry": + crs_name = col.extra_type_info.get("geometryCRS", None) + crs_id = crs_util.get_identifier_int_from_dataset(dataset, crs_name) or 0 + # This user-defined GeometryType adapts Sno's GPKG geometry to SQL Server's native geometry type. + return sa.column(col.name, GeometryType(crs_id)) + elif col.data_type in ("date", "time", "timestamp"): + return sa.column(col.name, BaseDateOrTimeType) + else: + # Don't need to specify type information for other columns at present, since we just pass through the values. + return sa.column(col.name) + + def _insert_or_replace_into_dataset(self, dataset): + pk_col_names = [c.name for c in dataset.schema.pk_columns] + non_pk_col_names = [ + c.name for c in dataset.schema.columns if c.pk_index is None + ] + return sqlserver_upsert( + self._table_def_for_dataset(dataset), + index_elements=pk_col_names, + set_=non_pk_col_names, + ) + + def _insert_or_replace_state_table_tree(self, sess, tree_id): + r = sess.execute( + f""" + MERGE {self.SNO_STATE} STA + USING (VALUES ('*', 'tree', :value)) AS SRC("table_name", "key", "value") + ON SRC."table_name" = STA."table_name" AND SRC."key" = STA."key" + WHEN MATCHED THEN + UPDATE SET "value" = SRC."value" + WHEN NOT MATCHED THEN + INSERT ("table_name", "key", "value") VALUES (SRC."table_name", SRC."key", SRC."value"); + """, + {"value": tree_id}, + ) + return r.rowcount + + def _write_meta(self, sess, dataset): + """Write the title. Other metadata is not stored in a SQL Server WC.""" + self._write_meta_title(sess, dataset) + + def _write_meta_title(self, sess, dataset): + """Write the dataset title as a comment on the table.""" + # TODO - dataset title is not stored anywhere in SQL server working copy right now. + # We can probably store it using function sp_addextendedproperty to add property 'MS_Description' + pass + + def delete_meta(self, dataset): + """Delete any metadata that is only needed by this dataset.""" + # There is no metadata stored anywhere except the table itself. + pass + + def _get_geom_extent(self, sess, dataset, default=None): + """Returns the envelope around the entire dataset as (min_x, min_y, max_x, max_y).""" + geom_col = dataset.geom_column_name + r = sess.execute( + f""" + WITH _E AS ( + SELECT geometry::EnvelopeAggregate({self.quote(geom_col)}) AS envelope + FROM {self.table_identifier(dataset)} + ) + SELECT + envelope.STPointN(1).STX AS min_x, + envelope.STPointN(1).STY AS min_y, + envelope.STPointN(3).STX AS max_x, + envelope.STPointN(3).STY AS max_y + FROM _E; + """ + ) + result = r.fetchone() + return default if result == (None, None, None, None) else result + + def _grow_rectangle(self, rectangle, scale_factor): + # scale_factor = 1 -> no change, >1 -> grow, <1 -> shrink. + min_x, min_y, max_x, max_y = rectangle + centre_x, centre_y = (min_x + max_x) / 2, (min_y + max_y) / 2 + min_x = (min_x - centre_x) * scale_factor + centre_x + min_y = (min_y - centre_y) * scale_factor + centre_y + max_x = (max_x - centre_x) * scale_factor + centre_x + max_y = (max_y - centre_y) * scale_factor + centre_y + return min_x, min_y, max_x, max_y + + def _create_spatial_index_post(self, sess, dataset): + # Only implementing _create_spatial_index_post: + # We need to know the rough extent of the data to create an index in that area, + # so we create the spatial index once the bulk of the features have been written. + + L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") + + extent = self._get_geom_extent(sess, dataset) + if not extent: + # Can't create a spatial index if we don't know the rough bounding box we need to index. + return + + # Add 20% room to grow. + GROW_FACTOR = 1.2 + min_x, min_y, max_x, max_y = self._grow_rectangle(extent, GROW_FACTOR) + + geom_col = dataset.geom_column_name + index_name = f"{dataset.table_name}_idx_{geom_col}" + + L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) + t0 = time.monotonic() + + create_index = sa.text( + f""" + CREATE SPATIAL INDEX {self.quote(index_name)} + ON {self.table_identifier(dataset)} ({self.quote(geom_col)}) + WITH (BOUNDING_BOX = (:min_x, :min_y, :max_x, :max_y)) + """ + ).bindparams(min_x=min_x, min_y=min_y, max_x=max_x, max_y=max_y) + # Placeholders not allowed in CREATE SPATIAL INDEX - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + create_index.compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() + + L.info("Created spatial index in %ss", time.monotonic() - t0) + + def _drop_spatial_index(self, sess, dataset): + # SQL server deletes the spatial index automatically when the table is deleted. + pass + + def _quoted_trigger_name(self, dataset): + trigger_name = f"{dataset.table_name}_sno_track" + return f"{self.DB_SCHEMA}.{self.quote(trigger_name)}" + + def _create_triggers(self, sess, dataset): + pk_name = dataset.primary_key + create_trigger = sa.text( + f""" + CREATE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)} + AFTER INSERT, UPDATE, DELETE AS + BEGIN + MERGE {self.SNO_TRACK} TRA + USING + (SELECT :table_name, {self.quote(pk_name)} FROM inserted + UNION SELECT :table_name, {self.quote(pk_name)} FROM deleted) + AS SRC (table_name, pk) + ON SRC.table_name = TRA.table_name AND SRC.pk = TRA.pk + WHEN NOT MATCHED THEN INSERT (table_name, pk) VALUES (SRC.table_name, SRC.pk); + END; + """ + ).bindparams(table_name=dataset.table_name) + # Placeholders not allowed in CREATE TRIGGER - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + create_trigger.compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() + + @contextlib.contextmanager + def _suspend_triggers(self, sess, dataset): + sess.execute( + f"""DISABLE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)};""" + ) + yield + sess.execute( + f"""ENABLE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)};""" + ) + + def meta_items(self, dataset): + with self.session() as sess: + table_info_sql = """ + SELECT + C.column_name, C.ordinal_position, C.data_type, + C.character_maximum_length, C.numeric_precision, C.numeric_scale, + KCU.ordinal_position AS pk_ordinal_position + FROM information_schema.columns C + LEFT OUTER JOIN information_schema.key_column_usage KCU + ON (KCU.table_schema = C.table_schema) + AND (KCU.table_name = C.table_name) + AND (KCU.column_name = C.column_name) + WHERE C.table_schema=:table_schema AND C.table_name=:table_name + ORDER BY C.ordinal_position; + """ + r = sess.execute( + table_info_sql, + {"table_schema": self.db_schema, "table_name": dataset.table_name}, + ) + ms_table_info = list(r) + + id_salt = f"{self.db_schema} {dataset.table_name} {self.get_db_tree()}" + schema = sqlserver_adapter.sqlserver_to_v2_schema(ms_table_info, id_salt) + yield "schema.json", schema.to_column_dicts() + + _UNSUPPORTED_META_ITEMS = ( + "title", + "description", + "metadata/dataset.json", + "metadata.xml", + ) + + @classmethod + def try_align_schema_col(cls, old_col_dict, new_col_dict): + old_type = old_col_dict["dataType"] + new_type = new_col_dict["dataType"] + + # Geometry type loses its extra type info when roundtripped through SQL Server. + if new_type == "geometry": + new_col_dict["geometryType"] = old_col_dict.get("geometryType") + new_col_dict["geometryCRS"] = old_col_dict.get("geometryCRS") + + return new_type == old_type + + def _remove_hidden_meta_diffs(self, dataset, ds_meta_items, wc_meta_items): + super()._remove_hidden_meta_diffs(dataset, ds_meta_items, wc_meta_items) + + # Nowhere to put these in SQL Server WC + for key in self._UNSUPPORTED_META_ITEMS: + if key in ds_meta_items: + del ds_meta_items[key] + + # Diffing CRS is not yet supported. + for key in list(ds_meta_items.keys()): + if key.startswith("crs/"): + del ds_meta_items[key] + + def _is_meta_update_supported(self, dataset_version, meta_diff): + """ + Returns True if the given meta-diff is supported *without* dropping and rewriting the table. + (Any meta change is supported if we drop and rewrite the table, but of course it is less efficient). + meta_diff - DeltaDiff object containing the meta changes. + """ + # For now, just always drop and rewrite. + return not meta_diff + + +class InstanceFunction(Function): + """ + An instance function that compiles like this when applied to an element: + >>> element.function() + Unlike a normal sqlalchemy function which would compile as follows: + >>> function(element) + """ + + +@compiles(InstanceFunction) +def compile_instance_function(element, compiler, **kw): + return "(%s).%s()" % (element.clauses, element.name) + + +class GeometryType(UserDefinedType): + """UserDefinedType so that V2 geometry is adapted to MS binary format.""" + + def __init__(self, crs_id): + self.crs_id = crs_id + + def bind_processor(self, dialect): + # 1. Writing - Python layer - convert sno geometry to WKB + return lambda geom: geom.to_wkb() + + def bind_expression(self, bindvalue): + # 2. Writing - SQL layer - wrap in call to STGeomFromWKB to convert WKB to MS binary. + return Function( + quoted_name("geometry::STGeomFromWKB", False), + bindvalue, + self.crs_id, + type_=self, + ) + + def column_expression(self, col): + # 3. Reading - SQL layer - append with call to .STAsBinary() to convert MS binary to WKB. + return InstanceFunction("STAsBinary", col, type_=self) + + def result_processor(self, dialect, coltype): + # 4. Reading - Python layer - convert WKB to sno geometry. + return lambda wkb: Geometry.from_wkb(wkb) + + +class BaseDateOrTimeType(UserDefinedType): + """ + UserDefinedType so we read dates, times, and datetimes as text. + They are stored as date / time / datetime in SQL Server, but read back out as text. + """ + + def column_expression(self, col): + # When reading, convert dates and times to strings using style 127: ISO8601 with time zone Z. + # https://docs.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql + return Function( + "CONVERT", + literal_column("NVARCHAR"), + col, + literal_column("127"), + type_=self, + ) + + +def sqlserver_upsert(*args, **kwargs): + return Upsert(*args, **kwargs) + + +class Upsert(ValuesBase): + """A SQL server custom upsert command that compiles to a merge statement.""" + + def __init__( + self, + table, + values=None, + prefixes=None, + index_elements=None, + set_=None, + **dialect_kw, + ): + ValuesBase.__init__(self, table, values, prefixes) + self._validate_dialect_kwargs(dialect_kw) + self.index_elements = index_elements + self.set_ = set_ + self.select = self.select_names = None + self._returning = None + + +@compiles(Upsert) +def compile_upsert(upsert_stmt, compiler, **kw): + preparer = compiler.preparer + + def list_cols(col_names, prefix=""): + return ", ".join([prefix + c for c in col_names]) + + crud_params = crud._setup_crud_params(compiler, upsert_stmt, crud.ISINSERT, **kw) + crud_values = ", ".join([c[1] for c in crud_params]) + + table = preparer.format_table(upsert_stmt.table) + all_columns = [preparer.quote(c[0].name) for c in crud_params] + index_elements = [preparer.quote(c) for c in upsert_stmt.index_elements] + set_ = [preparer.quote(c) for c in upsert_stmt.set_] + + result = f"MERGE {table} TARGET" + result += f" USING (VALUES ({crud_values})) AS SOURCE ({list_cols(all_columns)})" + + result += " ON " + result += " AND ".join([f"SOURCE.{c} = TARGET.{c}" for c in index_elements]) + + result += " WHEN MATCHED THEN UPDATE SET " + result += ", ".join([f"{c} = SOURCE.{c}" for c in set_]) + + result += " WHEN NOT MATCHED THEN INSERT " + result += ( + f"({list_cols(all_columns)}) VALUES ({list_cols(all_columns, 'SOURCE.')});" + ) + + return result diff --git a/sno/working_copy/sqlserver_adapter.py b/sno/working_copy/sqlserver_adapter.py new file mode 100644 index 000000000..86f10b329 --- /dev/null +++ b/sno/working_copy/sqlserver_adapter.py @@ -0,0 +1,160 @@ +from sno.schema import Schema, ColumnSchema + +from sqlalchemy.sql.compiler import IdentifierPreparer +from sqlalchemy.dialects.mssql.base import MSDialect + + +_PREPARER = IdentifierPreparer(MSDialect()) + + +def quote(ident): + return _PREPARER.quote(ident) + + +V2_TYPE_TO_MS_TYPE = { + "boolean": "bit", + "blob": "varbinary", + "date": "date", + "float": {0: "real", 32: "real", 64: "float"}, + "geometry": "geometry", + "integer": { + 0: "int", + 8: "tinyint", + 16: "smallint", + 32: "int", + 64: "bigint", + }, + "interval": "text", + "numeric": "numeric", + "text": "nvarchar", + "time": "time", + "timestamp": "datetimeoffset", +} + +MS_TYPE_TO_V2_TYPE = { + "bit": "boolean", + "tinyint": ("integer", 8), + "smallint": ("integer", 16), + "int": ("integer", 32), + "bigint": ("integer", 64), + "real": ("float", 32), + "float": ("float", 64), + "binary": "blob", + "char": "text", + "date": "date", + "datetime": "timestamp", + "datetime2": "timestamp", + "datetimeoffset": "timestamp", + "decimal": "numeric", + "geography": "geometry", + "geometry": "geometry", + "nchar": "text", + "numeric": "numeric", + "nvarchar": "text", + "ntext": "text", + "text": "text", + "time": "time", + "varchar": "text", + "varbinary": "blob", +} + +# Types that can't be roundtripped perfectly in SQL Server, and what they end up as. +APPROXIMATED_TYPES = {"interval": "text"} +# Note that although this means that all other V2 types above can be roundtripped, it +# doesn't mean that extra type info is always preserved. Specifically, extra +# geometry type info - the geometry type and CRS - is not roundtripped. + + +def v2_schema_to_sqlserver_spec(schema, v2_obj): + """ + Generate the SQL CREATE TABLE spec from a V2 object eg: + 'fid INTEGER, geom GEOMETRY(POINT,2136), desc VARCHAR(128), PRIMARY KEY(fid)' + """ + result = [f"{quote(col.name)} {v2_type_to_ms_type(col, v2_obj)}" for col in schema] + + if schema.pk_columns: + pk_col_names = ", ".join((quote(col.name) for col in schema.pk_columns)) + result.append(f"PRIMARY KEY({pk_col_names})") + + return ", ".join(result) + + +def v2_type_to_ms_type(column_schema, v2_obj): + """Convert a v2 schema type to a SQL server type.""" + + v2_type = column_schema.data_type + extra_type_info = column_schema.extra_type_info + + ms_type_info = V2_TYPE_TO_MS_TYPE.get(v2_type) + if ms_type_info is None: + raise ValueError(f"Unrecognised data type: {v2_type}") + + if isinstance(ms_type_info, dict): + return ms_type_info.get(extra_type_info.get("size", 0)) + + ms_type = ms_type_info + + if ms_type in ("varchar", "nvarchar", "varbinary"): + length = extra_type_info.get("length", None) + return f"{ms_type}({length})" if length is not None else f"{ms_type}(max)" + + if ms_type == "numeric": + precision = extra_type_info.get("precision", None) + scale = extra_type_info.get("scale", None) + if precision is not None and scale is not None: + return f"numeric({precision},{scale})" + elif precision is not None: + return f"numeric({precision})" + else: + return "numeric" + + return ms_type + + +def sqlserver_to_v2_schema(ms_table_info, id_salt): + """Generate a V2 schema from the given SQL server metadata.""" + return Schema([_sqlserver_to_column_schema(col, id_salt) for col in ms_table_info]) + + +def _sqlserver_to_column_schema(ms_col_info, id_salt): + """ + Given the MS column info for a particular column, converts it to a ColumnSchema. + + Parameters: + ms_col_info - info about a single column from ms_table_info. + id_salt - the UUIDs of the generated ColumnSchema are deterministic and depend on + the name and type of the column, and on this salt. + """ + name = ms_col_info["column_name"] + pk_index = ms_col_info["pk_ordinal_position"] + if pk_index is not None: + pk_index -= 1 + data_type, extra_type_info = _ms_type_to_v2_type(ms_col_info) + + col_id = ColumnSchema.deterministic_id(name, data_type, id_salt) + return ColumnSchema(col_id, name, data_type, pk_index, **extra_type_info) + + +def _ms_type_to_v2_type(ms_col_info): + v2_type_info = MS_TYPE_TO_V2_TYPE.get(ms_col_info["data_type"]) + + if isinstance(v2_type_info, tuple): + v2_type = v2_type_info[0] + extra_type_info = {"size": v2_type_info[1]} + else: + v2_type = v2_type_info + extra_type_info = {} + + if v2_type == "geometry": + return v2_type, extra_type_info + + if v2_type == "text": + length = ms_col_info["character_maximum_length"] or None + if length is not None: + extra_type_info["length"] = length + + if v2_type == "numeric": + extra_type_info["precision"] = ms_col_info["numeric_precision"] or None + extra_type_info["scale"] = ms_col_info["numeric_scale"] or None + + return v2_type, extra_type_info diff --git a/sno/working_copy/table_defs.py b/sno/working_copy/table_defs.py index 6488985d8..ef500911c 100644 --- a/sno/working_copy/table_defs.py +++ b/sno/working_copy/table_defs.py @@ -9,6 +9,8 @@ UniqueConstraint, ) +from sqlalchemy.types import NVARCHAR + class TinyInt(Integer): __visit_name__ = "TINYINT" @@ -100,6 +102,38 @@ def create_all(self, session): return self._SQLALCHEMY_METADATA.create_all(session.connection()) +class SqlServerSnoTables(TableSet): + """ + Tables for sno-specific metadata - SQL Server variant. + Table names have a user-defined schema, and so unlike other table sets, + we need to construct an instance with the appropriate schema. + Primary keys have to be NVARCHAR of a fixed maximum length - + if the total maximum length is too long, SQL Server cannot generate an index. + """ + + def __init__(self, schema=None): + self._SQLALCHEMY_METADATA = MetaData() + + self.sno_state = Table( + "_sno_state", + self._SQLALCHEMY_METADATA, + Column("table_name", NVARCHAR(400), nullable=False, primary_key=True), + Column("key", NVARCHAR(400), nullable=False, primary_key=True), + Column("value", Text, nullable=False), + schema=schema, + ) + self.sno_track = Table( + "_sno_track", + self._SQLALCHEMY_METADATA, + Column("table_name", NVARCHAR(400), nullable=False, primary_key=True), + Column("pk", NVARCHAR(400), nullable=True, primary_key=True), + schema=schema, + ) + + def create_all(self, session): + return self._SQLALCHEMY_METADATA.create_all(session.connection()) + + class GpkgTables(TableSet): """GPKG spec tables - see http://www.geopackage.org/spec/#table_definition_sql""" diff --git a/tests/conftest.py b/tests/conftest.py index d5684481a..47cdf6deb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from sno.geometry import Geometry from sno.repo import SnoRepo -from sno.sqlalchemy import gpkg_engine, postgis_engine +from sno.sqlalchemy import gpkg_engine, postgis_engine, sqlserver_engine from sno.working_copy import WorkingCopy @@ -739,24 +739,29 @@ def func(conn, pk, update_str, layer=None, commit=True): return func -def _insert_command(table_name, col_names, schema=None): +def _insert_command(table_name, col_names): return sqlalchemy.table( - table_name, - *[sqlalchemy.column(c) for c in col_names], - schema=schema, + table_name, *[sqlalchemy.column(c) for c in col_names] ).insert() -def _edit_points(conn, schema=None): +def _edit_points(conn, dataset=None, working_copy=None): H = pytest.helpers.helpers() - layer = f'"{schema}"."{H.POINTS.LAYER}"' if schema else H.POINTS.LAYER - r = conn.execute( - _insert_command(H.POINTS.LAYER, H.POINTS.RECORD.keys(), schema=schema), - H.POINTS.RECORD, - ) - assert r.rowcount == 1 + + if working_copy is None: + layer = H.POINTS.LAYER + insert_cmd = _insert_command(H.POINTS.LAYER, H.POINTS.RECORD.keys()) + else: + layer = f"{working_copy.DB_SCHEMA}.{H.POINTS.LAYER}" + insert_cmd = working_copy._insert_into_dataset(dataset) + + # Note - different DB backends support and interpret rowcount differently. + # Sometimes rowcount is not supported for inserts, so it just returns -1. + # Rowcount can be 1 or 2 if 1 row has changed its PK + r = conn.execute(insert_cmd, H.POINTS.RECORD) + assert r.rowcount in (1, -1) r = conn.execute(f"UPDATE {layer} SET fid=9998 WHERE fid=1;") - assert r.rowcount == 1 + assert r.rowcount in (1, 2) r = conn.execute(f"UPDATE {layer} SET name='test' WHERE fid=2;") assert r.rowcount == 1 r = conn.execute(f"DELETE FROM {layer} WHERE fid IN (3,30,31,32,33);") @@ -770,16 +775,20 @@ def edit_points(): return _edit_points -def _edit_polygons(conn, schema=None): +def _edit_polygons(conn, dataset=None, working_copy=None): H = pytest.helpers.helpers() - layer = f'"{schema}"."{H.POLYGONS.LAYER}"' if schema else H.POLYGONS.LAYER - r = conn.execute( - _insert_command(H.POLYGONS.LAYER, H.POLYGONS.RECORD.keys(), schema=schema), - H.POLYGONS.RECORD, - ) - assert r.rowcount == 1 + if working_copy is None: + layer = H.POLYGONS.LAYER + insert_cmd = _insert_command(H.POLYGONS.LAYER, H.POLYGONS.RECORD.keys()) + else: + layer = f"{working_copy.DB_SCHEMA}.{H.POLYGONS.LAYER}" + insert_cmd = working_copy._insert_into_dataset(dataset) + + # See note on rowcount at _edit_points + r = conn.execute(insert_cmd, H.POLYGONS.RECORD) + assert r.rowcount in (1, -1) r = conn.execute(f"UPDATE {layer} SET id=9998 WHERE id=1424927;") - assert r.rowcount == 1 + assert r.rowcount in (1, 2) r = conn.execute(f"UPDATE {layer} SET survey_reference='test' WHERE id=1443053;") assert r.rowcount == 1 r = conn.execute( @@ -795,16 +804,21 @@ def edit_polygons(): return _edit_polygons -def _edit_table(conn, schema=None): +def _edit_table(conn, dataset=None, working_copy=None): H = pytest.helpers.helpers() - layer = f'"{schema}"."{H.TABLE.LAYER}"' if schema else H.TABLE.LAYER - r = conn.execute( - _insert_command(H.TABLE.LAYER, H.TABLE.RECORD.keys(), schema=schema), - H.TABLE.RECORD, - ) - assert r.rowcount == 1 + + if working_copy is None: + layer = H.TABLE.LAYER + insert_cmd = _insert_command(H.TABLE.LAYER, H.TABLE.RECORD.keys()) + else: + layer = f"{working_copy.DB_SCHEMA}.{H.TABLE.LAYER}" + insert_cmd = working_copy._insert_into_dataset(dataset) + + r = conn.execute(insert_cmd, H.TABLE.RECORD) + # rowcount is not actually supported for inserts, but works in certain DB types - otherwise is -1. + assert r.rowcount in (1, -1) r = conn.execute(f"""UPDATE {layer} SET "OBJECTID"=9998 WHERE "OBJECTID"=1;""") - assert r.rowcount == 1 + assert r.rowcount in (1, 2) r = conn.execute(f"""UPDATE {layer} SET "NAME"='test' WHERE "OBJECTID"=2;""") assert r.rowcount == 1 r = conn.execute(f"""DELETE FROM {layer} WHERE "OBJECTID" IN (3,30,31,32,33);""") @@ -869,14 +883,12 @@ def disable_editor(): @pytest.fixture() def postgis_db(): """ - Using docker, you can run a PostGres test - such as test_postgis_import - as follows: + Using docker, you can run a PostGIS test - such as test_postgis_import - as follows: docker run -it --rm -d -p 15432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust kartoza/postgis SNO_POSTGRES_URL='postgresql://docker:docker@localhost:15432/gis' pytest -k postgis --pdb -vvs """ if "SNO_POSTGRES_URL" not in os.environ: - raise pytest.skip( - "Requires postgres - read docstring at sno.test_structure.postgis_db" - ) + raise pytest.skip("Requires PostGIS - read docstring at conftest.postgis_db") engine = postgis_engine(os.environ["SNO_POSTGRES_URL"]) with engine.connect() as conn: # test connection and postgis support @@ -912,3 +924,65 @@ def ctx(create=False): conn.execute(f"""DROP SCHEMA IF EXISTS "{schema}" CASCADE;""") return ctx + + +@pytest.fixture() +def sqlserver_db(): + """ + Using docker, you can run a SQL Server test - such as those in test_working_copy_sqlserver - as follows: + docker run -it --rm -d -p 11433:1433 -e ACCEPT_EULA=Y -e 'SA_PASSWORD=PassWord1' mcr.microsoft.com/mssql/server + SNO_SQLSERVER_URL='mssql://sa:PassWord1@127.0.0.1:11433/master' pytest -k sqlserver --pdb -vvs + """ + if "SNO_SQLSERVER_URL" not in os.environ: + raise pytest.skip( + "Requires SQL Server - read docstring at conftest.sqlserver_db" + ) + engine = sqlserver_engine(os.environ["SNO_SQLSERVER_URL"]) + with engine.connect() as conn: + # Test connection + try: + conn.execute("SELECT @@version;") + except sqlalchemy.exc.DBAPIError: + raise pytest.skip("Requires SQL Server") + yield engine + + +@pytest.fixture() +def new_sqlserver_db_schema(request, sqlserver_db): + @contextlib.contextmanager + def ctx(create=False): + sha = hashlib.sha1(request.node.nodeid.encode("utf8")).hexdigest()[:20] + schema = f"sno_test_{sha}" + with sqlserver_db.connect() as conn: + # Start by deleting in case it is left over from last test-run... + _sqlserver_drop_schema_cascade(conn, schema) + # Actually create only if create=True, otherwise the test will create it + if create: + conn.execute(f"""CREATE SCHEMA "{schema}";""") + try: + url = urlsplit(os.environ["SNO_SQLSERVER_URL"]) + url_path = url.path.rstrip("/") + "/" + schema + new_schema_url = urlunsplit( + [url.scheme, url.netloc, url_path, url.query, ""] + ) + yield new_schema_url, schema + finally: + # Clean up - delete it again if it exists. + with sqlserver_db.connect() as conn: + _sqlserver_drop_schema_cascade(conn, schema) + + return ctx + + +def _sqlserver_drop_schema_cascade(conn, db_schema): + r = conn.execute( + sqlalchemy.text( + "SELECT name FROM sys.tables WHERE schema_id=SCHEMA_ID(:schema);" + ), + {"schema": db_schema}, + ) + table_identifiers = ", ".join([f"{db_schema}.{row[0]}" for row in r]) + if table_identifiers: + conn.execute(f"DROP TABLE IF EXISTS {table_identifiers};") + + conn.execute(f"DROP SCHEMA IF EXISTS {db_schema};") diff --git a/tests/test_working_copy_postgis.py b/tests/test_working_copy_postgis.py index cfbaae24a..0b09efcd8 100644 --- a/tests/test_working_copy_postgis.py +++ b/tests/test_working_copy_postgis.py @@ -151,11 +151,11 @@ def test_commit_edits( with wc.session() as sess: if archive == "points": - edit_points(sess, postgres_schema) + edit_points(sess, repo.datasets()[H.POINTS.LAYER], wc) elif archive == "polygons": - edit_polygons(sess, postgres_schema) + edit_polygons(sess, repo.datasets()[H.POLYGONS.LAYER], wc) elif archive == "table": - edit_table(sess, postgres_schema) + edit_table(sess, repo.datasets()[H.TABLE.LAYER], wc) r = cli_runner.invoke(["status"]) assert r.exit_code == 0, r.stderr diff --git a/tests/test_working_copy_sqlserver.py b/tests/test_working_copy_sqlserver.py new file mode 100644 index 000000000..aa5e5f6a0 --- /dev/null +++ b/tests/test_working_copy_sqlserver.py @@ -0,0 +1,313 @@ +import pytest + +import pygit2 + +from sno.repo import SnoRepo +from sno.working_copy import sqlserver_adapter +from test_working_copy import compute_approximated_types + + +H = pytest.helpers.helpers() + + +@pytest.mark.parametrize( + "existing_schema", + [ + pytest.param(True, id="existing-schema"), + pytest.param(False, id="brand-new-schema"), + ], +) +@pytest.mark.parametrize( + "archive,table,commit_sha", + [ + pytest.param("points", H.POINTS.LAYER, H.POINTS.HEAD_SHA, id="points"), + pytest.param("polygons", H.POLYGONS.LAYER, H.POLYGONS.HEAD_SHA, id="polygons"), + pytest.param("table", H.TABLE.LAYER, H.TABLE.HEAD_SHA, id="table"), + ], +) +def test_checkout_workingcopy( + archive, + table, + commit_sha, + existing_schema, + data_archive, + cli_runner, + new_sqlserver_db_schema, +): + """ Checkout a working copy """ + with data_archive(archive) as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema(create=existing_schema) as ( + sqlserver_url, + sqlserver_schema, + ): + r = cli_runner.invoke(["create-workingcopy", sqlserver_url]) + assert r.exit_code == 0, r.stderr + assert ( + r.stdout.splitlines()[-1] + == f"Creating working copy at {sqlserver_url} ..." + ) + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + wc = repo.working_copy + assert wc.is_created() + + head_tree_id = repo.head_tree.hex + assert wc.assert_db_tree_match(head_tree_id) + + +@pytest.mark.parametrize( + "existing_schema", + [ + pytest.param(True, id="existing-schema"), + pytest.param(False, id="brand-new-schema"), + ], +) +def test_init_import( + existing_schema, + new_sqlserver_db_schema, + data_archive, + tmp_path, + cli_runner, +): + """ Import the GeoPackage (eg. `kx-foo-layer.gpkg`) into a Sno repository. """ + repo_path = tmp_path / "data.sno" + repo_path.mkdir() + + with data_archive("gpkg-points") as data: + with new_sqlserver_db_schema(create=existing_schema) as ( + sqlserver_url, + sqlserver_schema, + ): + r = cli_runner.invoke( + [ + "init", + "--import", + f"gpkg:{data / 'nz-pa-points-topo-150k.gpkg'}", + str(repo_path), + f"--workingcopy-path={sqlserver_url}", + ] + ) + assert r.exit_code == 0, r.stderr + assert (repo_path / ".sno" / "HEAD").exists() + + repo = SnoRepo(repo_path) + wc = repo.working_copy + + assert wc.is_created() + assert wc.is_initialised() + assert wc.has_data() + + assert wc.path == sqlserver_url + + +@pytest.mark.parametrize( + "archive,table,commit_sha", + [ + pytest.param("points", H.POINTS.LAYER, H.POINTS.HEAD_SHA, id="points"), + pytest.param("polygons", H.POLYGONS.LAYER, H.POLYGONS.HEAD_SHA, id="polygons"), + pytest.param("table", H.TABLE.LAYER, H.TABLE.HEAD_SHA, id="table"), + ], +) +def test_commit_edits( + archive, + table, + commit_sha, + data_archive, + cli_runner, + new_sqlserver_db_schema, + edit_points, + edit_polygons, + edit_table, +): + """ Checkout a working copy and make some edits """ + with data_archive(archive) as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema() as (sqlserver_url, sqlserver_schema): + r = cli_runner.invoke(["create-workingcopy", sqlserver_url]) + assert r.exit_code == 0, r.stderr + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + wc = repo.working_copy + assert wc.is_created() + + with wc.session() as sess: + if archive == "points": + edit_points(sess, repo.datasets()[H.POINTS.LAYER], wc) + elif archive == "polygons": + edit_polygons(sess, repo.datasets()[H.POLYGONS.LAYER], wc) + elif archive == "table": + edit_table(sess, repo.datasets()[H.TABLE.LAYER], wc) + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Changes in working copy:", + ' (use "sno commit" to commit)', + ' (use "sno reset" to discard changes)', + "", + f" {table}:", + " feature:", + " 1 inserts", + " 2 updates", + " 5 deletes", + ] + orig_head = repo.head.peel(pygit2.Commit).hex + + r = cli_runner.invoke(["commit", "-m", "test_commit"]) + assert r.exit_code == 0, r.stderr + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + new_head = repo.head.peel(pygit2.Commit).hex + assert new_head != orig_head + + r = cli_runner.invoke(["checkout", "HEAD^"]) + + assert repo.head.peel(pygit2.Commit).hex == orig_head + + +def test_edit_schema(data_archive, cli_runner, new_sqlserver_db_schema): + with data_archive("polygons") as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema() as (sqlserver_url, sqlserver_schema): + r = cli_runner.invoke(["create-workingcopy", sqlserver_url]) + assert r.exit_code == 0, r.stderr + + wc = repo.working_copy + assert wc.is_created() + + r = cli_runner.invoke(["diff", "--output-format=quiet"]) + assert r.exit_code == 0, r.stderr + + with wc.session() as sess: + sess.execute( + f"""ALTER TABLE "{sqlserver_schema}"."{H.POLYGONS.LAYER}" ADD colour NVARCHAR(32);""" + ) + sess.execute( + f"""ALTER TABLE "{sqlserver_schema}"."{H.POLYGONS.LAYER}" DROP COLUMN survey_reference;""" + ) + + r = cli_runner.invoke(["diff"]) + assert r.exit_code == 0, r.stderr + diff = r.stdout.splitlines() + + # New column "colour" has an ID is deterministically generated from the commit hash, + # but we don't care exactly what it is. + try: + colour_id_line = diff[-6] + except KeyError: + colour_id_line = "" + + assert diff[-46:] == [ + "--- nz_waca_adjustments:meta:schema.json", + "+++ nz_waca_adjustments:meta:schema.json", + " [", + " {", + ' "id": "79d3c4ca-3abd-0a30-2045-45169357113c",', + ' "name": "id",', + ' "dataType": "integer",', + ' "primaryKeyIndex": 0,', + ' "size": 64', + " },", + " {", + ' "id": "c1d4dea1-c0ad-0255-7857-b5695e3ba2e9",', + ' "name": "geom",', + ' "dataType": "geometry",', + ' "geometryType": "MULTIPOLYGON",', + ' "geometryCRS": "EPSG:4167"', + " },", + " {", + ' "id": "d3d4b64b-d48e-4069-4bb5-dfa943d91e6b",', + ' "name": "date_adjusted",', + ' "dataType": "timestamp"', + " },", + "- {", + '- "id": "dff34196-229d-f0b5-7fd4-b14ecf835b2c",', + '- "name": "survey_reference",', + '- "dataType": "text",', + '- "length": 50', + "- },", + " {", + ' "id": "13dc4918-974e-978f-05ce-3b4321077c50",', + ' "name": "adjusted_nodes",', + ' "dataType": "integer",', + ' "size": 32', + " },", + "+ {", + colour_id_line, + '+ "name": "colour",', + '+ "dataType": "text",', + '+ "length": 32', + "+ },", + " ]", + ] + + orig_head = repo.head.peel(pygit2.Commit).hex + + r = cli_runner.invoke(["commit", "-m", "test_commit"]) + assert r.exit_code == 0, r.stderr + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + new_head = repo.head.peel(pygit2.Commit).hex + assert new_head != orig_head + + r = cli_runner.invoke(["checkout", "HEAD^"]) + + assert repo.head.peel(pygit2.Commit).hex == orig_head + + +def test_approximated_types(): + assert sqlserver_adapter.APPROXIMATED_TYPES == compute_approximated_types( + sqlserver_adapter.V2_TYPE_TO_MS_TYPE, sqlserver_adapter.MS_TYPE_TO_V2_TYPE + ) + + +def test_types_roundtrip(data_archive, cli_runner, new_sqlserver_db_schema): + with data_archive("types") as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema() as (sqlserver_url, sqlserver_schema): + repo.config["sno.workingcopy.path"] = sqlserver_url + r = cli_runner.invoke(["checkout"]) + + # If type-approximation roundtrip code isn't working, + # we would get spurious diffs on types that SQL server doesn't support. + r = cli_runner.invoke(["diff", "--exit-code"]) + assert r.exit_code == 0, r.stdout