Skip to content

Commit

Permalink
Refactor types to use simpler design
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 26, 2024
1 parent a392ee2 commit ac606d6
Show file tree
Hide file tree
Showing 71 changed files with 1,345 additions and 1,733 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### Changed defaults / behaviours

- Deprecate vanilla `DataType`
- Remove `_Encodable` from project

#### New Features & Functionality

- Streamlit component and server
Expand Down
4 changes: 2 additions & 2 deletions plugins/anthropic/superduper_anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Anthropic(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
def __post_init__(self, db, example):
self.model = self.model or self.identifier
super().__post_init__(db, artifacts, example=example)
super().__post_init__(db, example=example)

def init(self, db=None):
"""Initialize the model.
Expand Down
8 changes: 4 additions & 4 deletions plugins/cohere/superduper_cohere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class Cohere(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)
self.identifier = self.identifier or self.model


Expand All @@ -47,8 +47,8 @@ class CohereEmbed(Cohere):
batch_size: int = 100
signature: str = 'singleton'

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)
if self.shape is None:
self.shape = self.shapes[self.identifier]

Expand Down
6 changes: 3 additions & 3 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from superduper.backends.local.artifacts import FileSystemArtifactStore
from superduper.base import exceptions
from superduper.base.enums import DBType
from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.components.table import Table

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

if uri.startswith('snowflake://') or uri.startswith('sqlite://'):
if uri.startswith('snowflake://'):
self.bytes_encoding = 'base64'

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}
Expand Down Expand Up @@ -190,7 +190,7 @@ def drop_table_or_collection(self, name: str):
def create_output_dest(
self,
predict_id: str,
datatype: t.Union[FieldType, DataType],
datatype: t.Union[FieldType, BaseDataType],
flatten: bool = False,
):
"""Create a table for the output of the model.
Expand Down
4 changes: 2 additions & 2 deletions plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from superduper.base.cursor import SuperDuperCursor
from superduper.base.exceptions import DatabackendException
from superduper.components.datatype import Encodable
from superduper.components.datatype import _Encodable
from superduper.components.schema import Schema
from superduper.misc.special_dicts import SuperDuperFlatEncode

Expand Down Expand Up @@ -81,7 +81,7 @@ def _model_update_impl(
d = {
"_source": str(source_id),
f"{CFG.output_prefix}{predict_id}": output.x
if isinstance(output, Encodable)
if isinstance(output, _Encodable)
else output,
"id": str(uuid.uuid4()),
}
Expand Down
8 changes: 0 additions & 8 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from ibis.expr.datatypes import dtype
from superduper.components.datatype import (
Artifact,
BaseDataType,
File,
LazyArtifact,
LazyFile,
Native,
)
from superduper.components.schema import ID, FieldType, Schema

SPECIAL_ENCODABLES_FIELDS = {
File: "str",
LazyFile: "str",
Artifact: "str",
LazyArtifact: "str",
Native: "json",
}


Expand Down
8 changes: 4 additions & 4 deletions plugins/jina/superduper_jina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Jina(APIBaseModel):

api_key: t.Optional[str] = None

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)
self.identifier = self.identifier or self.model
self.client = JinaAPIClient(model_name=self.identifier, api_key=self.api_key)

Expand All @@ -41,8 +41,8 @@ class JinaEmbedding(Jina):
shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
def __post_init__(self, db, example):
super().__post_init__(db, example)
if self.shape is None:
self.shape = (len(self.client.encode_batch(['shape'])[0]),)

Expand Down
4 changes: 2 additions & 2 deletions plugins/mongodb/superduper_mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.base.enums import DBType
from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.misc.colors import Colors

Expand Down Expand Up @@ -140,7 +140,7 @@ def disconnect(self):
def create_output_dest(
self,
predict_id: str,
datatype: t.Union[str, DataType],
datatype: t.Union[str, BaseDataType],
flatten: bool = False,
):
"""Create an output collection for a component.
Expand Down
1 change: 1 addition & 0 deletions plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def process_find_part(part):
method, args, kwargs = part
# args: (filter, projection, *args)
filter = copy.deepcopy(args[0]) if len(args) > 0 else {}
filter = dict(filter)
filter.update(self._get_filter_conditions())
args = tuple((filter, *args[1:]))

Expand Down
8 changes: 4 additions & 4 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class _OpenAI(APIBaseModel):
openai_api_base: t.Optional[str] = None
client_kwargs: t.Optional[dict] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
def __post_init__(self, db, example):
super().__post_init__(db, example)

assert isinstance(self.client_kwargs, dict)

Expand Down Expand Up @@ -151,8 +151,8 @@ class OpenAIChatCompletion(_OpenAI):
batch_size: int = 1
prompt: str = ''

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
def __post_init__(self, db, example):
super().__post_init__(db, example)
self.takes_context = True

def _format_prompt(self, context, X):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from superduper.backends.query_dataset import QueryDataset
from superduper.base.enums import DBType
from superduper.components.component import ensure_initialized
from superduper.components.datatype import DataType, dill_lazy
from superduper.components.datatype import dill_serializer
from superduper.components.model import Model, Signature, _DeviceManaged

DEFAULT_PREDICT_KWARGS = {
Expand Down Expand Up @@ -39,9 +39,7 @@ class SentenceTransformer(Model, _DeviceManaged):
"""

_artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = (
('object', dill_lazy),
)
_fields = {'object': dill_serializer}

object: t.Optional[_SentenceTransformer] = None
model: t.Optional[str] = None
Expand All @@ -50,8 +48,8 @@ class SentenceTransformer(Model, _DeviceManaged):
postprocess: t.Union[None, t.Callable] = None
signature: Signature = 'singleton'

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)

