Skip to content

Commit

Permalink
Optimized the Eager mode.
Browse files Browse the repository at this point in the history
- Support eager mode condition filter

- Add eager mode tutorial

- Use select.execute instead of select.datas to initiate eager mode.
  • Loading branch information
jieguangzhou committed Jul 25, 2024
1 parent d176929 commit 78f4ede
Show file tree
Hide file tree
Showing 10 changed files with 1,187 additions and 137 deletions.
543 changes: 543 additions & 0 deletions docs/content/tutorials/eager_mode.ipynb

Large diffs are not rendered by default.

334 changes: 334 additions & 0 deletions docs/content/tutorials/eager_mode.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/sidebars.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const sidebars = {
'tutorials/rag',
'tutorials/training',
'tutorials/custom_serialization',
'tutorials/eager_mode',
]
},
{
Expand Down
49 changes: 26 additions & 23 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,27 +176,6 @@ def __getitem__(self, item):
parts = self.parts[item]
return type(self)(db=self.db, table=self.table, parts=parts)

def datas(self, n: int = 100, unpack=True):
"""Return the data of the query.
:param n: The number of data to return.
:param unpack: Whether to unpack the data.
"""
from superduper.misc.eager import SuperDuperData, SuperDuperDataType

assert self.db is not None, 'No datalayer (db) provided'
query = self
if not len(query.parts):
query = query.select()

datas = []
for data in query.limit(n).execute():
if unpack:
data = Document(data.unpack())
sdd = SuperDuperData(data, type=SuperDuperDataType.DATA, query=query)
datas.append(sdd)
return datas

# TODO - not necessary: either `Document.decode(r, db=db)`
# or `db['table'].select...`
def set_db(self, db: 'Datalayer'):
Expand Down Expand Up @@ -492,14 +471,38 @@ def _execute(self, parent, method='encode'):
def _create_table_if_not_exists(self):
pass

def execute(self, db=None, **kwargs):
def execute(self, db=None, eager_mode=False, **kwargs):
"""
Execute the query.
:param db: Datalayer instance.
"""
self.db = db or self.db
return self.db.execute(self, **kwargs)
results = self.db.execute(self, **kwargs)
if eager_mode and self.type == 'select':
results = self._convert_eager_mode_results(results)
return results

def _convert_eager_mode_results(self, results):
from superduper.base.cursor import SuperDuperCursor
from superduper.misc.eager import SuperDuperData, SuperDuperDataType

new_results = []
query = self
if not len(query.parts):
query = query.select()
if isinstance(results, (SuperDuperCursor, list)):
for r in results:
r = Document(r.unpack())
sdd = SuperDuperData(r, type=SuperDuperDataType.DATA, query=query)
new_results.append(sdd)

return new_results

elif isinstance(results, dict):
return SuperDuperData(results, type=SuperDuperDataType.DATA, query=query)

raise ValueError(f'Cannot convert {results} to eager mode results')

def do_execute(self, db=None):
"""
Expand Down
114 changes: 93 additions & 21 deletions superduper/backends/mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,29 +544,45 @@ def model_update(
output_query.updated_key = predict_id
return output_query

def _replace_part(self, part_name, replace_function):
parts = copy.deepcopy(self.parts)

for i, part in enumerate(parts):
if part[0] == part_name:
parts[i] = replace_function(part)
break

return type(self)(
db=self.db,
table=self.table,
parts=parts,
)

def filter(self, filter):
"""Return a query that filters the documents.
:param filter: The filter to apply.
"""

def replace_function(part):
find_params = part[1]
if not find_params:
find_params = ({},)

assert isinstance(find_params[0], dict)

find_params[0].update(filter)
return (part[0], tuple(find_params), part[2])

return self._replace_part('find', replace_function)


class MongoOutputs(MongoQuery):
"""A query class for MongoDB outputs.
Use aggregate to implement the outputs query
"""

@applies_to('find')
def outputs(self, *predict_ids):
"""Return a query that selects the outputs of the given predict ids.
:param predict_ids: The ids of the predictions to select.
"""
x = MongoOutputs(
db=self.db,
table=self.table,
parts=[
*copy.deepcopy(self.parts),
('outputs', (*predict_ids,), {}),
],
)
return x

def _get_method_parameters(self, method):
args, kwargs = (), {}
for part in self.parts:
Expand All @@ -579,7 +595,6 @@ def _get_method_parameters(self, method):

def _execute(self, parent, method='encode'):
find_params, _ = self._get_method_parameters('find')
filter = {"$match": find_params[0]} if find_params else {}
project = copy.deepcopy(find_params[1]) if len(find_params) > 1 else {"*": 1}
project['_schema'] = 1
project['_builds'] = 1
Expand All @@ -589,9 +604,20 @@ def _execute(self, parent, method='encode'):
limit_args, _ = self._get_method_parameters('limit')
limit = {"$limit": limit_args[0]} if limit_args else None

predict_ids = self.parts[-1][1]
outputs_parts = [p for p in self.parts if p[0] == 'outputs']
predict_ids = sum([p[1] for p in outputs_parts], ())

pipeline = []
filter_mapping_base, filter_mapping_outptus = self._get_filter_mapping()
if filter_mapping_base:
pipeline.append({"$match": filter_mapping_base})
project.update({k: 1 for k in filter_mapping_base.keys()})

predict_ids_in_filter = list(filter_mapping_outptus.keys())

predict_ids = list(set(predict_ids).union(predict_ids_in_filter))
# After the join, the complete outputs data can be queried as
# _outputs__{predict_id}._outputs.{predict_id} : result.
for predict_id in predict_ids:
# MongoMock does not support '.' in 'as', so we replace it with '__'
key = f'_outputs.{predict_id}'.replace('.', '__')
Expand All @@ -606,13 +632,17 @@ def _execute(self, parent, method='encode'):

project[key] = 1
pipeline.append(lookup)

if predict_id in filter_mapping_outptus:
filter_key, filter_value = list(
filter_mapping_outptus[predict_id].items()
)[0]
pipeline.append({"$match": {f'{key}.{filter_key}': filter_value}})

pipeline.append(
{"$unwind": {"path": f"${key}", "preserveNullAndEmptyArrays": True}}
)

if filter:
pipeline = [filter] + pipeline

if project:
pipeline.append({"$project": project})

Expand All @@ -626,6 +656,26 @@ def _execute(self, parent, method='encode'):
process_func=self._postprocess_result,
)

def _get_filter_mapping(self):
find_params, _ = self._get_method_parameters('find')
filter = find_params[0] if find_params else {}
if not filter:
return {}, {}

filter_mapping_base = {}
filter_mapping_outputs = defaultdict(dict)

for key, value in filter.items():
if '_outputs.' not in key:
filter_mapping_base[key] = value
continue

if key.startswith('_outputs.'):
predict_id = key.split('.')[1]
filter_mapping_outputs[predict_id] = {key: value}

return filter_mapping_base, filter_mapping_outputs

def _postprocess_result(self, result):
"""Postprocess the result of the query.
Expand Down Expand Up @@ -669,6 +719,28 @@ def _postprocess_result(self, result):
result['_blobs'] = merge_blobs
return result

@property
@applies_to('find')
def select_ids(self):
"""Select the ids of the documents."""

def replace_id(part):
find_params = part[1]
if not find_params:
find_params = ({}, {})

assert isinstance(find_params[1], dict)
find_params = list(find_params)
find_params[1]['_id'] = 1
return (part[0], tuple(find_params), part[2])

def replace_outputs(part):
return (part[0], (), part[2])

query = self._replace_part('find', replace_id)
query = query._replace_part('outputs', replace_outputs)
return query


def InsertOne(**kwargs):
"""InsertOne operation for MongoDB.
Expand Down
6 changes: 2 additions & 4 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,17 +980,15 @@ def _eager_call__(self, *args, **kwargs):
upstream_var = SuperDuperData(
upstream_var, type=SuperDuperDataType.CONSTANT
)
upstream_mapping[upstream_var.source].append(
TrackData(index, upstream_var.key)
)
upstream_mapping[upstream_var.source].append(TrackData(index, upstream_var))

for k, upstream_var in kwargs.items():
if not isinstance(upstream_var, SuperDuperData):
upstream_var = SuperDuperData(
upstream_var, type=SuperDuperDataType.CONSTANT
)

upstream_mapping[upstream_var.source].append(TrackData(k, upstream_var.key))
upstream_mapping[upstream_var.source].append(TrackData(k, upstream_var))

for track_result in track_results:
for upstream_node, track_datas in upstream_mapping.items():
Expand Down
Loading

0 comments on commit 78f4ede

Please sign in to comment.