Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug where shared artifacts are deleted #2446

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update job_id after job submission
- Fixed default event.uuid
- Fixed atlas vector search
- Fix the bug where shared artifacts are deleted when removing a component.

#### New Features & Functionality

Expand Down
6 changes: 5 additions & 1 deletion plugins/mongodb/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from superduper_mongodb.metadata import MongoMetaDataStore

DATABASE_URL = CFG.metadata_store or "mongomock://test_db"
DATABASE_URL = CFG.metadata_store or CFG.data_backend or "mongomock://test_db"


@pytest.fixture
Expand All @@ -25,3 +25,7 @@ def test_parent_child(metadata):

def test_job(metadata):
metadata_utils.test_job(metadata)


def test_artifact_relation(metadata):
metadata_utils.test_artifact_relation(metadata)
2 changes: 1 addition & 1 deletion plugins/mongodb/superduper_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .metadata import MongoMetaDataStore as MetaDataStore
from .query import MongoQuery

__version__ = "0.0.3"
__version__ = "0.0.4"

__all__ = [
"ArtifactStore",
Expand Down
14 changes: 14 additions & 0 deletions plugins/mongodb/superduper_mongodb/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _setup(self):
self.component_collection = self.db['_objects']
self.job_collection = self.db['_jobs']
self.parent_child_mappings = self.db['_parent_child_mappings']
self.artifact_relations = self.db['_artifact_relations']

def reconnect(self):
"""Reconnect to metdata store."""
Expand Down Expand Up @@ -69,6 +70,7 @@ def drop(self, force: bool = False):
self.db.drop_collection(self.component_collection.name)
self.db.drop_collection(self.job_collection.name)
self.db.drop_collection(self.parent_child_mappings.name)
self.db.drop_collection(self.artifact_relations.name)

def delete_parent_child(self, parent: str, child: str) -> None:
"""
Expand Down Expand Up @@ -97,6 +99,18 @@ def create_parent_child(self, parent: str, child: str) -> None:
}
)

def _create_data(self, table_name, datas):
collection = self.db[table_name]
collection.insert_many(datas)

def _delete_data(self, table_name, filter):
collection = self.db[table_name]
collection.delete_many(filter)

def _get_data(self, table_name, filter):
collection = self.db[table_name]
return list(collection.find(filter))

def create_component(self, info: t.Dict) -> InsertOneResult:
"""Create a component in the metadata store.

Expand Down
4 changes: 4 additions & 0 deletions plugins/sqlalchemy/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ def test_parent_child(metadata):

def test_job(metadata):
metadata_utils.test_job(metadata)


def test_artifact_relation(metadata):
metadata_utils.test_artifact_relation(metadata)
2 changes: 1 addition & 1 deletion plugins/sqlalchemy/superduper_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .metadata import SQLAlchemyMetadata as MetaDataStore

__version__ = "0.0.2"
__version__ = "0.0.3"

__all__ = ['MetaDataStore']
42 changes: 42 additions & 0 deletions plugins/sqlalchemy/superduper_sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,44 @@ def _init_tables(self):
*component_table_args,
)

self.artifact_table = Table(
'ARTIFACT_RELATIONS',
metadata,
Column('uuid', type_string),
Column('artifact_id', type_integer),
*component_table_args,
)

self._table_mapping = {
'_artifact_relations': self.artifact_table,
}

metadata.create_all(self.conn)

def _create_data(self, table_name, datas):
table = self._table_mapping[table_name]
with self.session_context() as session:
for data in datas:
stmt = insert(table).values(**data)
session.execute(stmt)

def _delete_data(self, table_name, filter):
table = self._table_mapping[table_name]

with self.session_context() as session:
conditions = [getattr(table.c, k) == v for k, v in filter.items()]
stmt = delete(table).where(*conditions)
session.execute(stmt)

def _get_data(self, table_name, filter):
table = self._table_mapping[table_name]

with self.session_context() as session:
conditions = [getattr(table.c, k) == v for k, v in filter.items()]
stmt = select(table).where(*conditions)
res = self.query_results(table, stmt, session)
return res

def url(self):
"""Return the URL of the metadata store."""
return self.conn.url + self.name
Expand All @@ -142,6 +178,7 @@ def drop(self, force: bool = False):
default=False,
):
logging.warn('Aborting...')

try:
self.job_table.drop(self.conn)
except ProgrammingError as e:
Expand All @@ -157,6 +194,11 @@ def drop(self, force: bool = False):
except ProgrammingError as e:
logging.warn(f'Error dropping component table {e}')

try:
self.artifact_table.drop(self.conn)
except ProgrammingError as e:
logging.warn(f'Error dropping artifact table {e}')

@contextmanager
def session_context(self):
"""Provide a transactional scope around a series of operations."""
Expand Down
26 changes: 4 additions & 22 deletions superduper/backends/base/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,16 @@ def save_artifact(self, r: t.Dict):

return r

def delete_artifact(self, r: t.Dict):
def delete_artifact(self, artifact_ids: t.List[str]):
"""Delete artifact from artifact store.

:param r: dictionary with mandatory fields
"""
from superduper.misc.special_dicts import recursive_find

# find all blobs with `&:blob:` prefix,
blobs = recursive_find(
r, lambda v: isinstance(v, str) and v.startswith('&:blob:')
)

for blob in blobs:
try:
self._delete_bytes(blob.split(':')[-1])
except FileNotFoundError:
logging.warn(f'Blob {blob} not found in artifact store')

