Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for _pre_create #2517

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated CONTRIBUTING.md
- Add README.md files for the plugins.
- Add templates to project
- Add frontend to project
- Add frontend to project

#### New Features & Functionality

Expand Down
15 changes: 2 additions & 13 deletions plugins/anthropic/superduper_anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import anthropic
from anthropic import APIConnectionError, APIError, APIStatusError, APITimeoutError
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel
from superduper.ext.utils import format_prompt, get_key
from superduper.misc.retry import Retry
Expand All @@ -24,9 +23,9 @@ class Anthropic(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts):
def __post_init__(self, db, artifacts, example):
self.model = self.model or self.identifier
super().__post_init__(db, artifacts)
super().__post_init__(db, artifacts, example=example)
self.client = anthropic.Anthropic(
api_key=get_key(KEY_NAME), **self.client_kwargs
)
Expand All @@ -51,16 +50,6 @@ class AnthropicCompletions(Anthropic):

prompt: str = ''

def pre_create(self, db: Datalayer) -> None:
"""Pre create method for the model.

If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.

:param db: The datalayer to use for the model.
"""
super().pre_create(db)

@retry
def predict(
self,
Expand Down
22 changes: 5 additions & 17 deletions plugins/cohere/superduper_cohere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tqdm
from cohere.error import CohereAPIError, CohereConnectionError
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel
from superduper.components.vector_index import vector
from superduper.ext.utils import format_prompt, get_key
Expand All @@ -24,8 +23,8 @@ class Cohere(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
self.identifier = self.identifier or self.model


Expand All @@ -48,20 +47,19 @@ class CohereEmbed(Cohere):
batch_size: int = 100
signature: str = 'singleton'

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
if self.shape is None:
self.shape = self.shapes[self.identifier]

def pre_create(self, db):
def _pre_create(self, db):
"""Pre create method for the model.

If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.

:param db: The datalayer to use for the model.
"""
super().pre_create(db)
if self.datatype is None:
self.datatype = vector(shape=self.shape)

Expand Down Expand Up @@ -114,16 +112,6 @@ class CohereGenerate(Cohere):
takes_context: bool = True
prompt: str = ''

def pre_create(self, db: Datalayer) -> None:
"""Pre create method for the model.

If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.

:param db: The datalayer to use for the model.
"""
super().pre_create(db)

@retry
def predict(self, prompt: str, context: t.Optional[t.List[str]] = None):
"""Predict the generation of a single prompt.
Expand Down
11 changes: 5 additions & 6 deletions plugins/jina/superduper_jina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Jina(APIBaseModel):

api_key: t.Optional[str] = None

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
self.identifier = self.identifier or self.model
self.client = JinaAPIClient(model_name=self.identifier, api_key=self.api_key)

Expand All @@ -41,20 +41,19 @@ class JinaEmbedding(Jina):
shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
if self.shape is None:
self.shape = (len(self.client.encode_batch(['shape'])[0]),)

def pre_create(self, db):
def _pre_create(self, db):
"""Pre create method for the model.

If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.

:param db: The datalayer to use for the model.
"""
super().pre_create(db)
if self.datatype is None:
self.datatype = vector(shape=self.shape)

Expand Down
18 changes: 6 additions & 12 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,13 @@ def __post_init__(self, db, artifacts, example):
if self.shape is None:
self.shape = self.shapes[self.model]

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer) -> None:
"""Pre creates the model.

the datatype is set to ``vector``.

:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or vector(shape=self.shape)

@retry
Expand Down Expand Up @@ -177,12 +176,11 @@ def _format_prompt(self, context, X):
prompt = self.prompt.format(context='\n'.join(context))
return prompt + X

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer) -> None:
"""Pre creates the model.

:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'str'

@retry
Expand Down Expand Up @@ -246,12 +244,11 @@ class OpenAIImageCreation(_OpenAI):
n: int = 1
response_format: str = 'b64_json'

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Pre creates the model.

:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'bytes'

def _format_prompt(self, context, X):
Expand Down Expand Up @@ -332,12 +329,11 @@ def _format_prompt(self, context):
prompt = self.prompt.format(context='\n'.join(context))
return prompt

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Pre creates the model.

:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'bytes'

@retry
Expand Down Expand Up @@ -428,12 +424,11 @@ class OpenAIAudioTranscription(_OpenAI):
takes_context: bool = True
prompt: str = ''

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Pre creates the model.

:param db: The datalayer instance.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'str'

@retry
Expand Down Expand Up @@ -496,12 +491,11 @@ class OpenAIAudioTranslation(_OpenAI):
prompt: str = ''
batch_size: int = 1

def pre_create(self, db: Datalayer, startup_cache={}) -> None:
def _pre_create(self, db: Datalayer):
"""Translates a file-like Audio recording to English.

:param db: The datalayer to use for the model.
"""
super().pre_create(db, startup_cache=startup_cache)
self.datatype = self.datatype or 'str'

@retry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class SentenceTransformer(Model, _DeviceManaged):
postprocess: t.Union[None, t.Callable] = None
signature: Signature = 'singleton'

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)

if self.model is None:
self.model = self.identifier
Expand Down Expand Up @@ -118,15 +118,14 @@ def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
results = self.postprocess(results)
return results

def pre_create(self, db):
def _pre_create(self, db):
"""Pre creates the model.

If the datatype is not set and the datalayer is an IbisDataBackend,
the datatype is set to ``sqlvector`` or ``vector``.

:param db: The datalayer instance.
"""
super().pre_create(db)
if self.datatype is not None:
return

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ ignore = [
"D212", # Multi-line docstring summary should start at the first line
"D213", # Multi-line docstring summary should start at the second line
"D401",
"D102",
"E402",
]

Expand Down
33 changes: 33 additions & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import requests
import tqdm
from overrides import override

from superduper import CFG, logging
from superduper.backends.base.query import Query
Expand Down Expand Up @@ -1260,3 +1261,35 @@ def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
else:
out = p.predict_batches(out)
return out


class ModelRouter(Model):
"""ModelRouter component which routes the model to the correct model.

:param models: A dictionary of models to use
:param model: The model to use
"""

models: t.Dict[str, Model]
model: str

def _pre_create(self, db):
self.datatype = self.models[self.model].datatype

@override
def predict(self, *args, **kwargs) -> t.Any:
logging.info(f'Predicting with model {self.model}')
return self.models[self.model].predict(*args, **kwargs)

@override
def predict_batches(self, dataset) -> t.List:
logging.info(f'Predicting with model {self.model}')
return self.models[self.model].predict_batches(dataset)

@override
def init(self, db):
if hasattr(self.models[self.model], 'shape'):
self.shape = getattr(self.models[self.model], 'shape')
self.example = self.models[self.model].example
self.signature = self.models[self.model].signature
self.models[self.model].init()
24 changes: 21 additions & 3 deletions superduper/components/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import os
import typing as t

from superduper import logging
from superduper.base.constant import KEY_BLOBS, KEY_FILES
from superduper.base.datalayer import Datalayer
from superduper.base.document import Document, QueryUpdateDocument
from superduper.base.leaf import Leaf
from superduper.base.variables import _replace_variables
from superduper.components.component import Component, _build_info_from_path
from superduper.components.datatype import pickle_serializer
from superduper.misc.special_dicts import SuperDuperFlatEncode

from .component import ensure_initialized
Expand All @@ -19,7 +21,7 @@ class _BaseTemplate(Component):

:param template: Template component with variables.
:param template_variables: Variables to be set.
:param info: Additional information.
:param types: Additional information about types of variables.
:param blobs: Blob identifiers in `Template.component`.
:param files: File identifiers in `Template.component`.
:param substitutions: Substitutions to be made to create variables.
Expand All @@ -30,7 +32,7 @@ class _BaseTemplate(Component):

template: t.Union[t.Dict, Component]
template_variables: t.Optional[t.List[str]] = None
info: t.Optional[t.Dict] = dc.field(default_factory=dict)
types: t.Optional[t.Dict] = dc.field(default_factory=dict)
blobs: t.Optional[t.List[str]] = None
files: t.Optional[t.List[str]] = None
substitutions: dc.InitVar[t.Optional[t.Dict]] = None
Expand Down Expand Up @@ -75,17 +77,33 @@ def form_template(self):


class Template(_BaseTemplate):
"""Application template component."""
"""Application template component.

:param data: Sample data to test the template.
"""

_artifacts: t.ClassVar[t.Tuple[str]] = (('data', pickle_serializer),)

type_id: t.ClassVar[str] = "template"

data: t.List[t.Dict] | None = None

def pre_create(self, db: Datalayer) -> None:
"""Run before the object is created."""
super().pre_create(db)
assert isinstance(self.template, dict)
self.blobs = list(self.template.get(KEY_BLOBS, {}).keys())
self.files = list(self.template.get(KEY_FILES, {}).keys())
db.artifact_store.save_artifact(self.template)
if self.data is not None:
if not db.cfg.auto_schema:
logging.warn('Auto schema is disabled. Skipping data insertion.')
return
db[self.default_table].insert(self.data).execute()

@property
def default_table(self):
return f'_sample_{self.identifier}'

def export(
self,
Expand Down
Loading
Loading