Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bytesencoding for base64 #2645

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### Changed defaults / behaviours

#### New Features & Functionality
#### New Features & Functionality

- Streamlit component and server
- Graceful updates when making incremental changes
Expand Down
3 changes: 3 additions & 0 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

if uri.startswith('snowflake://') or uri.startswith('sqlite://'):
self.bytes_encoding = 'base64'

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}

def _setup(self, conn):
Expand Down
2 changes: 1 addition & 1 deletion plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _get_schema(self):
)
fields.update(to_update)

return Schema(f"_tmp:{self.table}", fields=fields)
return Schema(f"_tmp:{self.table}", fields=fields, db=self.db)

def renamings(self, r={}):
"""Return the renamings.
Expand Down
11 changes: 7 additions & 4 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ibis.expr.datatypes import dtype
from superduper import CFG
from superduper.components.datatype import (
Artifact,
BaseDataType,
Expand Down Expand Up @@ -42,9 +41,13 @@ def convert_schema_to_fields(schema: Schema):
elif not isinstance(v, BaseDataType):
fields[k] = v.identifier
else:
if v.encodable_cls in SPECIAL_ENCODABLES_FIELDS:
fields[k] = dtype(SPECIAL_ENCODABLES_FIELDS[v.encodable_cls])
if v.encodable == 'encodable':
fields[k] = dtype(
'str'
if schema.db.databackend.bytes_encoding == 'base64'
else 'bytes'
)
else:
fields[k] = CFG.bytes_encoding
fields[k] = dtype('str')

return fields
2 changes: 1 addition & 1 deletion plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def _get_schema(self):

fields = {k: v for k, v in fields.items() if isinstance(v, BaseDataType)}

return Schema(f"_tmp:{self.table}", fields=fields)
return Schema(f"_tmp:{self.table}", fields=fields, db=self.db)

def _get_project(self):
find_params, _ = self._get_method_parameters('find')
Expand Down
2 changes: 0 additions & 2 deletions plugins/openai/plugin_test/test_model_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_embed():
e = OpenAIEmbedding(identifier='text-embedding-ada-002')
resp = e.predict('Hello, world!')

assert len(resp) == e.shape[0]
assert str(resp.dtype) == 'float32'


Expand All @@ -103,7 +102,6 @@ def test_batch_embed():
resp = e.predict_batches(['Hello', 'world!'])

assert len(resp) == 2
assert all(len(x) == e.shape[0] for x in resp)
assert all(str(x.dtype) == 'float32' for x in resp)


Expand Down
18 changes: 0 additions & 18 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel, Inputs
from superduper.components.vector_index import vector
from superduper.misc.compat import cache
from superduper.misc.retry import Retry

Expand Down Expand Up @@ -107,9 +106,6 @@ class OpenAIEmbedding(_OpenAI):

"""

shapes: t.ClassVar[t.Dict] = {'text-embedding-ada-002': (1536,)}

shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'
batch_size: int = 100

Expand All @@ -118,20 +114,6 @@ def inputs(self):
"""The inputs of the model."""
return Inputs(['input'])

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
if self.shape is None:
self.shape = self.shapes[self.model]

def _pre_create(self, db: Datalayer) -> None:
"""Pre creates the model.

the datatype is set to ``vector``.

:param db: The datalayer instance.
"""
self.datatype = self.datatype or vector(shape=self.shape)

@retry
def predict(self, X: str):
"""Generates embeddings from text.
Expand Down
1 change: 1 addition & 0 deletions superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.in_memory_tables: t.Dict = {}
self._datalayer = None
self.uri = uri
self.bytes_encoding = 'bytes'

@property
def type(self):
Expand Down
4 changes: 2 additions & 2 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ class Config(BaseConfig):
:param logging_type: The type of logging to use
:param force_apply: Whether to force apply the configuration
:param datatype_presets: Presets to be applied for default types of data
:param bytes_encoding: The encoding of bytes in the data backend
:param auto_schema: Whether to automatically create the schema.
If True, the schema will be created if it does not exist.
:param json_native: Whether the databackend supports json natively or not.
:param log_colorize: Whether to colorize the logs
:param bytes_encoding: (Deprecated)
:param output_prefix: The prefix for the output table and output field key
:param vector_search_kwargs: The keyword arguments to pass to the vector search
:param rest: Settings for rest server.
Expand All @@ -190,12 +190,12 @@ class Config(BaseConfig):
log_level: LogLevel = LogLevel.INFO
logging_type: LogType = LogType.SYSTEM
log_colorize: bool = True
bytes_encoding: str = 'bytes'

force_apply: bool = False

datatype_presets: DataTypePresets = dc.field(default_factory=DataTypePresets)

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
json_native: bool = True
output_prefix: str = "_outputs__"
Expand Down
57 changes: 11 additions & 46 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from superduper.base.variables import _replace_variables
from superduper.components.component import Component
from superduper.components.datatype import (
BaseDataType,
Blob,
Encodable,
FileItem,
Expand Down Expand Up @@ -52,7 +51,7 @@ def _build_blob_getter(base_getter):
return partial(_blob_getter, getter=base_getter)


class _Getters:
class Getters:
"""A class to manage getters for decoding documents.

