Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for changing the spatial filter with kart checkout #465

Merged
merged 2 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions kart/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

from .exceptions import DbConnectionError
from .key_filters import RepoKeyFilter
from .output_util import InputMode, get_input_mode
from .spatial_filters import SpatialFilterString, spatial_filter_help_text
from .structs import CommitWithReference
from .working_copy import WorkingCopyStatus
from .output_util import InputMode, get_input_mode


_DISCARD_CHANGES_HELP_MESSAGE = (
"Commit these changes first (`kart commit`) or"
Expand All @@ -38,6 +40,17 @@ def reset_wc_if_needed(repo, target_tree_or_commit, *, discard_changes=False):
working_copy.create_and_initialise()
datasets = list(repo.datasets(target_tree_or_commit))
working_copy.write_full(target_tree_or_commit, *datasets)
return

spatial_filter_matches = repo.spatial_filter.matches_working_copy(repo)
if not spatial_filter_matches:
# TODO - support spatial filter changes without doing full rewrites.
click.echo(f"Updating {working_copy} with new spatial filter...")
datasets = list(repo.datasets(target_tree_or_commit))
working_copy.rewrite_full(
target_tree_or_commit, *datasets, force=discard_changes
)
return

db_tree_matches = (
working_copy.get_db_tree() == target_tree_or_commit.peel(pygit2.Tree).hex
Expand Down Expand Up @@ -68,8 +81,16 @@ def reset_wc_if_needed(repo, target_tree_or_commit, *, discard_changes=False):
help="If a local branch of given name doesn't exist, but a remote does, "
"this option guesses that the user wants to create a local to track the remote",
)
@click.option(
"--spatial-filter",
"spatial_filter_spec",
type=SpatialFilterString(encoding="utf-8"),
help=spatial_filter_help_text(),
)
@click.argument("refish", default=None, required=False)
def checkout(ctx, new_branch, force, discard_changes, do_guess, refish):
def checkout(
ctx, new_branch, force, discard_changes, do_guess, spatial_filter_spec, refish
):
""" Switch branches or restore working tree files """
repo = ctx.obj.repo

Expand Down Expand Up @@ -101,10 +122,16 @@ def checkout(ctx, new_branch, force, discard_changes, do_guess, refish):

commit = resolved.commit
head_ref = resolved.reference.name if resolved.reference else commit.id
same_commit = repo.head_commit == commit
do_switch_commit = repo.head_commit != commit

do_switch_spatial_filter = False
if spatial_filter_spec is not None:
do_switch_spatial_filter = not spatial_filter_spec.resolve(
repo
).matches_working_copy(repo)

force = force or discard_changes
if not same_commit and not force:
if (do_switch_commit or do_switch_spatial_filter) and not force:
ctx.obj.check_not_dirty(help_message=_DISCARD_CHANGES_HELP_MESSAGE)

if new_branch:
Expand All @@ -128,6 +155,9 @@ def checkout(ctx, new_branch, force, discard_changes, do_guess, refish):

from kart.working_copy.base import BaseWorkingCopy

if spatial_filter_spec is not None:
spatial_filter_spec.write_config(repo)

BaseWorkingCopy.ensure_config_exists(repo)
reset_wc_if_needed(repo, commit, discard_changes=discard_changes)

Expand Down Expand Up @@ -191,8 +221,8 @@ def switch(ctx, create, force_create, discard_changes, do_guess, refish):
resolved = CommitWithReference.resolve(repo, "HEAD")
commit = resolved.commit

same_commit = repo.head_commit == commit
if not discard_changes and not same_commit:
do_switch_commit = repo.head_commit != commit
if do_switch_commit and not discard_changes:
ctx.obj.check_not_dirty(_DISCARD_CHANGES_HELP_MESSAGE)

if new_branch in repo.branches and not force_create:
Expand Down Expand Up @@ -243,8 +273,8 @@ def switch(ctx, create, force_create, discard_changes, do_guess, refish):
raise NotFound(f"Branch '{refish}' not found.", NO_BRANCH)

commit = existing_branch.peel(pygit2.Commit)
same_commit = repo.head_commit == commit
if not discard_changes and not same_commit:
do_switch_commit = repo.head_commit != commit
if do_switch_commit and not discard_changes:
ctx.obj.check_not_dirty(_DISCARD_CHANGES_HELP_MESSAGE)

if existing_branch.shorthand in repo.branches.local:
Expand Down Expand Up @@ -347,8 +377,8 @@ def reset(ctx, discard_changes, refish):
except (KeyError, pygit2.InvalidSpecError):
raise NotFound(f"{refish} is not a commit", exit_code=NO_COMMIT)

same_commit = repo.head_commit == commit
if not discard_changes and not same_commit:
do_switch_commit = repo.head_commit != commit
if do_switch_commit and not discard_changes:
ctx.obj.check_not_dirty(_DISCARD_CHANGES_HELP_MESSAGE)

head_branch = repo.head_branch
Expand Down
2 changes: 1 addition & 1 deletion kart/fsck.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _fsck_reset(repo, working_copy, dataset_paths):
commit = repo.head_commit
datasets = [repo.datasets()[p] for p in dataset_paths]

working_copy.drop_table(commit, *datasets)
working_copy.drop_tables(commit, *datasets)
working_copy.write_full(commit, *datasets)


Expand Down
5 changes: 3 additions & 2 deletions kart/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def _add_datasets_to_working_copy(repo, *datasets, replace_existing=False):
click.echo(f"Updating {wc} ...")

if replace_existing:
wc.drop_table(commit, *datasets)
wc.write_full(commit, *datasets)
wc.rewrite_full(commit, *datasets, force=True)
else:
wc.write_full(commit, *datasets)


class GenerateIDsFromFile(StringFromFile):
Expand Down
125 changes: 102 additions & 23 deletions kart/spatial_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .crs_util import make_crs
from .exceptions import CrsError, GeometryError, NotFound, NO_SPATIAL_FILTER
from .geometry import geometry_from_string, GeometryType
from .serialise_util import hexhash


L = logging.getLogger("kart.spatial_filters")
Expand Down Expand Up @@ -96,6 +97,14 @@ def __init__(self):
self.REF_KEY = KartConfigKeys.KART_SPATIALFILTER_REFERENCE
self.OID_KEY = KartConfigKeys.KART_SPATIALFILTER_OBJECTID

def resolve(self):
"""
Returns an equivalent ResolvedSpatialFilterSpec that directly contains the geometry and CRS
(as opposed to a ReferenceSpatialFilterSpec that contains a reference to some other object
that in turn contains the geometry and CRS).
"""
raise NotImplementedError()


class ResolvedSpatialFilterSpec(SpatialFilterSpec):
"""A user-provided specification for a spatial filter where the user has supplied the values directly."""
Expand All @@ -114,6 +123,9 @@ def __init__(self, crs_spec, geometry_spec, match_all=False):
context="spatial filter",
)

