Skip to content

Commit

Permalink
Fix the bug where shared artifacts are deleted (#2446)
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou authored Sep 11, 2024
1 parent 502694a commit 51bdf94
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update job_id after job submission
- Fixed default event.uuid
- Fixed atlas vector search
- Fix the bug where shared artifacts are deleted when removing a component.

#### New Features & Functionality

Expand Down
6 changes: 5 additions & 1 deletion plugins/mongodb/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from superduper_mongodb.metadata import MongoMetaDataStore

DATABASE_URL = CFG.metadata_store or "mongomock://test_db"
DATABASE_URL = CFG.metadata_store or CFG.data_backend or "mongomock://test_db"


@pytest.fixture
Expand All @@ -25,3 +25,7 @@ def test_parent_child(metadata):

def test_job(metadata):
metadata_utils.test_job(metadata)


def test_artifact_relation(metadata):
metadata_utils.test_artifact_relation(metadata)
2 changes: 1 addition & 1 deletion plugins/mongodb/superduper_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .metadata import MongoMetaDataStore as MetaDataStore
from .query import MongoQuery

__version__ = "0.0.3"
__version__ = "0.0.4"

__all__ = [
"ArtifactStore",
Expand Down
14 changes: 14 additions & 0 deletions plugins/mongodb/superduper_mongodb/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _setup(self):
self.component_collection = self.db['_objects']
self.job_collection = self.db['_jobs']
self.parent_child_mappings = self.db['_parent_child_mappings']
self.artifact_relations = self.db['_artifact_relations']

def reconnect(self):
"""Reconnect to metdata store."""
Expand Down Expand Up @@ -69,6 +70,7 @@ def drop(self, force: bool = False):
self.db.drop_collection(self.component_collection.name)
self.db.drop_collection(self.job_collection.name)
self.db.drop_collection(self.parent_child_mappings.name)
self.db.drop_collection(self.artifact_relations.name)

def delete_parent_child(self, parent: str, child: str) -> None:
"""
Expand Down Expand Up @@ -97,6 +99,18 @@ def create_parent_child(self, parent: str, child: str) -> None:
}
)

def _create_data(self, table_name, datas):
collection = self.db[table_name]
collection.insert_many(datas)

def _delete_data(self, table_name, filter):
collection = self.db[table_name]
collection.delete_many(filter)

def _get_data(self, table_name, filter):
collection = self.db[table_name]
return list(collection.find(filter))

def create_component(self, info: t.Dict) -> InsertOneResult:
"""Create a component in the metadata store.
Expand Down
4 changes: 4 additions & 0 deletions plugins/sqlalchemy/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ def test_parent_child(metadata):

def test_job(metadata):
metadata_utils.test_job(metadata)


def test_artifact_relation(metadata):
metadata_utils.test_artifact_relation(metadata)
2 changes: 1 addition & 1 deletion plugins/sqlalchemy/superduper_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .metadata import SQLAlchemyMetadata as MetaDataStore

__version__ = "0.0.2"
__version__ = "0.0.3"

__all__ = ['MetaDataStore']
42 changes: 42 additions & 0 deletions plugins/sqlalchemy/superduper_sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,44 @@ def _init_tables(self):
*component_table_args,
)

self.artifact_table = Table(
'ARTIFACT_RELATIONS',
metadata,
Column('uuid', type_string),
Column('artifact_id', type_integer),
*component_table_args,
)

self._table_mapping = {
'_artifact_relations': self.artifact_table,
}

metadata.create_all(self.conn)

def _create_data(self, table_name, datas):
table = self._table_mapping[table_name]
with self.session_context() as session:
for data in datas:
stmt = insert(table).values(**data)
session.execute(stmt)

def _delete_data(self, table_name, filter):
table = self._table_mapping[table_name]

with self.session_context() as session:
conditions = [getattr(table.c, k) == v for k, v in filter.items()]
stmt = delete(table).where(*conditions)
session.execute(stmt)

def _get_data(self, table_name, filter):
table = self._table_mapping[table_name]

with self.session_context() as session:
conditions = [getattr(table.c, k) == v for k, v in filter.items()]
stmt = select(table).where(*conditions)
res = self.query_results(table, stmt, session)
return res

def url(self):
"""Return the URL of the metadata store."""
return self.conn.url + self.name
Expand All @@ -142,6 +178,7 @@ def drop(self, force: bool = False):
default=False,
):
logging.warn('Aborting...')

try:
self.job_table.drop(self.conn)
except ProgrammingError as e:
Expand All @@ -157,6 +194,11 @@ def drop(self, force: bool = False):
except ProgrammingError as e:
logging.warn(f'Error dropping component table {e}')

try:
self.artifact_table.drop(self.conn)
except ProgrammingError as e:
logging.warn(f'Error dropping artifact table {e}')

@contextmanager
def session_context(self):
"""Provide a transactional scope around a series of operations."""
Expand Down
26 changes: 4 additions & 22 deletions superduper/backends/base/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,16 @@ def save_artifact(self, r: t.Dict):

return r

def delete_artifact(self, r: t.Dict):
def delete_artifact(self, artifact_ids: t.List[str]):
"""Delete artifact from artifact store.
:param r: dictionary with mandatory fields
"""
from superduper.misc.special_dicts import recursive_find

# find all blobs with `&:blob:` prefix,
blobs = recursive_find(
r, lambda v: isinstance(v, str) and v.startswith('&:blob:')
)

for blob in blobs:
try:
self._delete_bytes(blob.split(':')[-1])
except FileNotFoundError:
logging.warn(f'Blob {blob} not found in artifact store')

