Skip to content

Commit

Permalink
Fix rest API with templates
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jun 21, 2024
1 parent bf06094 commit c95da41
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 28 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make output of `Document.encode()` a bit more minimalistic
- Increment minimum supported ibis version to 9.0.0
- Make database connections reconnection on token expiry
- Prototype the cron job service
- Prototype the cron job services

#### New Features & Functionality

Expand Down
2 changes: 1 addition & 1 deletion docs/content/tutorials/rag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "2ca02914-dac0-42ac-ac90-d1ebe87e6774",
"metadata": {},
"source": [
"# Basic RAG tutorial with templates"
"# Basic RAG tutorial"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/content/tutorials/rag.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# Basic RAG tutorial with templates
# Basic RAG tutorial

:::info
In this tutorial we show you how to do retrieval augmented generation (RAG) with `superduperdb`.
Expand Down
8 changes: 7 additions & 1 deletion superduperdb/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,15 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None):

def _build_hr_identifier(self):
identifier = str(self).split('\n')[-1]
identifier = re.sub(r'[^a-zA-Z0-9]', '-', identifier)
variables = re.findall(r'(<var:[a-zA-Z0-9]+>)', identifier)
variables = sorted(list(set(variables)))
for i, v in enumerate(variables):
identifier = identifier.replace(v, f'#{i}')
identifier = re.sub(r'[^a-zA-Z0-9#]', '-', identifier)
identifier = re.sub('[-]+$', '', identifier)
identifier = re.sub('[-]+', '-', identifier)
for i, v in enumerate(variables):
identifier = identifier.replace(f'#{i}', v)
return identifier

def __getitem__(self, item):
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/backends/ray/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def submit(self, identifier, dependencies=(), compute_kwargs={}):
job_string += ")"

entrypoint = (
f"python -c 'from superduperdb.jobs.job import remote_job; {job_string}'"
f"python3 -c 'from superduperdb.jobs.job import remote_job; {job_string}'"
)

runtime_env = {}
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def variables(self) -> t.List[str]:
"""Return a list of variables in the object."""
from superduperdb.base.variables import _find_variables

return _find_variables(self)
return sorted(list(set(_find_variables(self))))

def set_variables(self, **kwargs) -> 'Document':
"""Set free variables of self.
Expand Down
5 changes: 4 additions & 1 deletion superduperdb/base/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def _replace_variables(x, **kwargs):
_replace_variables(k, **kwargs): _replace_variables(v, **kwargs)
for k, v in x.items()
}
if isinstance(x, str) and re.match(r'^<var:(.*?)>$', x) is not None:
return kwargs.get(x[5:-1], x)
if isinstance(x, str):
variables = re.findall(r'<var:(.*?)>', x)
variables = list(map(lambda v: v.strip(), variables))
Expand All @@ -31,7 +33,8 @@ def _replace_variables(x, **kwargs):
if isinstance(v, str):
x = x.replace(f'<var:{k}>', v)
else:
x = v
x = re.sub('[<>:]', '-', x)
x = re.sub('[-]+', '-', x)
return x
if isinstance(x, (list, tuple)):
return [_replace_variables(v, **kwargs) for v in x]
Expand Down
25 changes: 25 additions & 0 deletions superduperdb/components/application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing as t

from superduperdb.base.datalayer import Datalayer

from .component import Component


Expand All @@ -8,8 +10,31 @@ class Application(Component):
A placeholder to hold list of components with associated funcionality.
:param components: List of components to group together and apply to `superduperdb`.
:param namespace: List of tuples with type_id and identifier of components to
assist in managing application.
"""

literals: t.ClassVar[t.Sequence[str]] = ('template',)
type_id: t.ClassVar[str] = 'application'
components: t.Sequence[Component]
namespace: t.Optional[t.Sequence[t.Tuple[str, str]]] = None

def pre_create(self, db: Datalayer):
"""Pre-create hook.
:param db: Datalayer instance
"""
self.namespace = [
{'type_id': c.type_id, 'identifier': c.identifier} for c in self.children
]
return super().pre_create(db)

