diff --git a/CHANGELOG.md b/CHANGELOG.md index a2488e1b8..279a25120 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Changed defaults / behaviours -#### New Features & Functionality +#### New Features & Functionality - Streamlit component and server - Graceful updates when making incremental changes diff --git a/plugins/ibis/superduper_ibis/data_backend.py b/plugins/ibis/superduper_ibis/data_backend.py index 5dcb49eb6..c798ffa92 100644 --- a/plugins/ibis/superduper_ibis/data_backend.py +++ b/plugins/ibis/superduper_ibis/data_backend.py @@ -69,6 +69,9 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None): self.overwrite = False self._setup(conn) + if uri.startswith('snowflake://') or uri.startswith('sqlite://'): + self.bytes_encoding = 'base64' + self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'} def _setup(self, conn): diff --git a/plugins/ibis/superduper_ibis/query.py b/plugins/ibis/superduper_ibis/query.py index ca137cf79..6ec315581 100644 --- a/plugins/ibis/superduper_ibis/query.py +++ b/plugins/ibis/superduper_ibis/query.py @@ -161,7 +161,7 @@ def _get_schema(self): ) fields.update(to_update) - return Schema(f"_tmp:{self.table}", fields=fields) + return Schema(f"_tmp:{self.table}", fields=fields, db=self.db) def renamings(self, r={}): """Return the renamings. diff --git a/plugins/ibis/superduper_ibis/utils.py b/plugins/ibis/superduper_ibis/utils.py index e33efd263..1aeeb87b3 100644 --- a/plugins/ibis/superduper_ibis/utils.py +++ b/plugins/ibis/superduper_ibis/utils.py @@ -1,5 +1,4 @@ from ibis.expr.datatypes import dtype -from superduper import CFG from superduper.components.datatype import ( Artifact, BaseDataType, @@ -42,9 +41,13 @@ def convert_schema_to_fields(schema: Schema): elif not isinstance(v, BaseDataType): fields[k] = v.identifier else: - if v.encodable_cls in SPECIAL_ENCODABLES_FIELDS: - fields[k] = dtype(SPECIAL_ENCODABLES_FIELDS[v.encodable_cls]) + if v.encodable == 'encodable': + fields[k] = dtype( + 'str' + if schema.db.databackend.bytes_encoding == 'base64' + else 'bytes' + ) else: - fields[k] = CFG.bytes_encoding + fields[k] = dtype('str') return fields diff --git a/plugins/mongodb/superduper_mongodb/query.py b/plugins/mongodb/superduper_mongodb/query.py index d88f6e2b2..efa0277e6 100644 --- a/plugins/mongodb/superduper_mongodb/query.py +++ b/plugins/mongodb/superduper_mongodb/query.py @@ -676,7 +676,7 @@ def _get_schema(self): fields = {k: v for k, v in fields.items() if isinstance(v, BaseDataType)} - return Schema(f"_tmp:{self.table}", fields=fields) + return Schema(f"_tmp:{self.table}", fields=fields, db=self.db) def _get_project(self): find_params, _ = self._get_method_parameters('find') diff --git a/plugins/openai/plugin_test/test_model_openai.py b/plugins/openai/plugin_test/test_model_openai.py index 0d10a91e1..c07add5e6 100644 --- a/plugins/openai/plugin_test/test_model_openai.py +++ b/plugins/openai/plugin_test/test_model_openai.py @@ -93,7 +93,6 @@ def test_embed(): e = OpenAIEmbedding(identifier='text-embedding-ada-002') resp = e.predict('Hello, world!') - assert len(resp) == e.shape[0] assert str(resp.dtype) == 'float32' @@ -103,7 +102,6 @@ def test_batch_embed(): resp = e.predict_batches(['Hello', 'world!']) assert len(resp) == 2 - assert all(len(x) == e.shape[0] for x in resp) assert all(str(x.dtype) == 'float32' for x in resp) diff --git a/plugins/openai/superduper_openai/model.py b/plugins/openai/superduper_openai/model.py index 6ff2d130f..6e6fead09 100644 --- a/plugins/openai/superduper_openai/model.py +++ b/plugins/openai/superduper_openai/model.py @@ -18,7 +18,6 @@ from superduper.backends.query_dataset import QueryDataset from superduper.base.datalayer import Datalayer from superduper.components.model import APIBaseModel, Inputs -from superduper.components.vector_index import vector from superduper.misc.compat import cache from superduper.misc.retry import Retry @@ -107,9 +106,6 @@ class OpenAIEmbedding(_OpenAI): """ - shapes: t.ClassVar[t.Dict] = {'text-embedding-ada-002': (1536,)} - - shape: t.Optional[t.Sequence[int]] = None signature: str = 'singleton' batch_size: int = 100 @@ -118,20 +114,6 @@ def inputs(self): """The inputs of the model.""" return Inputs(['input']) - def __post_init__(self, db, artifacts, example): - super().__post_init__(db, artifacts, example) - if self.shape is None: - self.shape = self.shapes[self.model] - - def _pre_create(self, db: Datalayer) -> None: - """Pre creates the model. - - the datatype is set to ``vector``. - - :param db: The datalayer instance. - """ - self.datatype = self.datatype or vector(shape=self.shape) - @retry def predict(self, X: str): """Generates embeddings from text. diff --git a/superduper/backends/base/data_backend.py b/superduper/backends/base/data_backend.py index 74fcadd49..30ad4ca77 100644 --- a/superduper/backends/base/data_backend.py +++ b/superduper/backends/base/data_backend.py @@ -27,6 +27,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None): self.in_memory_tables: t.Dict = {} self._datalayer = None self.uri = uri + self.bytes_encoding = 'bytes' @property def type(self): diff --git a/superduper/base/config.py b/superduper/base/config.py index 3a8ba0a79..7d39b748e 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -161,11 +161,11 @@ class Config(BaseConfig): :param logging_type: The type of logging to use :param force_apply: Whether to force apply the configuration :param datatype_presets: Presets to be applied for default types of data - :param bytes_encoding: The encoding of bytes in the data backend :param auto_schema: Whether to automatically create the schema. If True, the schema will be created if it does not exist. :param json_native: Whether the databackend supports json natively or not. :param log_colorize: Whether to colorize the logs + :param bytes_encoding: (Deprecated) :param output_prefix: The prefix for the output table and output field key :param vector_search_kwargs: The keyword arguments to pass to the vector search :param rest: Settings for rest server. @@ -190,12 +190,12 @@ class Config(BaseConfig): log_level: LogLevel = LogLevel.INFO logging_type: LogType = LogType.SYSTEM log_colorize: bool = True + bytes_encoding: str = 'bytes' force_apply: bool = False datatype_presets: DataTypePresets = dc.field(default_factory=DataTypePresets) - bytes_encoding: BytesEncoding = BytesEncoding.BYTES auto_schema: bool = True json_native: bool = True output_prefix: str = "_outputs__" diff --git a/superduper/base/document.py b/superduper/base/document.py index e51d4d6f0..1b9b71345 100644 --- a/superduper/base/document.py +++ b/superduper/base/document.py @@ -14,7 +14,6 @@ from superduper.base.variables import _replace_variables from superduper.components.component import Component from superduper.components.datatype import ( - BaseDataType, Blob, Encodable, FileItem, @@ -52,7 +51,7 @@ def _build_blob_getter(base_getter): return partial(_blob_getter, getter=base_getter) -class _Getters: +class Getters: """A class to manage getters for decoding documents. We have will have a list of getters for each type of reference. @@ -258,7 +257,7 @@ def decode( r, schema: t.Optional['Schema'] = None, db: t.Optional['Datalayer'] = None, - getters: t.Union[_Getters, t.Dict[str, t.Callable], None] = None, + getters: t.Union[Getters, t.Dict[str, t.Callable], None] = None, ): """Converts any dictionary into a Document or a Leaf. @@ -291,9 +290,9 @@ def decode( identifier = k builds[k]['identifier'] = identifier - if not isinstance(getters, _Getters): - getters = _Getters(getters) - assert isinstance(getters, _Getters) + if not isinstance(getters, Getters): + getters = Getters(getters) + assert isinstance(getters, Getters) # Prioritize using the local artifact storage getter, # and then use the DB read getter. @@ -314,7 +313,7 @@ def decode( if schema is not None: schema.init() - r = _schema_decode(schema, r, getters) + r = schema.decode_data(r, getters) r = _deep_flat_decode( {k: v for k, v in r.items() if k not in (KEY_BUILDS, KEY_BLOBS, KEY_FILES)}, @@ -586,42 +585,6 @@ def _deep_flat_encode( return r -def _schema_decode( - schema, data: dict[str, t.Any], getters: _Getters -) -> dict[str, t.Any]: - """Decode data using the schema's encoders. - - :param data: Data to decode. - """ - if schema.trivial: - return data - decoded = {} - for k, value in data.items(): - field = schema.fields.get(k) - if not isinstance(field, BaseDataType): - decoded[k] = value - continue - - value = data[k] - if reference := parse_reference(value): - value = getters.run(reference.name, reference.path) - if reference.name == 'blob': - kwargs = {'blob': value} - elif reference.name == 'file': - kwargs = {'x': value} - else: - assert False, f'Unknown reference type {reference.name}' - encodable = field.encodable_cls(datatype=field, **kwargs) - if not field.encodable_cls.lazy: - encodable = encodable.unpack() - decoded[k] = encodable - else: - decoded[k] = field.decode_data(data[k]) - - decoded.pop(KEY_SCHEMA, None) - return decoded - - def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None): if reference := parse_reference(f'?{k}'): if reference.name in getters: @@ -672,7 +635,7 @@ def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None) return leaf -def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer'] = None): +def _deep_flat_decode(r, builds, getters: Getters, db: t.Optional['Datalayer'] = None): if isinstance(r, Leaf): return r if isinstance(r, (list, tuple)): @@ -722,7 +685,8 @@ def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer'] if isinstance(r, str) and r.startswith('&'): assert getters is not None reference = parse_reference(r) - return getters.run(reference.name, reference.path) + out = getters.run(reference.name, reference.path) + return out return r @@ -741,7 +705,8 @@ def _get_component(db, path): return db.load(type_id=parts[0], identifier=parts[1]) if len(parts) == 3: if not _check_if_version(parts[2]): - return db.load(uuid=parts[2]) + out = db.load(uuid=parts[2]) + return out return db.load(type_id=parts[0], identifier=parts[1], version=parts[2]) raise ValueError(f'Invalid component reference: {path}') diff --git a/superduper/components/component.py b/superduper/components/component.py index 8dfb43cce..a9729ee0c 100644 --- a/superduper/components/component.py +++ b/superduper/components/component.py @@ -819,6 +819,7 @@ def _convert_components_to_refs(r): r['status'] = str(self.status) return Document(r) + # TODO needed? looks to have legacy "_content" @classmethod def decode(cls, r, db: t.Optional[t.Any] = None, reference: bool = False): """Decodes a dictionary component into a `Component` instance. diff --git a/superduper/components/datatype.py b/superduper/components/datatype.py index 09a39ff57..6eb199ee1 100644 --- a/superduper/components/datatype.py +++ b/superduper/components/datatype.py @@ -219,6 +219,7 @@ class NativeVector(BaseDataType): :param dtype: Datatype of array to encode. """ + encodable: t.ClassVar[str] = 'native' dtype: str = 'float64' def __post_init__(self, db, artifacts): @@ -234,6 +235,22 @@ def decode_data(self, item, info=None): return numpy.array(item).astype(self.dtype) +class Json2Str(BaseDataType): + """Datatype for encoding vectors which are supported natively by databackend.""" + + encodable: t.ClassVar[str] = 'native' + + def __post_init__(self, db, artifacts): + # self.encodable_cls = Native + return super().__post_init__(db, artifacts) + + def encode_data(self, item, info=None): + return json.dumps(item) + + def decode_data(self, item, info=None): + return json.loads(item) + + class DataType(BaseDataType): """A data type component that defines how data is encoded and decoded. @@ -324,7 +341,7 @@ def encode_data(self, item, info: t.Optional[t.Dict] = None): """ info = info or {} data = self.encoder(item, info) if self.encoder else item - data = self.bytes_encoding_after_encode(data) + # data = self.bytes_encoding_after_encode(data) return data @ensure_initialized @@ -335,7 +352,7 @@ def decode_data(self, item, info: t.Optional[t.Dict] = None): :param info: The optional information dictionary. """ info = info or {} - item = self.bytes_encoding_before_decode(item) + # item = self.bytes_encoding_before_decode(item) return self.decoder(item, info=info) if self.decoder else item def bytes_encoding_after_encode(self, data): @@ -789,14 +806,7 @@ def get_serializer( ) -json_serializer = DataType( - 'json', - encoder=json_encode, - decoder=json_decode, - encodable='encodable', - bytes_encoding=BytesEncoding.BASE64, - intermediate_type=IntermediateType.STRING, -) +json_serializer = Json2Str('json') pickle_encoder = get_serializer( @@ -876,6 +886,10 @@ def __post_init__(self, db, artifacts): def encodable_cls(self): return self.datatype_impl.encodable_cls + @property + def encodable(self): + return self.datatype_impl.encodable + @cached_property def datatype_impl(self): if isinstance(CFG.datatype_presets.vector, str): diff --git a/superduper/components/schema.py b/superduper/components/schema.py index ac1c22292..d0b9a89e5 100644 --- a/superduper/components/schema.py +++ b/superduper/components/schema.py @@ -1,12 +1,32 @@ +import base64 +import hashlib import typing as t from functools import cached_property +from superduper.base.constant import KEY_SCHEMA from superduper.base.leaf import Leaf from superduper.components.component import Component from superduper.components.datatype import BaseDataType, DataType from superduper.misc.reference import parse_reference from superduper.misc.special_dicts import SuperDuperFlatEncode +if t.TYPE_CHECKING: + from superduper.base.document import Getters + + +def get_hash(data): + """Get the hash of the given data. + + :param data: Data to hash. + """ + if isinstance(data, str): + bytes_ = data.encode() + elif isinstance(data, bytes): + bytes_ = data + else: + bytes_ = str(id(data)).encode() + return hashlib.sha1(bytes_).hexdigest() + class FieldType(Leaf): """Field type to represent the type of a field in a table. @@ -39,6 +59,14 @@ def __post_init__(self, db): ID = FieldType(identifier='ID') +def _convert_base64_to_bytes(str_: str) -> bytes: + return base64.b64decode(str_) + + +def _convert_bytes_to_base64(bytes_: bytes) -> str: + return base64.b64encode(bytes_).decode('utf-8') + + class Schema(Component): """A component carrying the `DataType` of columns. @@ -90,6 +118,47 @@ def fields_set(self): fields.add((k, v.identifier)) return fields + def decode_data( + self, data: dict[str, t.Any], getters: 'Getters' + ) -> dict[str, t.Any]: + """Decode data using the schema's encoders. + + :param data: Data to decode. + """ + if self.trivial: + return data + decoded = {} + for k, value in data.items(): + field = self.fields.get(k) + if not isinstance(field, BaseDataType): + decoded[k] = value + continue + + value = data[k] + if reference := parse_reference(value): + value = getters.run(reference.name, reference.path) + if reference.name == 'blob': + kwargs = {'blob': value} + elif reference.name == 'file': + kwargs = {'x': value} + else: + assert False, f'Unknown reference type {reference.name}' + encodable = field.encodable_cls(datatype=field, **kwargs) + if not field.encodable_cls.lazy: + encodable = encodable.unpack() + decoded[k] = encodable + else: + b = data[k] + if ( + field.encodable == 'encodable' + and self.db.databackend.bytes_encoding == 'base64' + ): + b = _convert_base64_to_bytes(b) + decoded[k] = field.decode_data(b) + + decoded.pop(KEY_SCHEMA, None) + return decoded + def encode_data(self, out, builds, blobs, files, leaves_to_keep=()): """Encode data using the schema's encoders. @@ -108,8 +177,19 @@ def encode_data(self, out, builds, blobs, files, leaves_to_keep=()): if isinstance(out[k], leaves_to_keep): continue - data, identifier = field.encode_data_with_identifier(out[k]) - if field.encodable_cls.artifact: + # data, identifier = field.encode_data_with_identifier(out[k]) + data = field.encode_data(out[k]) + + identifier = get_hash(data) + + if ( + field.encodable == 'encodable' + and self.db.databackend.bytes_encoding == 'base64' + ): + assert isinstance(data, bytes) + data = _convert_bytes_to_base64(data) + + if field.encodable in {'artifact', 'lazy_artifact'}: reference = field.encodable_cls.build_reference(identifier, data) ref_obj = parse_reference(reference) diff --git a/superduper/components/table.py b/superduper/components/table.py index 19193b23a..d5947aa8d 100644 --- a/superduper/components/table.py +++ b/superduper/components/table.py @@ -41,6 +41,7 @@ def __post_init__(self, db, artifacts): self.schema = Schema( self.schema.identifier, fields={**fields}, + db=db, ) def on_create(self, db: 'Datalayer'): diff --git a/superduper/ext/numpy/encoder.py b/superduper/ext/numpy/encoder.py index cc1a62bd9..5548ecad3 100644 --- a/superduper/ext/numpy/encoder.py +++ b/superduper/ext/numpy/encoder.py @@ -62,6 +62,7 @@ class Array(BaseDataType): def __post_init__(self, db, artifacts): self.encodable_cls = Encodable + self.encodable = 'encodable' return super().__post_init__(db, artifacts) def encode_data(self, item, info=None): diff --git a/templates/simple_rag/build.ipynb b/templates/simple_rag/build.ipynb index 25be39cb7..88289701c 100644 --- a/templates/simple_rag/build.ipynb +++ b/templates/simple_rag/build.ipynb @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "3ef70f6d-a189-460a-8864-241a689624e2", "metadata": { "editable": true, @@ -55,19 +55,17 @@ }, "outputs": [], "source": [ - "APPLY = False\n", + "APPLY = True\n", "SAMPLE_COLLECTION_NAME = 'sample_simple_rag'\n", "COLLECTION_NAME = '' if not APPLY else 'docs'\n", "ID_FIELD = '' if not APPLY else 'id'\n", "OUTPUT_PREFIX = '_outputs__'\n", - "BASE_URL = None\n", - "API_KEY = None\n", - "EAGER = False" + "EAGER = True" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "cb029a5e-fedf-4f07-8a31-d220cfbfbb3d", "metadata": { "editable": true, @@ -76,35 +74,19 @@ }, "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32m2024-Nov-18 11:25:01.22\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.misc.plugins\u001b[0m:\u001b[36m13 \u001b[0m | \u001b[1mLoading plugin: mongodb\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.22\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.datalayer\u001b[0m:\u001b[36m68 \u001b[0m | \u001b[1mBuilding Data Layer\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.22\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.build\u001b[0m:\u001b[36m184 \u001b[0m | \u001b[1mConfiguration: \n", - " +---------------+--------------+\n", - "| Configuration | Value |\n", - "+---------------+--------------+\n", - "| Data Backend | mongomock:// |\n", - "+---------------+--------------+\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "from superduper import superduper, CFG\n", "\n", "CFG.output_prefix = OUTPUT_PREFIX\n", "CFG.bytes_encoding = 'str'\n", - "CFG.native_json = False\n", "\n", "db = superduper('mongomock://')" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "4e7902bd", "metadata": { "editable": true, @@ -130,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "1ef8dd07-1b47-4dce-84dd-a081d1f5ee9d", "metadata": {}, "outputs": [], @@ -152,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "c5965fdf", "metadata": {}, "outputs": [], @@ -186,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "id": "2d20eaa0-a416-4483-938e-23f79845739a", "metadata": {}, "outputs": [], @@ -214,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "id": "93d21872-d4dc-40dc-abab-fb07ba102ea3", "metadata": {}, "outputs": [], @@ -233,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "id": "31900eec-b516-4bef-939e-2e8f46252b12", "metadata": {}, "outputs": [], @@ -281,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "id": "a9b1f538-65ca-499e-b6d0-2dd733f81723", "metadata": {}, "outputs": [], @@ -294,8 +276,6 @@ "openai_embedding = OpenAIEmbedding(\n", " identifier='text-embedding',\n", " model='text-embedding-ada-002',\n", - " datatype=sqlvector(shape=(1536,)),\n", - " client_kwargs={'base_url': BASE_URL, 'api_key': API_KEY},\n", ")" ] }, @@ -309,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "id": "4663fa4b-c2ec-427d-bf8b-b8b109cc2ccf", "metadata": {}, "outputs": [], @@ -332,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "id": "509c3505-54c5-4e68-84ec-3df8bea0fd74", "metadata": {}, "outputs": [], @@ -352,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "id": "f98e5ff4", "metadata": {}, "outputs": [], @@ -363,7 +343,6 @@ "llm_openai = OpenAIChatCompletion(\n", " identifier='llm-model',\n", " model='gpt-3.5-turbo',\n", - " client_kwargs={'base_url': BASE_URL, 'api_key': API_KEY}\n", ")" ] }, @@ -377,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "id": "44baeb09-6f35-4cf2-b814-46283a59f7e9", "metadata": {}, "outputs": [], @@ -404,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "id": "2d3a0d3a-da1c-41ec-b16c-f281c46ad794", "metadata": {}, "outputs": [], @@ -423,19 +402,10 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "id": "e6787c78-4b14-4a72-818b-450408a74331", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32m2024-Nov-18 11:25:01.28\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.application\u001b[0m:\u001b[36m39 \u001b[0m | \u001b[1mResorting components based on topological order.\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.28\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.application\u001b[0m:\u001b[36m56 \u001b[0m | \u001b[1mNew order of components: ['listener:chunker:e23e13cb64ae4975', 'vector_index:vectorindex:59324a8cd6574ee3', 'model:simple_rag:1c08b0b496914eec']\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "from superduper import Application\n", "\n", @@ -451,7 +421,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "id": "e7c16557-af76-4e70-83d9-2984e19a9554", "metadata": {}, "outputs": [], @@ -462,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "id": "2a82ea22-9694-4c65-b72f-c89ae49d1ab2", "metadata": {}, "outputs": [], @@ -489,24 +459,10 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "id": "2e850c03-33c6-4c88-95d3-d14146a6a0af", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32m2024-Nov-18 11:25:01.29\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m74 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.30\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m74 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.30\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m556 \u001b[0m | \u001b[33m\u001b[1mLeaf listener:chunker already exists\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.30\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m74 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.30\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m556 \u001b[0m | \u001b[33m\u001b[1mLeaf model:chunker already exists\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.30\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m556 \u001b[0m | \u001b[33m\u001b[1mLeaf datatype:dill already exists\u001b[0m\n", - "\u001b[32m2024-Nov-18 11:25:01.30\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m556 \u001b[0m | \u001b[33m\u001b[1mLeaf var-table-name-select-var-id-field-x already exists\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "from superduper import Template, Table, Schema\n", "from superduper.components.dataset import RemoteData\n", @@ -575,18 +531,10 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "id": "8924ba0d-7c01-4d6c-87fb-245531db7506", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32m2024-Nov-18 11:25:01.31\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m556 \u001b[0m | \u001b[33m\u001b[1mLeaf str already exists\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "template.export('.')" ] diff --git a/test/unittest/component/datatype/test_bytes_encoding.py b/test/unittest/component/datatype/test_bytes_encoding.py deleted file mode 100644 index 892f517a7..000000000 --- a/test/unittest/component/datatype/test_bytes_encoding.py +++ /dev/null @@ -1,10 +0,0 @@ -from superduper.components.vector_index import sqlvector - - -def test_bytes_encoding_str(): - dt = sqlvector(shape=(3,)) - - dt.intermediate_type = 'bytes' - dt.bytes_encoding = 'str' - encoded = dt.encode_data([1.1, 2.2, 3.3]) - assert isinstance(encoded, str) diff --git a/test/unittest/component/datatype/test_file.py b/test/unittest/component/datatype/test_file.py index c3367db16..659c16f55 100644 --- a/test/unittest/component/datatype/test_file.py +++ b/test/unittest/component/datatype/test_file.py @@ -31,8 +31,8 @@ def dt_file_lazy(): @pytest.mark.parametrize("datatype", datatypes) -def test_data_with_schema(datatype: DataType, random_data): - datatype_utils.check_data_with_schema(random_data, datatype) +def test_data_with_schema(db, datatype: DataType, random_data): + datatype_utils.check_data_with_schema(random_data, datatype, db=db) @pytest.mark.parametrize("datatype", datatypes) diff --git a/test/unittest/component/datatype/test_pickle.py b/test/unittest/component/datatype/test_pickle.py index c50685352..f152aa706 100644 --- a/test/unittest/component/datatype/test_pickle.py +++ b/test/unittest/component/datatype/test_pickle.py @@ -27,8 +27,8 @@ def random_data(): @pytest.mark.parametrize("datatype", datatypes) -def test_data_with_schema(datatype: DataType, random_data: pd.DataFrame): - datatype_utils.check_data_with_schema(random_data, datatype) +def test_data_with_schema(db, datatype: DataType, random_data: pd.DataFrame): + datatype_utils.check_data_with_schema(random_data, datatype, db) @pytest.mark.parametrize("datatype", datatypes) diff --git a/test/unittest/component/test_schema.py b/test/unittest/component/test_schema.py new file mode 100644 index 000000000..a4893eb48 --- /dev/null +++ b/test/unittest/component/test_schema.py @@ -0,0 +1,30 @@ +from superduper import Schema, Table +from superduper.components.datatype import pickle_encoder + + +def test_schema_with_bytes_encoding(db): + db.apply( + Table( + 'documents', + schema=Schema('_schema/documents', fields={'txt': pickle_encoder}), + ) + ) + + t = db.load('table', 'documents') + + assert t.schema.db is not None + + db.databackend.bytes_encoding = 'base64' + + db['documents'].insert([{'txt': 'testing 123'}]).execute() + + try: + r = db.databackend.db['documents'].find_one() + except Exception: + return + + print(r) + + assert isinstance(r['txt'], str) + + r = db['documents'].find_one() diff --git a/test/utils/component/datatype.py b/test/utils/component/datatype.py index 62800cdf8..779e606d3 100644 --- a/test/utils/component/datatype.py +++ b/test/utils/component/datatype.py @@ -45,10 +45,10 @@ def print_sep(): print("\n", "-" * 80, "\n") -def check_data_with_schema(data, datatype: DataType): +def check_data_with_schema(data, datatype: DataType, db): print("datatype", datatype) print_sep() - schema = Schema(identifier="schema", fields={"x": datatype, "y": int}) + schema = Schema(identifier="schema", fields={"x": datatype, "y": int}, db=db) document = Document({"x": data, "y": 1}) print(document) @@ -59,7 +59,7 @@ def check_data_with_schema(data, datatype: DataType): print_sep() decoded = Document.decode(encoded, schema=schema) - if datatype.encodable_cls.lazy: + if datatype.encodable == 'lazy_artifact': assert isinstance(decoded["x"], datatype.encodable_cls) assert isinstance(decoded["x"].x, type(data)) decoded = Document(decoded.unpack()) @@ -96,7 +96,7 @@ def check_data_with_schema_and_db(data, datatype: DataType, db: Datalayer): decoded = list(db["documents"].select().execute())[0] - if datatype.encodable_cls.lazy: + if datatype.encodable == 'lazy_artifact': assert isinstance(decoded["x"], datatype.encodable_cls) assert isinstance(decoded["x"].x, Empty) decoded = Document(decoded.unpack())