We have will have a list of getters for each type of reference.
Expand Down Expand Up @@ -258,7 +257,7 @@ def decode(
r,
schema: t.Optional['Schema'] = None,
db: t.Optional['Datalayer'] = None,
getters: t.Union[_Getters, t.Dict[str, t.Callable], None] = None,
getters: t.Union[Getters, t.Dict[str, t.Callable], None] = None,
):
"""Converts any dictionary into a Document or a Leaf.

Expand Down Expand Up @@ -291,9 +290,9 @@ def decode(
identifier = k
builds[k]['identifier'] = identifier

if not isinstance(getters, _Getters):
getters = _Getters(getters)
assert isinstance(getters, _Getters)
if not isinstance(getters, Getters):
getters = Getters(getters)
assert isinstance(getters, Getters)

# Prioritize using the local artifact storage getter,
# and then use the DB read getter.
Expand All @@ -314,7 +313,7 @@ def decode(

if schema is not None:
schema.init()
r = _schema_decode(schema, r, getters)
r = schema.decode_data(r, getters)

r = _deep_flat_decode(
{k: v for k, v in r.items() if k not in (KEY_BUILDS, KEY_BLOBS, KEY_FILES)},
Expand Down Expand Up @@ -586,42 +585,6 @@ def _deep_flat_encode(
return r


def _schema_decode(
schema, data: dict[str, t.Any], getters: _Getters
) -> dict[str, t.Any]:
"""Decode data using the schema's encoders.

:param data: Data to decode.
"""
if schema.trivial:
return data
decoded = {}
for k, value in data.items():
field = schema.fields.get(k)
if not isinstance(field, BaseDataType):
decoded[k] = value
continue

value = data[k]
if reference := parse_reference(value):
value = getters.run(reference.name, reference.path)
if reference.name == 'blob':
kwargs = {'blob': value}
elif reference.name == 'file':
kwargs = {'x': value}
else:
assert False, f'Unknown reference type {reference.name}'
encodable = field.encodable_cls(datatype=field, **kwargs)
if not field.encodable_cls.lazy:
encodable = encodable.unpack()
decoded[k] = encodable
else:
decoded[k] = field.decode_data(data[k])

decoded.pop(KEY_SCHEMA, None)
return decoded


def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None):
if reference := parse_reference(f'?{k}'):
if reference.name in getters:
Expand Down Expand Up @@ -672,7 +635,7 @@ def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None)
return leaf


def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer'] = None):
def _deep_flat_decode(r, builds, getters: Getters, db: t.Optional['Datalayer'] = None):
if isinstance(r, Leaf):
return r
if isinstance(r, (list, tuple)):
Expand Down Expand Up @@ -722,7 +685,8 @@ def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer']
if isinstance(r, str) and r.startswith('&'):
assert getters is not None
reference = parse_reference(r)
return getters.run(reference.name, reference.path)
out = getters.run(reference.name, reference.path)
return out
return r


Expand All @@ -741,7 +705,8 @@ def _get_component(db, path):
return db.load(type_id=parts[0], identifier=parts[1])
if len(parts) == 3:
if not _check_if_version(parts[2]):
return db.load(uuid=parts[2])
out = db.load(uuid=parts[2])
return out
return db.load(type_id=parts[0], identifier=parts[1], version=parts[2])
raise ValueError(f'Invalid component reference: {path}')

Expand Down
1 change: 1 addition & 0 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def _convert_components_to_refs(r):
r['status'] = str(self.status)
return Document(r)

# TODO needed? looks to have legacy "_content"
@classmethod
def decode(cls, r, db: t.Optional[t.Any] = None, reference: bool = False):
"""Decodes a dictionary component into a `Component` instance.
Expand Down
34 changes: 24 additions & 10 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class NativeVector(BaseDataType):
:param dtype: Datatype of array to encode.
"""

encodable: t.ClassVar[str] = 'native'
dtype: str = 'float64'

def __post_init__(self, db, artifacts):
Expand All @@ -234,6 +235,22 @@ def decode_data(self, item, info=None):
return numpy.array(item).astype(self.dtype)


class Json2Str(BaseDataType):
"""Datatype for encoding vectors which are supported natively by databackend."""

encodable: t.ClassVar[str] = 'native'

def __post_init__(self, db, artifacts):
# self.encodable_cls = Native
return super().__post_init__(db, artifacts)

def encode_data(self, item, info=None):
return json.dumps(item)

def decode_data(self, item, info=None):
return json.loads(item)


class DataType(BaseDataType):
"""A data type component that defines how data is encoded and decoded.

Expand Down Expand Up @@ -324,7 +341,7 @@ def encode_data(self, item, info: t.Optional[t.Dict] = None):
"""
info = info or {}
data = self.encoder(item, info) if self.encoder else item
data = self.bytes_encoding_after_encode(data)
# data = self.bytes_encoding_after_encode(data)
return data

@ensure_initialized
Expand All @@ -335,7 +352,7 @@ def decode_data(self, item, info: t.Optional[t.Dict] = None):
:param info: The optional information dictionary.
"""
info = info or {}
item = self.bytes_encoding_before_decode(item)
# item = self.bytes_encoding_before_decode(item)
return self.decoder(item, info=info) if self.decoder else item

def bytes_encoding_after_encode(self, data):
Expand Down Expand Up @@ -789,14 +806,7 @@ def get_serializer(
)


json_serializer = DataType(
'json',
encoder=json_encode,
decoder=json_decode,
encodable='encodable',
bytes_encoding=BytesEncoding.BASE64,
intermediate_type=IntermediateType.STRING,
)
json_serializer = Json2Str('json')


pickle_encoder = get_serializer(
Expand Down Expand Up @@ -876,6 +886,10 @@ def __post_init__(self, db, artifacts):
def encodable_cls(self):
return self.datatype_impl.encodable_cls

@property
def encodable(self):
return self.datatype_impl.encodable

@cached_property
def datatype_impl(self):
if isinstance(CFG.datatype_presets.vector, str):
Expand Down
Loading
Loading