Skip to content

Commit

Permalink
Fix bytesencoding for base64
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 22, 2024
1 parent 9f20b61 commit a392ee2
Show file tree
Hide file tree
Showing 21 changed files with 199 additions and 182 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### Changed defaults / behaviours

#### New Features & Functionality
#### New Features & Functionality

- Streamlit component and server
- Graceful updates when making incremental changes
Expand Down
3 changes: 3 additions & 0 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

if uri.startswith('snowflake://') or uri.startswith('sqlite://'):
self.bytes_encoding = 'base64'

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}

def _setup(self, conn):
Expand Down
2 changes: 1 addition & 1 deletion plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _get_schema(self):
)
fields.update(to_update)

return Schema(f"_tmp:{self.table}", fields=fields)
return Schema(f"_tmp:{self.table}", fields=fields, db=self.db)

def renamings(self, r={}):
"""Return the renamings.
Expand Down
11 changes: 7 additions & 4 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ibis.expr.datatypes import dtype
from superduper import CFG
from superduper.components.datatype import (
Artifact,
BaseDataType,
Expand Down Expand Up @@ -42,9 +41,13 @@ def convert_schema_to_fields(schema: Schema):
elif not isinstance(v, BaseDataType):
fields[k] = v.identifier
else:
if v.encodable_cls in SPECIAL_ENCODABLES_FIELDS:
fields[k] = dtype(SPECIAL_ENCODABLES_FIELDS[v.encodable_cls])
if v.encodable == 'encodable':
fields[k] = dtype(
'str'
if schema.db.databackend.bytes_encoding == 'base64'
else 'bytes'
)
else:
fields[k] = CFG.bytes_encoding
fields[k] = dtype('str')

return fields
2 changes: 1 addition & 1 deletion plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def _get_schema(self):

fields = {k: v for k, v in fields.items() if isinstance(v, BaseDataType)}

return Schema(f"_tmp:{self.table}", fields=fields)
return Schema(f"_tmp:{self.table}", fields=fields, db=self.db)

def _get_project(self):
find_params, _ = self._get_method_parameters('find')
Expand Down
2 changes: 0 additions & 2 deletions plugins/openai/plugin_test/test_model_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_embed():
e = OpenAIEmbedding(identifier='text-embedding-ada-002')
resp = e.predict('Hello, world!')

assert len(resp) == e.shape[0]
assert str(resp.dtype) == 'float32'


Expand All @@ -103,7 +102,6 @@ def test_batch_embed():
resp = e.predict_batches(['Hello', 'world!'])

assert len(resp) == 2
assert all(len(x) == e.shape[0] for x in resp)
assert all(str(x.dtype) == 'float32' for x in resp)


Expand Down
18 changes: 0 additions & 18 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel, Inputs
from superduper.components.vector_index import vector
from superduper.misc.compat import cache
from superduper.misc.retry import Retry

Expand Down Expand Up @@ -107,9 +106,6 @@ class OpenAIEmbedding(_OpenAI):
"""

shapes: t.ClassVar[t.Dict] = {'text-embedding-ada-002': (1536,)}

shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'
batch_size: int = 100

Expand All @@ -118,20 +114,6 @@ def inputs(self):
"""The inputs of the model."""
return Inputs(['input'])

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
if self.shape is None:
self.shape = self.shapes[self.model]

def _pre_create(self, db: Datalayer) -> None:
"""Pre creates the model.
the datatype is set to ``vector``.
:param db: The datalayer instance.
"""
self.datatype = self.datatype or vector(shape=self.shape)

