Skip to content

Commit

Permalink
Fix datatype swap in MongoDB
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Oct 17, 2024
1 parent 8ebd5b6 commit c81d9fe
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 101 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Before you create a Pull Request, remember to update the Changelog with your changes.**

## Changes Since Last Release
## Changes Since Last Release

#### Changed defaults / behaviours

Expand Down Expand Up @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change compute init order in cluster initialize
- Add table error exception and sql table length fallback.
- Permissions of artifacts increased
- Make JSON-able a configuration depending on the databackend

#### New Features & Functionality

Expand Down
1 change: 1 addition & 0 deletions plugins/ibis/plugin_test/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ artifact_store: null
data_backend: sqlite://
auto_schema: false
force_apply: true
json_native: false
15 changes: 13 additions & 2 deletions plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def _model_update_impl(
class IbisQuery(Query):
"""A query that can be executed on an Ibis database."""

def __post_init__(self, db=None):
super().__post_init__(db)
self._primary_id = None
self._base_table = None

@property
def base_table(self):
"""Return the base table."""
if self._base_table is None:
self._base_table = self.db.load('table', self.table)
return self._base_table

flavours: t.ClassVar[t.Dict[str, str]] = {
"pre_like": r"^.*\.like\(.*\)\.select",
"post_like": r"^.*\.([a-z]+)\(.*\)\.like(.*)$",
Expand Down Expand Up @@ -234,8 +246,7 @@ def type(self):
@property
def primary_id(self):
"""Return the primary id."""
table = self.db.load('table', self.table)
return table.primary_id
return self.base_table.primary_id

def model_update(
self,
Expand Down
6 changes: 2 additions & 4 deletions plugins/mongodb/superduper_mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.base.enums import DBType
from superduper.components.datatype import DataType
from superduper.components.schema import FieldType, Schema
from superduper.components.schema import Schema
from superduper.misc.colors import Colors

from superduper_mongodb.artifacts import MongoArtifactStore
Expand Down Expand Up @@ -248,6 +248,4 @@ def create_table_and_schema(self, identifier: str, schema: Schema):
"""
# If the data can be converted to JSON,
# then save it as native data in MongoDB.
for key, datatype in schema.fields.items():
if isinstance(datatype, DataType) and datatype.identifier == "json":
schema.fields[key] = FieldType(identifier="json")
pass
62 changes: 34 additions & 28 deletions superduper/backends/local/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,46 @@ def __init__(self, init_cache: bool = True):
super().__init__()
self.init_cache = init_cache
self._cache: t.Dict = {}
self._cache_uuid: t.Dict = {}
self._component_to_uuid: t.Dict = {}
self._db = None

def list_components(self):
"""List components by (type_id, identifier) in the cache."""
return list(self._cache.keys())
return list(self._component_to_uuid.keys())

def list_uuids(self):
"""List UUIDs in the cache."""
return list(self._cache_uuid.keys())
return list(self._cache.keys())

# TODO which of these is the correct one?
# def __getitem__(self, *item):
# return self._cache[item]
def __getitem__(self, item):
if isinstance(item, tuple):
# (type_id, identifier)
item = self._component_to_uuid[item[0], item[1]]
return self._cache[item]

def _put(self, component: Component):
"""Put a component in the cache."""
self._cache[component.type_id, component.identifier] = component
self._cache_uuid[component.uuid] = component

def __delitem__(self, name: str):
del self._cache[name]
self._cache[component.uuid] = component
if (component.type_id, component.identifier) in self._component_to_uuid:
current = self._component_to_uuid[component.type_id, component.identifier]
current_version = self._cache[current].version
if current_version < component.version:
self._component_to_uuid[
component.type_id, component.identifier
] = component.uuid
else:
self._component_to_uuid[
component.type_id, component.identifier
] = component.uuid

def __delitem__(self, item):
if isinstance(item, tuple):
item = self._component_to_uuid[item[0], item[1]]
tuples = [k for k, v in self._component_to_uuid.items() if v == item]
if tuples:
for type_id, identifier in tuples:
del self._component_to_uuid[type_id, identifier]
del self._cache[item]

def initialize(self):
"""Initialize the cache."""
Expand All @@ -52,6 +70,7 @@ def initialize(self):
def drop(self):
"""Drop the cache."""
self._cache = {}
self._component_to_uuid = {}

@property
def db(self):
Expand All @@ -67,29 +86,16 @@ def db(self, value):
self._db = value
self.initialize()

# def init(self):
# """Initialize the cache."""
# if not self.init_cache:
# return
# for component in self.db.show():
# if 'version' not in component:
# component['version'] = -1

# show = self.db.show(**component)
# uuid = show.get('uuid')
# if show.get('cache', False):
# self._cache[uuid] = self.db.load(uuid=uuid)

def __getitem__(self, uuid: str):
"""Get a component from the cache."""
return self._cache_uuid[uuid]

def __iter__(self):
return iter(self._cache.keys())

def expire(self, item):
"""Expire an item from the cache."""
try:
del self._cache[item]
for (t, i), uuid in self._component_to_uuid.items():
if uuid == item:
del self._component_to_uuid[t, i]
break
except KeyError:
pass
2 changes: 2 additions & 0 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class Config(BaseConfig):
:param bytes_encoding: The encoding of bytes in the data backend
:param auto_schema: Whether to automatically create the schema.
If True, the schema will be created if it does not exist.
:param json_native: Whether the databackend supports json natively or not.
:param log_colorize: Whether to colorize the logs
:param output_prefix: The prefix for the output table and output field key
:param vector_search_kwargs: The keyword arguments to pass to the vector search
Expand Down Expand Up @@ -181,6 +182,7 @@ class Config(BaseConfig):

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
json_native: bool = True
output_prefix: str = "_outputs__"
vector_search_kwargs: t.Dict = dc.field(default_factory=dict)
rest: RestConfig = dc.field(default_factory=RestConfig)
Expand Down
88 changes: 39 additions & 49 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def load(
type_id: t.Optional[str] = None,
identifier: t.Optional[str] = None,
version: t.Optional[int] = None,
allow_hidden: bool = False,
uuid: t.Optional[str] = None,
huuid: t.Optional[str] = None,
on_load: bool = True,
allow_hidden: bool = False,
) -> Component:
"""
Load a component using uniquely identifying information.
Expand All @@ -600,63 +600,52 @@ def load(
of deprecated components.
:param uuid: [Optional] UUID of the component to load.
"""
if type_id == 'encoder':
logging.warn(
'"encoder" has moved to "datatype" this functionality will not work'
' after version 0.2.0'
)
type_id = 'datatype'
if uuid is None and huuid is None:
if type_id is None or identifier is None:
raise ValueError(
'Must specify `type_id` and `identifier` to load a component '
'when `uuid` is not provided.'
)

if version is not None:
assert type_id is not None
assert identifier is not None
info = self.metadata.get_component(
type_id=type_id,
identifier=identifier,
version=version,
allow_hidden=allow_hidden,
)
else:
if huuid:
uuid = huuid.split(':')[-1]
uuid = info['uuid']

if huuid is not None:
uuid = huuid.split(':')[-1]

if uuid is not None:
try:
assert uuid is not None
uuid = uuid.split('.')[0]
return self.cluster.cache[uuid]
except KeyError:
logging.info(f'Component {uuid} not found in cache, loading from db')
info = self.metadata.get_component_by_uuid(
uuid=uuid,
allow_hidden=allow_hidden,
uuid=uuid, allow_hidden=allow_hidden
)
except FileNotFoundError as e:
if huuid is not None:
raise FileNotFoundError(
f'Could not find {huuid} in metadata.'
) from e
raise e

assert info is not None
type_id = info['type_id']

if info.get('cache', False):
c = Document.decode(info, db=self)
c.db = self
else:
try:
return self.cluster.cache[info['type_id'], info['identifier']]
return self.cluster.cache[type_id, identifier]
except KeyError:
logging.info(
f'Component {info["uuid"]} not found in cache, loading from db'
logging.warn(
f'Component ({type_id}, {identifier}) not found in cache, '
'loading from db'
)
assert type_id is not None
assert identifier is not None
info = self.metadata.get_component(
type_id=type_id,
identifier=identifier,
allow_hidden=allow_hidden,
)
c = Document.decode(info, db=self)
c.db = self

m = Document.decode(info, db=self)
m.db = self
if on_load:
m.on_load(self)

assert type_id is not None
if m.cache:
logging.info(f'Adding component {info["uuid"]} to cache.')
self.cluster.cache.put(m)
return m
if c.cache:
logging.info(f'Adding {c.huuid} to cache')
self.cluster.cache.put(c)
return c

def _add_child_components(self, components, parent, job_events, context):
# TODO this is a bit of a mess
Expand Down Expand Up @@ -707,7 +696,6 @@ def _apply(
self._update_component(object, parent=parent)
return [], []

# object.pre_create(self)
assert hasattr(object, 'identifier')
assert hasattr(object, 'version')

Expand Down Expand Up @@ -810,7 +798,9 @@ def _remove_component_version(
):
# TODO - make this less I/O intensive
component = self.load(
type_id, identifier, version=version, allow_hidden=force
type_id,
identifier,
version=version,
)
info = self.metadata.get_component(
type_id, identifier, version=version, allow_hidden=force
Expand Down Expand Up @@ -893,11 +883,10 @@ def replace(
type_id=object.type_id,
version=object.version,
)
self.expire(object.uuid)
self.expire(old_uuid)

def expire(self, uuid):
"""Expire a component from the cache."""
parents = True
self.cluster.cache.expire(uuid)
parents = self.metadata.get_component_version_parents(uuid)
while parents:
Expand Down Expand Up @@ -969,6 +958,7 @@ def select_nearest(
assert isinstance(like, dict)
like = Document(like)
like = self._get_content_for_filter(like)
logging.info('Getting vector-index')
vi = self.load('vector_index', vector_index)
if outputs is None:
outs: t.Dict = {}
Expand Down
6 changes: 3 additions & 3 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Component(Leaf, metaclass=ComponentMeta):
upstream: t.Optional[t.List["Component"]] = None
plugins: t.Optional[t.List["Plugin"]] = None
artifacts: dc.InitVar[t.Optional[t.Dict]] = None
cache: t.Optional[bool] = False
cache: t.Optional[bool] = True
status: t.Optional[Status] = None

@property
Expand Down Expand Up @@ -557,7 +557,7 @@ def declare_component(self, cluster):
:param cluster: The cluster to declare the component to.
"""
if self.cache:
logging.debug(f'Declaring {self.type_id}: {self.identifier} to cache')
logging.info(f'Adding {self.type_id}: {self.identifier} to cache')
cluster.cache.put(self)
cluster.compute.put(self)

Expand Down Expand Up @@ -796,7 +796,7 @@ def decode(cls, r, db: t.Optional[t.Any] = None, reference: bool = False):
assert db is not None
r = r['_content']
assert r['version'] is not None
return db.load(r['type_id'], r['identifier'], r['version'], allow_hidden=True)
return db.load(r['type_id'], r['identifier'], r['version'])

def __setattr__(self, k, v):
if k in dc.fields(self):
Expand Down
1 change: 1 addition & 0 deletions superduper/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class FieldType(Leaf):

def __post_init__(self, db):
super().__post_init__(db)

if isinstance(self.identifier, DataType):
self.identifier = self.identifier.name

Expand Down
Loading

0 comments on commit c81d9fe

Please sign in to comment.