# find all files with `&:file:` prefix
files = recursive_find(
r, lambda v: isinstance(v, str) and v.startswith('&:file:')
)
for file_path in files:
# file: &:file:file_id
for artifact_id in artifact_ids:
try:
self._delete_bytes(file_path.split(':')[-1])
self._delete_bytes(artifact_id)
except FileNotFoundError:
logging.warn(f'File {file_path} not found in artifact store')
logging.warn(f'Blob {artifact_id} not found in artifact store')

@abstractmethod
def get_bytes(self, file_id: str) -> bytes:
Expand Down
71 changes: 71 additions & 0 deletions superduper/backends/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,77 @@ def create_parent_child(self, parent: str, child: str):
"""
pass

def create_artifact_relation(self, uuid, artifact_ids):
"""
Create a relation between an artifact and a component version.
:param uuid: UUID of component version
:param artifact: artifact
"""
artifact_ids = (
[artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids
)
data = []
for artifact_id in artifact_ids:
data.append({'uuid': uuid, 'artifact_id': artifact_id})

if data:
self._create_data('_artifact_relations', data)

def delete_artifact_relation(self, uuid, artifact_ids):
"""
Delete a relation between an artifact and a component version.
:param uuid: UUID of component version
:param artifact: artifact
"""
artifact_ids = (
[artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids
)
for artifact_id in artifact_ids:
self._delete_data(
'_artifact_relations',
{
'uuid': uuid,
'artifact_id': artifact_id,
},
)

def get_artifact_relations(self, uuid=None, artifact_id=None):
"""
Get all relations between an artifact and a component version.
:param artifact_id: artifact
"""
if uuid is None and artifact_id is None:
raise ValueError('Either `uuid` or `artifact_id` must be provided')
elif uuid:
relations = self._get_data(
'_artifact_relations',
{'uuid': uuid},
)
ids = [relation['artifact_id'] for relation in relations]
else:
relations = self._get_data(
'_artifact_relations',
{'artifact_id': artifact_id},
)
ids = [relation['uuid'] for relation in relations]
return ids

# TODO: Refactor to use _create_data, _delete_data, _get_data
@abstractmethod
def _create_data(self, table_name, datas):
pass

@abstractmethod
def _delete_data(self, table_name, filter):
pass

@abstractmethod
def _get_data(self, table_name, filter):
pass

@abstractmethod
def drop(self, force: bool = False):
"""
Expand Down
47 changes: 43 additions & 4 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def _apply(
if children:
serialized = self._change_component_reference_prefix(serialized)

serialized = self.artifact_store.save_artifact(serialized)
serialized = self._save_artifact(object.uuid, serialized)
if artifacts:
for file_id, bytes in artifacts.items():
self.artifact_store.put_bytes(bytes, file_id)
Expand Down Expand Up @@ -786,7 +786,7 @@ def _remove_component_version(
except KeyError:
pass

self.artifact_store.delete_artifact(info)
self._delete_artifacts(r['uuid'], info)
self.metadata.delete_component_version(type_id, identifier, version=version)

def _get_content_for_filter(self, filter) -> Document:
Expand Down Expand Up @@ -848,9 +848,9 @@ def replace(
if children:
serialized = self._change_component_reference_prefix(serialized)

self.artifact_store.delete_artifact(info)
self._delete_artifacts(object.uuid, info)

serialized = self.artifact_store.save_artifact(serialized)
serialized = self._save_artifact(object.uuid, serialized)

self.metadata.replace_object(
serialized,
Expand All @@ -859,6 +859,45 @@ def replace(
version=object.version,
)

def _save_artifact(self, uuid, info: t.Dict):
"""
Save an artifact to the artifact store.
:param artifact: The artifact to save.
"""
artifact_ids, _ = self._find_artifacts(info)
self.metadata.create_artifact_relation(uuid, artifact_ids)
return self.artifact_store.save_artifact(info)

def _delete_artifacts(self, uuid, info: t.Dict):
artifact_ids, artifacts = self._find_artifacts(info)
for artifact_id in artifact_ids:
relation_uuids = self.metadata.get_artifact_relations(
artifact_id=artifact_id
)
if len(relation_uuids) == 1 and relation_uuids[0] == uuid:
self.artifact_store.delete_artifact([artifact_id])
self.metadata.delete_artifact_relation(
uuid=uuid, artifact_ids=artifact_id
)

def _find_artifacts(self, info: t.Dict):
from superduper.misc.special_dicts import recursive_find

# find all blobs with `&:blob:` prefix,
blobs = recursive_find(
info, lambda v: isinstance(v, str) and v.startswith('&:blob:')
)

# find all files with `&:file:` prefix
files = recursive_find(
info, lambda v: isinstance(v, str) and v.startswith('&:file:')
)
artifact_ids: list[str] = []
artifact_ids.extend(a.split(":")[-1] for a in blobs)
artifact_ids.extend(a.split(":")[-1] for a in files)
return artifact_ids, {'blobs': blobs, 'files': files}

def select_nearest(
self,
like: t.Union[t.Dict, Document],
Expand Down
Empty file.
22 changes: 22 additions & 0 deletions test/unittest/base/test_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,28 @@ def test_dataset(db):
assert len(dataset.data) == len(list(db.execute(dataset.select)))


def test_delete_componet_with_same_artifact(db):
from superduper import ObjectModel

model1 = ObjectModel(
object=lambda x: x + 1,
identifier='model1',
)

model2 = ObjectModel(
object=model1.object,
identifier='model2',
)

db.apply(model1)
db.apply(model2)

db.remove('model', 'model1', force=True)
model2 = db.load('model', 'model2')
model2.init()
assert model2.predict(1) == 2


def test_retry_on_token_expiry(db):
# Mock the methods
db.retry = 1
Expand Down
Loading

0 comments on commit 51bdf94

Please sign in to comment.