Skip to content

Commit

Permalink
Fixes for _pre_create
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Oct 10, 2024
1 parent 733e112 commit 796f95f
Show file tree
Hide file tree
Showing 17 changed files with 886 additions and 1,090 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated CONTRIBUTING.md
- Add README.md files for the plugins.
- Add templates to project
- Add frontend to project
- Add frontend to project

#### New Features & Functionality

Expand Down
15 changes: 2 additions & 13 deletions plugins/anthropic/superduper_anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import anthropic
from anthropic import APIConnectionError, APIError, APIStatusError, APITimeoutError
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel
from superduper.ext.utils import format_prompt, get_key
from superduper.misc.retry import Retry
Expand All @@ -24,9 +23,9 @@ class Anthropic(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts):
def __post_init__(self, db, artifacts, example):
self.model = self.model or self.identifier
super().__post_init__(db, artifacts)
super().__post_init__(db, artifacts, example=example)
self.client = anthropic.Anthropic(
api_key=get_key(KEY_NAME), **self.client_kwargs
)
Expand All @@ -51,16 +50,6 @@ class AnthropicCompletions(Anthropic):

prompt: str = ''

def pre_create(self, db: Datalayer) -> None:
"""Pre create method for the model.
If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.
:param db: The datalayer to use for the model.
"""
super().pre_create(db)

