diff --git a/CHANGELOG.md b/CHANGELOG.md index 45972066f..2d9a45d4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update job_id after job submission - Fixed default event.uuid - Fixed atlas vector search +- Fix the bug where shared artifacts are deleted when removing a component. #### New Features & Functionality diff --git a/plugins/mongodb/plugin_test/test_metadata.py b/plugins/mongodb/plugin_test/test_metadata.py index 4f29e7ff8..26adc4e9a 100644 --- a/plugins/mongodb/plugin_test/test_metadata.py +++ b/plugins/mongodb/plugin_test/test_metadata.py @@ -5,7 +5,7 @@ from superduper_mongodb.metadata import MongoMetaDataStore -DATABASE_URL = CFG.metadata_store or "mongomock://test_db" +DATABASE_URL = CFG.metadata_store or CFG.data_backend or "mongomock://test_db" @pytest.fixture @@ -25,3 +25,7 @@ def test_parent_child(metadata): def test_job(metadata): metadata_utils.test_job(metadata) + + +def test_artifact_relation(metadata): + metadata_utils.test_artifact_relation(metadata) diff --git a/plugins/mongodb/superduper_mongodb/__init__.py b/plugins/mongodb/superduper_mongodb/__init__.py index 19c658151..1b91265e4 100644 --- a/plugins/mongodb/superduper_mongodb/__init__.py +++ b/plugins/mongodb/superduper_mongodb/__init__.py @@ -3,7 +3,7 @@ from .metadata import MongoMetaDataStore as MetaDataStore from .query import MongoQuery -__version__ = "0.0.3" +__version__ = "0.0.4" __all__ = [ "ArtifactStore", diff --git a/plugins/mongodb/superduper_mongodb/metadata.py b/plugins/mongodb/superduper_mongodb/metadata.py index 4ba5aa25b..f16f91858 100644 --- a/plugins/mongodb/superduper_mongodb/metadata.py +++ b/plugins/mongodb/superduper_mongodb/metadata.py @@ -41,6 +41,7 @@ def _setup(self): self.component_collection = self.db['_objects'] self.job_collection = self.db['_jobs'] self.parent_child_mappings = self.db['_parent_child_mappings'] + self.artifact_relations = self.db['_artifact_relations'] def reconnect(self): """Reconnect to metdata store.""" @@ -69,6 +70,7 @@ def drop(self, force: bool = False): self.db.drop_collection(self.component_collection.name) self.db.drop_collection(self.job_collection.name) self.db.drop_collection(self.parent_child_mappings.name) + self.db.drop_collection(self.artifact_relations.name) def delete_parent_child(self, parent: str, child: str) -> None: """ @@ -97,6 +99,18 @@ def create_parent_child(self, parent: str, child: str) -> None: } ) + def _create_data(self, table_name, datas): + collection = self.db[table_name] + collection.insert_many(datas) + + def _delete_data(self, table_name, filter): + collection = self.db[table_name] + collection.delete_many(filter) + + def _get_data(self, table_name, filter): + collection = self.db[table_name] + return list(collection.find(filter)) + def create_component(self, info: t.Dict) -> InsertOneResult: """Create a component in the metadata store. diff --git a/plugins/sqlalchemy/plugin_test/test_metadata.py b/plugins/sqlalchemy/plugin_test/test_metadata.py index d60f27bb9..3f49dd529 100644 --- a/plugins/sqlalchemy/plugin_test/test_metadata.py +++ b/plugins/sqlalchemy/plugin_test/test_metadata.py @@ -25,3 +25,7 @@ def test_parent_child(metadata): def test_job(metadata): metadata_utils.test_job(metadata) + + +def test_artifact_relation(metadata): + metadata_utils.test_artifact_relation(metadata) diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py b/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py index 22bd035e4..3a6d6a67c 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ from .metadata import SQLAlchemyMetadata as MetaDataStore -__version__ = "0.0.2" +__version__ = "0.0.3" __all__ = ['MetaDataStore'] diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py index 93c186506..0cb58bc34 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py @@ -123,8 +123,44 @@ def _init_tables(self): *component_table_args, ) + self.artifact_table = Table( + 'ARTIFACT_RELATIONS', + metadata, + Column('uuid', type_string), + Column('artifact_id', type_integer), + *component_table_args, + ) + + self._table_mapping = { + '_artifact_relations': self.artifact_table, + } + metadata.create_all(self.conn) + def _create_data(self, table_name, datas): + table = self._table_mapping[table_name] + with self.session_context() as session: + for data in datas: + stmt = insert(table).values(**data) + session.execute(stmt) + + def _delete_data(self, table_name, filter): + table = self._table_mapping[table_name] + + with self.session_context() as session: + conditions = [getattr(table.c, k) == v for k, v in filter.items()] + stmt = delete(table).where(*conditions) + session.execute(stmt) + + def _get_data(self, table_name, filter): + table = self._table_mapping[table_name] + + with self.session_context() as session: + conditions = [getattr(table.c, k) == v for k, v in filter.items()] + stmt = select(table).where(*conditions) + res = self.query_results(table, stmt, session) + return res + def url(self): """Return the URL of the metadata store.""" return self.conn.url + self.name @@ -142,6 +178,7 @@ def drop(self, force: bool = False): default=False, ): logging.warn('Aborting...') + try: self.job_table.drop(self.conn) except ProgrammingError as e: @@ -157,6 +194,11 @@ def drop(self, force: bool = False): except ProgrammingError as e: logging.warn(f'Error dropping component table {e}') + try: + self.artifact_table.drop(self.conn) + except ProgrammingError as e: + logging.warn(f'Error dropping artifact table {e}') + @contextmanager def session_context(self): """Provide a transactional scope around a series of operations.""" diff --git a/superduper/backends/base/artifacts.py b/superduper/backends/base/artifacts.py index fed6380a2..87721ca6b 100644 --- a/superduper/backends/base/artifacts.py +++ b/superduper/backends/base/artifacts.py @@ -126,34 +126,16 @@ def save_artifact(self, r: t.Dict): return r - def delete_artifact(self, r: t.Dict): + def delete_artifact(self, artifact_ids: t.List[str]): """Delete artifact from artifact store. :param r: dictionary with mandatory fields """ - from superduper.misc.special_dicts import recursive_find - - # find all blobs with `&:blob:` prefix, - blobs = recursive_find( - r, lambda v: isinstance(v, str) and v.startswith('&:blob:') - ) - - for blob in blobs: - try: - self._delete_bytes(blob.split(':')[-1]) - except FileNotFoundError: - logging.warn(f'Blob {blob} not found in artifact store') - - # find all files with `&:file:` prefix - files = recursive_find( - r, lambda v: isinstance(v, str) and v.startswith('&:file:') - ) - for file_path in files: - # file: &:file:file_id + for artifact_id in artifact_ids: try: - self._delete_bytes(file_path.split(':')[-1]) + self._delete_bytes(artifact_id) except FileNotFoundError: - logging.warn(f'File {file_path} not found in artifact store') + logging.warn(f'Blob {artifact_id} not found in artifact store') @abstractmethod def get_bytes(self, file_id: str) -> bytes: diff --git a/superduper/backends/base/metadata.py b/superduper/backends/base/metadata.py index 27503cc2b..5dc6fef5d 100644 --- a/superduper/backends/base/metadata.py +++ b/superduper/backends/base/metadata.py @@ -76,6 +76,77 @@ def create_parent_child(self, parent: str, child: str): """ pass + def create_artifact_relation(self, uuid, artifact_ids): + """ + Create a relation between an artifact and a component version. + + :param uuid: UUID of component version + :param artifact: artifact + """ + artifact_ids = ( + [artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids + ) + data = [] + for artifact_id in artifact_ids: + data.append({'uuid': uuid, 'artifact_id': artifact_id}) + + if data: + self._create_data('_artifact_relations', data) + + def delete_artifact_relation(self, uuid, artifact_ids): + """ + Delete a relation between an artifact and a component version. + + :param uuid: UUID of component version + :param artifact: artifact + """ + artifact_ids = ( + [artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids + ) + for artifact_id in artifact_ids: + self._delete_data( + '_artifact_relations', + { + 'uuid': uuid, + 'artifact_id': artifact_id, + }, + ) + + def get_artifact_relations(self, uuid=None, artifact_id=None): + """ + Get all relations between an artifact and a component version. + + :param artifact_id: artifact + """ + if uuid is None and artifact_id is None: + raise ValueError('Either `uuid` or `artifact_id` must be provided') + elif uuid: + relations = self._get_data( + '_artifact_relations', + {'uuid': uuid}, + ) + ids = [relation['artifact_id'] for relation in relations] + else: + relations = self._get_data( + '_artifact_relations', + {'artifact_id': artifact_id}, + ) + ids = [relation['uuid'] for relation in relations] + return ids + + # TODO: Refactor to use _create_data, _delete_data, _get_data + @abstractmethod + def _create_data(self, table_name, datas): + pass + + @abstractmethod + def _delete_data(self, table_name, filter): + pass + + @abstractmethod + def _get_data(self, table_name, filter): + pass + @abstractmethod def drop(self, force: bool = False): """ diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index ac7b39fe4..e89c88b1f 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -699,7 +699,7 @@ def _apply( if children: serialized = self._change_component_reference_prefix(serialized) - serialized = self.artifact_store.save_artifact(serialized) + serialized = self._save_artifact(object.uuid, serialized) if artifacts: for file_id, bytes in artifacts.items(): self.artifact_store.put_bytes(bytes, file_id) @@ -786,7 +786,7 @@ def _remove_component_version( except KeyError: pass - self.artifact_store.delete_artifact(info) + self._delete_artifacts(r['uuid'], info) self.metadata.delete_component_version(type_id, identifier, version=version) def _get_content_for_filter(self, filter) -> Document: @@ -848,9 +848,9 @@ def replace( if children: serialized = self._change_component_reference_prefix(serialized) - self.artifact_store.delete_artifact(info) + self._delete_artifacts(object.uuid, info) - serialized = self.artifact_store.save_artifact(serialized) + serialized = self._save_artifact(object.uuid, serialized) self.metadata.replace_object( serialized, @@ -859,6 +859,45 @@ def replace( version=object.version, ) + def _save_artifact(self, uuid, info: t.Dict): + """ + Save an artifact to the artifact store. + + :param artifact: The artifact to save. + """ + artifact_ids, _ = self._find_artifacts(info) + self.metadata.create_artifact_relation(uuid, artifact_ids) + return self.artifact_store.save_artifact(info) + + def _delete_artifacts(self, uuid, info: t.Dict): + artifact_ids, artifacts = self._find_artifacts(info) + for artifact_id in artifact_ids: + relation_uuids = self.metadata.get_artifact_relations( + artifact_id=artifact_id + ) + if len(relation_uuids) == 1 and relation_uuids[0] == uuid: + self.artifact_store.delete_artifact([artifact_id]) + self.metadata.delete_artifact_relation( + uuid=uuid, artifact_ids=artifact_id + ) + + def _find_artifacts(self, info: t.Dict): + from superduper.misc.special_dicts import recursive_find + + # find all blobs with `&:blob:` prefix, + blobs = recursive_find( + info, lambda v: isinstance(v, str) and v.startswith('&:blob:') + ) + + # find all files with `&:file:` prefix + files = recursive_find( + info, lambda v: isinstance(v, str) and v.startswith('&:file:') + ) + artifact_ids: list[str] = [] + artifact_ids.extend(a.split(":")[-1] for a in blobs) + artifact_ids.extend(a.split(":")[-1] for a in files) + return artifact_ids, {'blobs': blobs, 'files': files} + def select_nearest( self, like: t.Union[t.Dict, Document], diff --git a/test/unittest/backends/base/test_metadata.py b/test/unittest/backends/base/test_metadata.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index a9b90b512..4664fa3c3 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -605,6 +605,28 @@ def test_dataset(db): assert len(dataset.data) == len(list(db.execute(dataset.select))) +def test_delete_componet_with_same_artifact(db): + from superduper import ObjectModel + + model1 = ObjectModel( + object=lambda x: x + 1, + identifier='model1', + ) + + model2 = ObjectModel( + object=model1.object, + identifier='model2', + ) + + db.apply(model1) + db.apply(model2) + + db.remove('model', 'model1', force=True) + model2 = db.load('model', 'model2') + model2.init() + assert model2.predict(1) == 2 + + def test_retry_on_token_expiry(db): # Mock the methods db.retry = 1 diff --git a/test/utils/database/metadata.py b/test/utils/database/metadata.py index cbd75c938..ebd67bf80 100644 --- a/test/utils/database/metadata.py +++ b/test/utils/database/metadata.py @@ -90,6 +90,29 @@ def test_job(metadata: MetaDataStore): assert job_get["status"] == "running" +def test_artifact_relation(metadata: MetaDataStore): + uuid_1 = str(uuid.uuid4()) + artifact_ids_1 = ["artifact-1", "artifact-2"] + uuid_2 = str(uuid.uuid4()) + artifact_ids_2 = ["artifact-3", "artifact-4"] + + metadata.create_artifact_relation(uuid_1, artifact_ids_1) + metadata.create_artifact_relation(uuid_2, artifact_ids_2) + + assert metadata.get_artifact_relations(uuid=uuid_1) == artifact_ids_1 + assert metadata.get_artifact_relations(artifact_id="artifact-1") == [uuid_1] + assert metadata.get_artifact_relations(artifact_id="artifact-2") == [uuid_1] + + assert metadata.get_artifact_relations(uuid=uuid_2) == artifact_ids_2 + assert metadata.get_artifact_relations(artifact_id="artifact-3") == [uuid_2] + assert metadata.get_artifact_relations(artifact_id="artifact-4") == [uuid_2] + + metadata.delete_artifact_relation(uuid_1, artifact_ids_1[0]) + assert metadata.get_artifact_relations(uuid=uuid_1) == [artifact_ids_1[1]] + + assert metadata.get_artifact_relations(uuid=uuid_2) == artifact_ids_2 + + def _create_components(type_ids, identifiers, versions, metadata): versions = versions or [0] uuid2component = {}