def resolve(self, repo):
return self

def write_config(self, repo):
if self.match_all:
self.delete_all_config(repo)
Expand All @@ -123,6 +135,23 @@ def write_config(self, repo):
repo.del_config(self.REF_KEY)
repo.del_config(self.OID_KEY)

def delete_all_config(self, repo):
for key in (self.GEOM_KEY, self.CRS_KEY, self.REF_KEY, self.OID_KEY):
repo.del_config(key)

def matches_working_copy(self, repo):
working_copy = repo.working_copy
return (
working_copy is None
or working_copy.get_spatial_filter_hash() == self.hexhash
)

@property
def hexhash(self):
if self.match_all:
return None
return hexhash(self.crs_spec.strip(), self.geometry.to_wkb())


class ReferenceSpatialFilterSpec(SpatialFilterSpec):
"""
Expand All @@ -134,36 +163,71 @@ def __init__(self, ref_or_oid):
super().__init__()
self.ref_or_oid = ref_or_oid

def write_config(self, repo):
def _resolve_object_contents(self, obj):
contents = obj.data.decode("utf-8")
parts = self.split_file(contents)
return ResolvedSpatialFilterSpec(*parts)

@functools.lru_cache(maxsize=1)
def _resolve_target(self, repo):
"""
Returns a tuple of strings (reference, object_id, ResolvedSpatialFilterSpec).
# Returned reference will be None if ref_or_oid is an object-id.
"""

# TODO - handle missing objects (try to make sure they are fetched from the remote).

obj = None
oid = self.ref_or_oid
try:
obj = repo[self.ref_or_oid]
obj = repo[oid]
except (KeyError, ValueError):
pass

if obj is not None:
return None, oid, self._resolve_object_contents(obj)

ref = self.ref_or_oid
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this function is assuming (AFAICT) that at this point the self.ref_or_oid is the name of a ref rather than an OID. Is that true? It seems like it could be an OID of an unfetched object also. Is there any way to tell?

Copy link
Collaborator Author

@olsen232 olsen232 Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does assume that if it doesn't find the object, it must be a ref - or missing. It could be that it is an unfetched object however. But whether non-existent or unfetched, the end results will be this:
No spatial filter object was found in the repository at 9d5d8e7 or refs/filters/9d5d8e7 - which is true and describes the problem accurately.

I'll add a TODO to handle unfetched objects here somehow, but it's part of a bigger problem - we can't currently tell if objects are non-existent / corrupted or merely unfetched, and we don't specifically handle unfetched objects anywhere.

if not ref.startswith("refs/"):
ref = f"refs/filters/{ref}"

if ref in repo.references:
oid = str(repo.references[ref].resolve().target)
try:
obj = repo[oid]
except (KeyError, ValueError):
pass

if obj is not None:
return ref, oid, self._resolve_object_contents(obj)

ref_desc = " or ".join(set([oid, ref]))
raise NotFound(
f"No spatial filter object was found in the repository at {ref_desc}",
exit_code=NO_SPATIAL_FILTER,
)

def resolve(self, repo):
ref, oid, resolved_spatial_filter_spec = self._resolve_target(repo)
return resolved_spatial_filter_spec

def write_config(self, repo):
ref, oid, resolved_spatial_filter_spec = self._resolve_target(repo)
if ref is None:
# Found an object - the object is immutable, so no reason to store a pointer to it.
# Just resolve the reference to geometry + CRS and store that.
contents = obj.data.decode("utf-8")
parts = self.split_file(contents)
ResolvedSpatialFilterSpec(*parts).write_config(repo)
resolved_spatial_filter_spec.write_config(repo)

else:
ref = self.ref_or_oid
if not ref.startswith("refs/"):
ref = f"refs/filters/{ref}"
if ref not in repo.references:
ref_desc = " or ".join(set([ref, self.ref_or_oid]))
raise NotFound(
f"No spatial filter object was found in the repository at {ref_desc}",
exit_code=NO_SPATIAL_FILTER,
)
# Found a reference. The reference is mutable, so we store it (and the object it points to).
oid = str(repo.references[ref].resolve().target)
repo.config[self.REF_KEY] = ref
repo.config[self.OID_KEY] = oid
repo.del_config(self.GEOM_KEY)
repo.del_config(self.CRS_KEY)

def matches_working_copy(self, repo):
return self.resolve().matches_working_copy(repo)

@classmethod
def split_file(cls, contents):
parts = re.split(r"\n\r?\n", contents, maxsplit=1)
Expand Down Expand Up @@ -221,13 +285,10 @@ def from_repo_config(cls, repo):
@classmethod
@functools.lru_cache()
def from_spec(cls, crs_spec, geometry_spec):
geometry = geometry_from_string(geometry_spec, context="spatial filter")
crs = make_crs(crs_spec, context="spatial filter")

return OriginalSpatialFilter(geometry.to_ogr(), crs)
return OriginalSpatialFilter(crs_spec, geometry_spec)

def __init__(
self, filter_geometry_ogr, crs, geom_column_name=None, match_all=False
self, crs, filter_geometry_ogr, geom_column_name=None, match_all=False
):
"""
Create a new spatial filter.
Expand All @@ -238,12 +299,12 @@ def __init__(
self.match_all = match_all

if match_all:
self.filter_ogr = self.filter_env = self.crs = None
self.crs = self.filter_ogr = self.filter_env = None
self.geom_column_name = None
else:
self.crs = crs
self.filter_ogr = filter_geometry_ogr
self.filter_env = self.filter_ogr.GetEnvelope()
self.crs = crs
self.geom_column_name = geom_column_name

def matches(self, feature):
Expand Down Expand Up @@ -311,6 +372,17 @@ class OriginalSpatialFilter(SpatialFilter):
That is why only OriginalSpatialFilter supports transformation.
"""

def __init__(self, crs_spec, geometry_spec, match_all=False):
if match_all:
super().__init__(None, None, match_all=True)
self.hexhash = None
else:
ctx = "spatial filter"
geometry = geometry_from_string(geometry_spec, context=ctx)
crs = make_crs(crs_spec, context=ctx)
super().__init__(crs, geometry.to_ogr())
self.hexhash = hexhash(crs_spec.strip(), geometry.to_wkb())

@property
def is_original(self):
return True
Expand Down Expand Up @@ -361,12 +433,19 @@ def transform_for_schema_and_crs(self, schema, crs, ds_path=None):
transform = osr.CoordinateTransformation(self.crs, crs)
new_filter_ogr = self.filter_ogr.Clone()
new_filter_ogr.Transform(transform)
return SpatialFilter(new_filter_ogr, crs, new_geom_column_name)
return SpatialFilter(crs, new_filter_ogr, new_geom_column_name)

except RuntimeError as e:
crs_desc = f"CRS for {ds_path!r}" if ds_path else f"CRS:\n {crs_spec!r}"
raise CrsError(f"Can't reproject spatial filter into {crs_desc}:\n{e}")

def matches_working_copy(self, repo):
working_copy = repo.working_copy
return (
working_copy is None
or working_copy.get_spatial_filter_hash() == self.hexhash
)


# A SpatialFilter object that matches everything.
SpatialFilter._MATCH_ALL = SpatialFilter(None, None, match_all=True)
Expand Down
Loading