@retry
def predict(
self,
Expand Down
22 changes: 5 additions & 17 deletions plugins/cohere/superduper_cohere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tqdm
from cohere.error import CohereAPIError, CohereConnectionError
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel
from superduper.components.vector_index import vector
from superduper.ext.utils import format_prompt, get_key
Expand All @@ -24,8 +23,8 @@ class Cohere(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
self.identifier = self.identifier or self.model


Expand All @@ -48,20 +47,19 @@ class CohereEmbed(Cohere):
batch_size: int = 100
signature: str = 'singleton'

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
if self.shape is None:
self.shape = self.shapes[self.identifier]

def pre_create(self, db):
def _pre_create(self, db):
"""Pre create method for the model.
If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.
:param db: The datalayer to use for the model.
"""
super().pre_create(db)
if self.datatype is None:
self.datatype = vector(shape=self.shape)

Expand Down Expand Up @@ -114,16 +112,6 @@ class CohereGenerate(Cohere):
takes_context: bool = True
prompt: str = ''

def pre_create(self, db: Datalayer) -> None:
"""Pre create method for the model.
If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.
:param db: The datalayer to use for the model.
"""
super().pre_create(db)

@retry
def predict(self, prompt: str, context: t.Optional[t.List[str]] = None):
"""Predict the generation of a single prompt.
Expand Down
11 changes: 5 additions & 6 deletions plugins/jina/superduper_jina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Jina(APIBaseModel):

api_key: t.Optional[str] = None

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
self.identifier = self.identifier or self.model
self.client = JinaAPIClient(model_name=self.identifier, api_key=self.api_key)

Expand All @@ -41,20 +41,19 @@ class JinaEmbedding(Jina):
shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
if self.shape is None:
self.shape = (len(self.client.encode_batch(['shape'])[0]),)

def pre_create(self, db):
def _pre_create(self, db):
"""Pre create method for the model.
If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.
:param db: The datalayer to use for the model.
"""
super().pre_create(db)
if self.datatype is None:
self.datatype = vector(shape=self.shape)

Expand Down
18 changes: 6 additions & 12 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,13 @@ def __post_init__(self, db, artifacts, example):
if self.shape is None:
self.shape = self.shapes[self.model]

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer) -> None:
"""Pre creates the model.
the datatype is set to ``vector``.
:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or vector(shape=self.shape)

@retry
Expand Down Expand Up @@ -177,12 +176,11 @@ def _format_prompt(self, context, X):
prompt = self.prompt.format(context='\n'.join(context))
return prompt + X

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer) -> None:
"""Pre creates the model.
:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'str'

@retry
Expand Down Expand Up @@ -246,12 +244,11 @@ class OpenAIImageCreation(_OpenAI):
n: int = 1
response_format: str = 'b64_json'

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Pre creates the model.
:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'bytes'

def _format_prompt(self, context, X):
Expand Down Expand Up @@ -332,12 +329,11 @@ def _format_prompt(self, context):
prompt = self.prompt.format(context='\n'.join(context))
return prompt

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Pre creates the model.
:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'bytes'

@retry
Expand Down Expand Up @@ -428,12 +424,11 @@ class OpenAIAudioTranscription(_OpenAI):
takes_context: bool = True
prompt: str = ''

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Pre creates the model.
:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'str'

@retry
Expand Down Expand Up @@ -496,12 +491,11 @@ class OpenAIAudioTranslation(_OpenAI):
prompt: str = ''
batch_size: int = 1

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Translates a file-like Audio recording to English.
:param db: The datalayer to use for the model.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'str'

@retry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class SentenceTransformer(Model, _DeviceManaged):
postprocess: t.Union[None, t.Callable] = None
signature: Signature = 'singleton'

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)

if self.model is None:
self.model = self.identifier
Expand Down Expand Up @@ -118,15 +118,14 @@ def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
results = self.postprocess(results)
return results

def pre_create(self, db):
def _pre_create(self, db):
"""Pre creates the model.
If the datatype is not set and the datalayer is an IbisDataBackend,
the datatype is set to ``sqlvector`` or ``vector``.
:param db: The datalayer instance.
"""
super().pre_create(db)
if self.datatype is not None:
return

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ ignore = [
"D212", # Multi-line docstring summary should start at the first line
"D213", # Multi-line docstring summary should start at the second line
"D401",
"D102",
"E402",
]

Expand Down
33 changes: 33 additions & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import requests
import tqdm
from overrides import override

from superduper import CFG, logging
from superduper.backends.base.query import Query
Expand Down Expand Up @@ -1260,3 +1261,35 @@ def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
else:
out = p.predict_batches(out)
return out


class ModelRouter(Model):
"""ModelRouter component which routes the model to the correct model.
:param models: A dictionary of models to use
:param model: The model to use
"""

models: t.Dict[str, Model]
model: str

def _pre_create(self, db):
self.datatype = self.models[self.model].datatype

@override
def predict(self, *args, **kwargs) -> t.Any:
logging.info(f'Predicting with model {self.model}')
return self.models[self.model].predict(*args, **kwargs)

@override
def predict_batches(self, dataset) -> t.List:
logging.info(f'Predicting with model {self.model}')
return self.models[self.model].predict_batches(dataset)

@override
def init(self, db):
if hasattr(self.models[self.model], 'shape'):
self.shape = getattr(self.models[self.model], 'shape')
self.example = self.models[self.model].example
self.signature = self.models[self.model].signature
self.models[self.model].init()
24 changes: 21 additions & 3 deletions superduper/components/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import os
import typing as t

from superduper import logging
from superduper.base.constant import KEY_BLOBS, KEY_FILES
from superduper.base.datalayer import Datalayer
from superduper.base.document import Document, QueryUpdateDocument
from superduper.base.leaf import Leaf
from superduper.base.variables import _replace_variables
from superduper.components.component import Component, _build_info_from_path
from superduper.components.datatype import pickle_serializer
from superduper.misc.special_dicts import SuperDuperFlatEncode

from .component import ensure_initialized
Expand All @@ -19,7 +21,7 @@ class _BaseTemplate(Component):
:param template: Template component with variables.
:param template_variables: Variables to be set.
:param info: Additional information.
:param types: Additional information about types of variables.
:param blobs: Blob identifiers in `Template.component`.
:param files: File identifiers in `Template.component`.
:param substitutions: Substitutions to be made to create variables.
Expand All @@ -30,7 +32,7 @@ class _BaseTemplate(Component):

template: t.Union[t.Dict, Component]
template_variables: t.Optional[t.List[str]] = None
info: t.Optional[t.Dict] = dc.field(default_factory=dict)
types: t.Optional[t.Dict] = dc.field(default_factory=dict)
blobs: t.Optional[t.List[str]] = None
files: t.Optional[t.List[str]] = None
substitutions: dc.InitVar[t.Optional[t.Dict]] = None
Expand Down Expand Up @@ -75,17 +77,33 @@ def form_template(self):


class Template(_BaseTemplate):
"""Application template component."""
"""Application template component.
:param data: Sample data to test the template.
"""

_artifacts: t.ClassVar[t.Tuple[str]] = (('data', pickle_serializer),)

type_id: t.ClassVar[str] = "template"

data: t.List[t.Dict] | None = None

def pre_create(self, db: Datalayer) -> None:
"""Run before the object is created."""
super().pre_create(db)
assert isinstance(self.template, dict)
self.blobs = list(self.template.get(KEY_BLOBS, {}).keys())
self.files = list(self.template.get(KEY_FILES, {}).keys())
db.artifact_store.save_artifact(self.template)
if self.data is not None:
if not db.cfg.auto_schema:
logging.warn('Auto schema is disabled. Skipping data insertion.')
return
db[self.default_table].insert(self.data).execute()

@property
def default_table(self):
return f'_sample_{self.identifier}'

def export(
self,
Expand Down
Loading

0 comments on commit 796f95f

Please sign in to comment.