@retry
def predict(self, X: str):
"""Generates embeddings from text.
Expand Down
1 change: 1 addition & 0 deletions superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.in_memory_tables: t.Dict = {}
self._datalayer = None
self.uri = uri
self.bytes_encoding = 'bytes'

@property
def type(self):
Expand Down
4 changes: 2 additions & 2 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ class Config(BaseConfig):
:param logging_type: The type of logging to use
:param force_apply: Whether to force apply the configuration
:param datatype_presets: Presets to be applied for default types of data
:param bytes_encoding: The encoding of bytes in the data backend
:param auto_schema: Whether to automatically create the schema.
If True, the schema will be created if it does not exist.
:param json_native: Whether the databackend supports json natively or not.
:param log_colorize: Whether to colorize the logs
:param bytes_encoding: (Deprecated)
:param output_prefix: The prefix for the output table and output field key
:param vector_search_kwargs: The keyword arguments to pass to the vector search
:param rest: Settings for rest server.
Expand All @@ -190,12 +190,12 @@ class Config(BaseConfig):
log_level: LogLevel = LogLevel.INFO
logging_type: LogType = LogType.SYSTEM
log_colorize: bool = True
bytes_encoding: str = 'bytes'

force_apply: bool = False

datatype_presets: DataTypePresets = dc.field(default_factory=DataTypePresets)

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
json_native: bool = True
output_prefix: str = "_outputs__"
Expand Down
57 changes: 11 additions & 46 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from superduper.base.variables import _replace_variables
from superduper.components.component import Component
from superduper.components.datatype import (
BaseDataType,
Blob,
Encodable,
FileItem,
Expand Down Expand Up @@ -52,7 +51,7 @@ def _build_blob_getter(base_getter):
return partial(_blob_getter, getter=base_getter)


class _Getters:
class Getters:
"""A class to manage getters for decoding documents.
We have will have a list of getters for each type of reference.
Expand Down Expand Up @@ -258,7 +257,7 @@ def decode(
r,
schema: t.Optional['Schema'] = None,
db: t.Optional['Datalayer'] = None,
getters: t.Union[_Getters, t.Dict[str, t.Callable], None] = None,
getters: t.Union[Getters, t.Dict[str, t.Callable], None] = None,
):
"""Converts any dictionary into a Document or a Leaf.
Expand Down Expand Up @@ -291,9 +290,9 @@ def decode(
identifier = k
builds[k]['identifier'] = identifier

if not isinstance(getters, _Getters):
getters = _Getters(getters)
assert isinstance(getters, _Getters)
if not isinstance(getters, Getters):
getters = Getters(getters)
assert isinstance(getters, Getters)

# Prioritize using the local artifact storage getter,
# and then use the DB read getter.
Expand All @@ -314,7 +313,7 @@ def decode(

if schema is not None:
schema.init()
r = _schema_decode(schema, r, getters)
r = schema.decode_data(r, getters)

r = _deep_flat_decode(
{k: v for k, v in r.items() if k not in (KEY_BUILDS, KEY_BLOBS, KEY_FILES)},
Expand Down Expand Up @@ -586,42 +585,6 @@ def _deep_flat_encode(
return r


def _schema_decode(
schema, data: dict[str, t.Any], getters: _Getters
) -> dict[str, t.Any]:
"""Decode data using the schema's encoders.
:param data: Data to decode.
"""
if schema.trivial:
return data
decoded = {}
for k, value in data.items():
field = schema.fields.get(k)
if not isinstance(field, BaseDataType):
decoded[k] = value
continue

value = data[k]
if reference := parse_reference(value):
value = getters.run(reference.name, reference.path)
if reference.name == 'blob':
kwargs = {'blob': value}
elif reference.name == 'file':
kwargs = {'x': value}
else:
assert False, f'Unknown reference type {reference.name}'
encodable = field.encodable_cls(datatype=field, **kwargs)
if not field.encodable_cls.lazy:
encodable = encodable.unpack()
decoded[k] = encodable
else:
decoded[k] = field.decode_data(data[k])

decoded.pop(KEY_SCHEMA, None)
return decoded


def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None):
if reference := parse_reference(f'?{k}'):
if reference.name in getters:
Expand Down Expand Up @@ -672,7 +635,7 @@ def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None)
return leaf


def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer'] = None):
def _deep_flat_decode(r, builds, getters: Getters, db: t.Optional['Datalayer'] = None):
if isinstance(r, Leaf):
return r
if isinstance(r, (list, tuple)):
Expand Down Expand Up @@ -722,7 +685,8 @@ def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer']
if isinstance(r, str) and r.startswith('&'):
assert getters is not None
reference = parse_reference(r)
return getters.run(reference.name, reference.path)
out = getters.run(reference.name, reference.path)
return out
return r


Expand All @@ -741,7 +705,8 @@ def _get_component(db, path):
return db.load(type_id=parts[0], identifier=parts[1])
if len(parts) == 3:
if not _check_if_version(parts[2]):
return db.load(uuid=parts[2])
out = db.load(uuid=parts[2])
return out
return db.load(type_id=parts[0], identifier=parts[1], version=parts[2])
raise ValueError(f'Invalid component reference: {path}')

Expand Down
1 change: 1 addition & 0 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def _convert_components_to_refs(r):
r['status'] = str(self.status)
return Document(r)

# TODO needed? looks to have legacy "_content"
@classmethod
def decode(cls, r, db: t.Optional[t.Any] = None, reference: bool = False):
"""Decodes a dictionary component into a `Component` instance.
Expand Down
34 changes: 24 additions & 10 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class NativeVector(BaseDataType):
:param dtype: Datatype of array to encode.
"""

encodable: t.ClassVar[str] = 'native'
dtype: str = 'float64'

def __post_init__(self, db, artifacts):
Expand All @@ -234,6 +235,22 @@ def decode_data(self, item, info=None):
return numpy.array(item).astype(self.dtype)


class Json2Str(BaseDataType):
"""Datatype for encoding vectors which are supported natively by databackend."""

encodable: t.ClassVar[str] = 'native'

def __post_init__(self, db, artifacts):
# self.encodable_cls = Native
return super().__post_init__(db, artifacts)

def encode_data(self, item, info=None):
return json.dumps(item)

def decode_data(self, item, info=None):
return json.loads(item)


class DataType(BaseDataType):
"""A data type component that defines how data is encoded and decoded.
Expand Down Expand Up @@ -324,7 +341,7 @@ def encode_data(self, item, info: t.Optional[t.Dict] = None):
"""
info = info or {}
data = self.encoder(item, info) if self.encoder else item
data = self.bytes_encoding_after_encode(data)
# data = self.bytes_encoding_after_encode(data)
return data

@ensure_initialized
Expand All @@ -335,7 +352,7 @@ def decode_data(self, item, info: t.Optional[t.Dict] = None):
:param info: The optional information dictionary.
"""
info = info or {}
item = self.bytes_encoding_before_decode(item)
# item = self.bytes_encoding_before_decode(item)
return self.decoder(item, info=info) if self.decoder else item

def bytes_encoding_after_encode(self, data):
Expand Down Expand Up @@ -789,14 +806,7 @@ def get_serializer(
)


json_serializer = DataType(
'json',
encoder=json_encode,
decoder=json_decode,
encodable='encodable',
bytes_encoding=BytesEncoding.BASE64,
intermediate_type=IntermediateType.STRING,
)
json_serializer = Json2Str('json')


pickle_encoder = get_serializer(
Expand Down Expand Up @@ -876,6 +886,10 @@ def __post_init__(self, db, artifacts):
def encodable_cls(self):
return self.datatype_impl.encodable_cls

@property
def encodable(self):
return self.datatype_impl.encodable

@cached_property
def datatype_impl(self):
if isinstance(CFG.datatype_presets.vector, str):
Expand Down
Loading

0 comments on commit a392ee2

Please sign in to comment.