if self.model is None:
self.model = self.identifier
Expand Down
2 changes: 1 addition & 1 deletion plugins/sklearn/plugin_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_sklearn(db):
identifier='test',
object=SVC(),
)
assert 'object' in m.artifact_schema.fields
assert 'object' in m.class_schema.fields
db.apply(m, force=True)
assert db.show('model') == ['test']

Expand Down
4 changes: 2 additions & 2 deletions plugins/torch/superduper_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ class TorchModel(Model, _DeviceManaged):
optimizer_state: t.Optional[t.Any] = None
loader_kwargs: t.Dict = dc.field(default_factory=lambda: {})

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts=artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)

if self.optimizer_state is not None:
self.optimizer.load_state_dict(self.optimizer_state)
Expand Down
12 changes: 6 additions & 6 deletions plugins/transformers/superduper_transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ class TransformersTrainer(TrainingArguments, Trainer):
t.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
] = None

def __post_init__(self, db, artifacts):
def __post_init__(self, db):
assert self.output_dir == '' or self.output_dir == self.identifier
self.output_dir = self.identifier
TrainingArguments.__post_init__(self)
return Trainer.__post_init__(self, db, artifacts)
return Trainer.__post_init__(self, db)

@property
def native_arguments(self):
Expand Down Expand Up @@ -214,10 +214,10 @@ def _build_pipeline(self):
model=self.model_cls.from_pretrained(self.model_name),
)

def __post_init__(self, db, artifacts, example):
def __post_init__(self, db, example):
if self.pipeline is None:
self._build_pipeline()
super().__post_init__(db, artifacts, example)
super().__post_init__(db, example)

def predict(self, text: str):
"""Predict the class of a single text.
Expand Down Expand Up @@ -284,12 +284,12 @@ class LLM(BaseLLM):
("tokenizer_kwargs", dill_serializer),
)

def __post_init__(self, db, artifacts, example):
def __post_init__(self, db, example):
if not self.identifier:
self.identifier = self.adapter_id or self.model_name_or_path

# TODO: Compatible with the bug of artifact sha1 equality and will be deleted
super().__post_init__(db, artifacts, example)
super().__post_init__(db, example)

@classmethod
def from_pretrained(
Expand Down
4 changes: 2 additions & 2 deletions plugins/transformers/superduper_transformers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ class LLMTrainer(TrainingArguments, SuperDuperTrainer):
num_gpus: t.Optional[int] = None
ray_configs: t.Optional[dict] = None

def __post_init__(self, db, artifacts):
def __post_init__(self, db):
self.output_dir = self.output_dir or os.path.join("output", self.identifier)
if self.num_gpus and 'num_gpus' not in self.compute_kwargs:
self.compute_kwargs['num_gpus'] = self.num_gpus
return SuperDuperTrainer.__post_init__(self, db, artifacts)
return SuperDuperTrainer.__post_init__(self, db)

def build(self):
"""Build the training arguments."""
Expand Down
2 changes: 1 addition & 1 deletion plugins/vllm/superduper_vllm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class _VLLMCore(Model):

vllm_params: dict = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
def __post_init__(self, db, example):
super().__post_init__(db, artifacts, example)
assert "model" in self.vllm_params, "model is required in vllm_params"
self._async_llm = None
Expand Down
7 changes: 3 additions & 4 deletions superduper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .components.application import Application
from .components.component import Component
from .components.dataset import Dataset
from .components.datatype import DataType, dill_serializer, pickle_serializer
from .components.datatype import BaseDataType, dill_serializer, pickle_serializer
from .components.listener import Listener
from .components.metric import Metric
from .components.model import (
Expand All @@ -39,7 +39,7 @@
from .components.streamlit import Streamlit
from .components.table import Table
from .components.template import QueryTemplate, Template
from .components.vector_index import VectorIndex, vector
from .components.vector_index import VectorIndex

REQUIRES = [
'superduper=={}'.format(__version__),
Expand All @@ -52,7 +52,7 @@
'config',
'logging',
'superduper',
'DataType',
'BaseDataType',
'Document',
'code',
'ObjectModel',
Expand All @@ -62,7 +62,6 @@
'model',
'Listener',
'VectorIndex',
'vector',
'Dataset',
'Metric',
'Plugin',
Expand Down
4 changes: 2 additions & 2 deletions superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from superduper import logging
from superduper.backends.base.query import Query
from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType

if t.TYPE_CHECKING:
from superduper.components.schema import Schema
Expand Down Expand Up @@ -75,7 +75,7 @@ def build_artifact_store(self):
def create_output_dest(
self,
predict_id: str,
datatype: t.Union[str, DataType],
datatype: t.Union[str, BaseDataType],
flatten: bool = False,
):
"""Create an output destination for the database.
Expand Down
Loading

0 comments on commit ac606d6

Please sign in to comment.