# find all files with `&:file:` prefix
files = recursive_find(
r, lambda v: isinstance(v, str) and v.startswith('&:file:')
)
for file_path in files:
# file: &:file:file_id
for artifact_id in artifact_ids:
try:
self._delete_bytes(file_path.split(':')[-1])
self._delete_bytes(artifact_id)
except FileNotFoundError:
logging.warn(f'File {file_path} not found in artifact store')
logging.warn(f'Blob {artifact_id} not found in artifact store')

@abstractmethod
def get_bytes(self, file_id: str) -> bytes:
Expand Down
71 changes: 71 additions & 0 deletions superduper/backends/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,77 @@ def create_parent_child(self, parent: str, child: str):
"""
pass

def create_artifact_relation(self, uuid, artifact_ids):
"""
Create a relation between an artifact and a component version.

:param uuid: UUID of component version
:param artifact: artifact
"""
artifact_ids = (
[artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids
)
data = []
for artifact_id in artifact_ids:
data.append({'uuid': uuid, 'artifact_id': artifact_id})

if data:
self._create_data('_artifact_relations', data)

def delete_artifact_relation(self, uuid, artifact_ids):
"""
Delete a relation between an artifact and a component version.

:param uuid: UUID of component version
:param artifact: artifact
"""
artifact_ids = (
[artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids
)
for artifact_id in artifact_ids:
self._delete_data(
'_artifact_relations',
{
'uuid': uuid,
'artifact_id': artifact_id,
},
)

def get_artifact_relations(self, uuid=None, artifact_id=None):
"""
Get all relations between an artifact and a component version.

:param artifact_id: artifact
"""
if uuid is None and artifact_id is None:
raise ValueError('Either `uuid` or `artifact_id` must be provided')
elif uuid:
relations = self._get_data(
'_artifact_relations',
{'uuid': uuid},
)
ids = [relation['artifact_id'] for relation in relations]
else:
relations = self._get_data(
'_artifact_relations',
{'artifact_id': artifact_id},
)
ids = [relation['uuid'] for relation in relations]
return ids

# TODO: Refactor to use _create_data, _delete_data, _get_data
@abstractmethod
def _create_data(self, table_name, datas):
pass

@abstractmethod
def _delete_data(self, table_name, filter):
pass

@abstractmethod
def _get_data(self, table_name, filter):
pass

@abstractmethod
def drop(self, force: bool = False):
"""
Expand Down
47 changes: 43 additions & 4 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def _apply(
if children:
serialized = self._change_component_reference_prefix(serialized)

serialized = self.artifact_store.save_artifact(serialized)
serialized = self._save_artifact(object.uuid, serialized)
if artifacts:
for file_id, bytes in artifacts.items():
self.artifact_store.put_bytes(bytes, file_id)
Expand Down Expand Up @@ -786,7 +786,7 @@ def _remove_component_version(
except KeyError:
pass

self.artifact_store.delete_artifact(info)
self._delete_artifacts(r['uuid'], info)
self.metadata.delete_component_version(type_id, identifier, version=version)

def _get_content_for_filter(self, filter) -> Document:
Expand Down Expand Up @@ -848,9 +848,9 @@ def replace(
if children:
serialized = self._change_component_reference_prefix(serialized)

self.artifact_store.delete_artifact(info)
self._delete_artifacts(object.uuid, info)

serialized = self.artifact_store.save_artifact(serialized)
serialized = self._save_artifact(object.uuid, serialized)

self.metadata.replace_object(
serialized,
Expand All @@ -859,6 +859,45 @@ def replace(
version=object.version,
)

def _save_artifact(self, uuid, info: t.Dict):
"""
Save an artifact to the artifact store.

:param artifact: The artifact to save.
"""
artifact_ids, _ = self._find_artifacts(info)
self.metadata.create_artifact_relation(uuid, artifact_ids)
return self.artifact_store.save_artifact(info)

def _delete_artifacts(self, uuid, info: t.Dict):
artifact_ids, artifacts = self._find_artifacts(info)
for artifact_id in artifact_ids:
relation_uuids = self.metadata.get_artifact_relations(
artifact_id=artifact_id
)
if len(relation_uuids) == 1 and relation_uuids[0] == uuid:
self.artifact_store.delete_artifact([artifact_id])
self.metadata.delete_artifact_relation(
uuid=uuid, artifact_ids=artifact_id
)

def _find_artifacts(self, info: t.Dict):
from superduper.misc.special_dicts import recursive_find

# find all blobs with `&:blob:` prefix,
blobs = recursive_find(
info, lambda v: isinstance(v, str) and v.startswith('&:blob:')
)

# find all files with `&:file:` prefix
files = recursive_find(
info, lambda v: isinstance(v, str) and v.startswith('&:file:')
)
artifact_ids: list[str] = []
artifact_ids.extend(a.split(":")[-1] for a in blobs)
artifact_ids.extend(a.split(":")[-1] for a in files)
return artifact_ids, {'blobs': blobs, 'files': files}

def select_nearest(
self,
like: t.Union[t.Dict, Document],
Expand Down
Empty file.
22 changes: 22 additions & 0 deletions test/unittest/base/test_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,28 @@ def test_dataset(db):
assert len(dataset.data) == len(list(db.execute(dataset.select)))


def test_delete_componet_with_same_artifact(db):
from superduper import ObjectModel

model1 = ObjectModel(
object=lambda x: x + 1,
identifier='model1',
)

model2 = ObjectModel(
object=model1.object,
identifier='model2',
)

db.apply(model1)
db.apply(model2)

db.remove('model', 'model1', force=True)
model2 = db.load('model', 'model2')
model2.init()
assert model2.predict(1) == 2


def test_retry_on_token_expiry(db):
# Mock the methods
db.retry = 1
Expand Down
Loading
Loading