def cleanup(self, db: Datalayer):
"""Cleanup hook.
:param db: Datalayer instance
"""
if self.namespace is not None:
for type_id, identifier in self.namespace:
db.remove(type_id=type_id, identifier=identifier, force=True)
return super().cleanup(db)
15 changes: 10 additions & 5 deletions superduperdb/components/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,26 @@ def export(
Save `self` to a directory using super-duper protocol.
:param path: Path to the directory to save the component.
:param format: Format to save the component in.
:param zip: Whether to zip the directory.
:param defaults: Whether to save default values.
:param metadata: Whether to save metadata.
Created directory structure:
```
|_component.json/yaml
|_component.(json|yaml)
|_blobs/*
|_files/*
```
"""
assert self.db is not None
assert self.identifier in self.db.show('template')
if self.blobs is not None and self.blobs:
assert self.db is not None
assert self.identifier in self.db.show('template')
if path is None:
path = './' + self.identifier
super().export(path, format, zip=False, defaults=defaults, metadata=metadata)
os.makedirs(os.path.join(path, 'blobs'), exist_ok=True)
if self.blobs is not None:
if self.blobs is not None and self.blobs:
os.makedirs(os.path.join(path, 'blobs'), exist_ok=True)
for identifier in self.blobs:
blob = self.db.artifact_store.get_bytes(identifier)
with open(path + f'/blobs/{identifier}', 'wb') as f:
Expand Down
3 changes: 1 addition & 2 deletions superduperdb/ext/llm/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ class RetrievalPrompt(QueryModel):
join: str = "\n---\n"

def __post_init__(self, db, artifacts):
assert len(self.select.variables) == 1
assert next(iter(self.select.variables)) == 'prompt'
assert 'prompt' in self.select.variables
return super().__post_init__(db, artifacts)

@property
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/misc/special_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def merge(self, d, inplace=False):
@property
def variables(self):
"""List of variables in the object."""
return _find_variables(self)
return sorted(list(set(_find_variables(self))))

def info(self):
"""Print the serialized object."""
Expand Down
37 changes: 24 additions & 13 deletions superduperdb/rest/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def db_artifact_store_get_bytes(file_id: str):

@app.add('/db/apply', method='post')
def db_apply(info: t.Dict):
if 'variables' in info:
assert {'variables', 'template_body', 'identifier'}.issubset(info.keys())
component = Component.from_template(
identifier=info['identifier'],
template_body=info['template_body'],
**info['variables'],
)
app.db.apply(component)
for k in info['variables']:
assert '<' not in info['variables'][k]
assert '>' not in info['variables'][k]
assert ' ' not in info['variables'][k]
return {'status': 'ok'}
component = Document.decode(info).unpack()
app.db.apply(component)
return {'status': 'ok'}
Expand All @@ -74,19 +87,17 @@ def db_show(
type_id: t.Optional[str] = None,
identifier: t.Optional[str] = None,
version: t.Optional[int] = None,
application: t.Optional[str] = None,
):
return app.db.show(
type_id=type_id,
identifier=identifier,
version=version,
)

@app.add('/db/apply_template', method='post')
def db_apply_template(info: t.Dict):
assert {'variables', 'template_body', 'identifier'}.issubset(info.keys())
component = Component.from_template(**info)
app.db.apply(component)
return {'status': 'ok'}
if application is not None:
r = app.db.metadata.get_component('application', application)
return r['namespace']
else:
return app.db.show(
type_id=type_id,
identifier=identifier,
version=version,
)

@app.add('/db/remove', method='post')
def db_remove(type_id: str, identifier: str):
Expand All @@ -101,7 +112,7 @@ def db_show_template(identifier: str):
return {
'identifier': '<Please enter a unique name for this application>',
'variables': {k: '<Please enter a value>' for k in template['variables']},
'template_body': template['template_body'],
'template_body': template['template'],
}

@app.add('/db/metadata/show_jobs', method='get')
Expand Down

0 comments on commit c95da41

Please sign in to comment.