From 307524d0026ef93b25e246a70869ee195ef0ead9 Mon Sep 17 00:00:00 2001 From: Peter Beaucage Date: Fri, 10 May 2024 09:36:14 -0400 Subject: [PATCH] Skip duplicate data during copy from one catalog to another (#737) * Add basic duplicate checking to sync.copy * fix lint * Gate error override behind skip_duplicates flag * Change skip_duplicates to on_conflict * Add changelog entry for conflict handling * Add tests for conflict behavior on copy --- CHANGELOG.md | 2 ++ tiled/_tests/test_sync.py | 31 +++++++++++++++++++++++ tiled/client/sync.py | 53 ++++++++++++++++++++++++++++----------- 3 files changed, 71 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d1dcf8b9..cbad59cc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ Write the date in place of the "Unreleased" in the case a new version is release dictionary with 'selected' as the key, to match default type/behavior. - The method `BaseClient.data_sources()` returns dataclass objects instead of raw dict objects. +- `tiled.client.sync` has conflict handling, with initial options of 'error' + (default), 'warn', and 'skip' ### Fixed diff --git a/tiled/_tests/test_sync.py b/tiled/_tests/test_sync.py index 1722e2e69..468054618 100644 --- a/tiled/_tests/test_sync.py +++ b/tiled/_tests/test_sync.py @@ -6,6 +6,7 @@ import h5py import numpy import pandas +import pytest import sparse import tifffile @@ -14,6 +15,7 @@ from tiled.client.register import register from tiled.client.smoke import read from tiled.client.sync import copy +from tiled.client.utils import ClientError from tiled.queries import Key from tiled.server.app import build_app @@ -88,6 +90,35 @@ def test_copy_internal(): read(dest, strict=True) +def test_copy_skip_conflict(): + with client_factory() as dest: + with client_factory() as source: + populate_internal(source) + copy(source, dest) + copy(source, dest, on_conflict="skip") + assert list(source) == list(dest) + assert list(source["c"]) == list(dest["c"]) + read(dest, strict=True) + + +def test_copy_warn_conflict(): + with client_factory() as dest: + with client_factory() as source: + populate_internal(source) + copy(source, dest) + with pytest.warns(UserWarning): + copy(source, dest, on_conflict="warn") + + +def test_copy_error_conflict(): + with client_factory() as dest: + with client_factory() as source: + populate_internal(source) + copy(source, dest) + with pytest.raises(ClientError): + copy(source, dest) + + def test_copy_external(tmp_path): with client_factory(readable_storage=[tmp_path]) as dest: with client_factory() as source: diff --git a/tiled/client/sync.py b/tiled/client/sync.py index 230ea20a9..b595dd2a1 100644 --- a/tiled/client/sync.py +++ b/tiled/client/sync.py @@ -1,13 +1,18 @@ import itertools +import warnings + +import httpx from ..structures.core import StructureFamily from ..structures.data_source import DataSource, Management from .base import BaseClient +from .utils import ClientError def copy( source: BaseClient, dest: BaseClient, + on_conflict: str = "error", ): """ Copy data from one Tiled instance to another. @@ -16,6 +21,7 @@ def copy( ---------- source : tiled node dest : tiled node + on_conflict : str, default 'error', other options 'warn', 'skip' Examples -------- @@ -34,15 +40,20 @@ def copy( >>> copy(a.items().head(), b) >>> copy(a.search(...), b) + Copy and ignore duplicates. + + >>> copy(a, b, on_conflict = 'skip') """ if hasattr(source, "structure_family"): # looks like a client object - _DISPATCH[source.structure_family](source.include_data_sources(), dest) + _DISPATCH[source.structure_family]( + source.include_data_sources(), dest, on_conflict + ) else: - _DISPATCH[StructureFamily.container](dict(source), dest) + _DISPATCH[StructureFamily.container](dict(source), dest, on_conflict) -def _copy_array(source, dest): +def _copy_array(source, dest, on_conflict): num_blocks = (range(len(n)) for n in source.chunks) # Loop over each block index --- e.g. (0, 0), (0, 1), (0, 2) .... for block in itertools.product(*num_blocks): @@ -50,7 +61,7 @@ def _copy_array(source, dest): dest.write_block(array, block) -def _copy_awkward(source, dest): +def _copy_awkward(source, dest, on_conflict): import awkward array = source.read() @@ -58,7 +69,7 @@ def _copy_awkward(source, dest): dest.write(container) -def _copy_sparse(source, dest): +def _copy_sparse(source, dest, on_conflict): num_blocks = (range(len(n)) for n in source.chunks) # Loop over each block index --- e.g. (0, 0), (0, 1), (0, 2) .... for block in itertools.product(*num_blocks): @@ -66,13 +77,13 @@ def _copy_sparse(source, dest): dest.write_block(array.coords, array.data, block) -def _copy_table(source, dest): +def _copy_table(source, dest, on_conflict): for partition in range(source.structure().npartitions): df = source.read_partition(partition) dest.write_partition(df, partition) -def _copy_container(source, dest): +def _copy_container(source, dest, on_conflict): for key, child_node in source.items(): original_data_sources = child_node.include_data_sources().data_sources() num_data_sources = len(original_data_sources) @@ -108,13 +119,23 @@ def _copy_container(source, dest): raise NotImplementedError( "Multiple Data Sources in one Node is not supported." ) - node = dest.new( - key=key, - structure_family=child_node.structure_family, - data_sources=data_sources, - metadata=dict(child_node.metadata), - specs=child_node.specs, - ) + try: + node = dest.new( + key=key, + structure_family=child_node.structure_family, + data_sources=data_sources, + metadata=dict(child_node.metadata), + specs=child_node.specs, + ) + except ClientError as err: + if ( + on_conflict == "skip" or on_conflict == "warn" + ) and err.response.status_code == httpx.codes.CONFLICT: + if on_conflict == "warn": + warnings.warn("Skipped existing entry") + continue + else: + raise err if ( original_data_sources and (original_data_sources[0].management != Management.external) @@ -122,7 +143,9 @@ def _copy_container(source, dest): child_node.structure_family == StructureFamily.container and (not original_data_sources) ): - _DISPATCH[child_node.structure_family](child_node, node) + _DISPATCH[child_node.structure_family]( + child_node, node, on_conflict=on_conflict + ) _DISPATCH = {