Skip to content

Commit

Permalink
Add select missing ids in mongodb plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Nov 29, 2024
1 parent 78199ff commit 510226b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
17 changes: 17 additions & 0 deletions plugins/mongodb/plugin_test/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,20 @@ def test_replace_one(db):
doc = db.execute(collection.find_one({'_id': r['_id']}))
print(doc['x'])
assert doc.unpack()['x'].tolist() == new_x.tolist()


def test_select_missing_ids(db):
db.cfg.auto_schema = True
add_random_data(db, n=5)
add_models(db)
add_vector_index(db)
out = db.load('listener', 'vector-x').outputs
doc = list(db[out].select().execute())[0]
source_id = doc['_source']
db.databackend._db[out].delete_one({'_source': source_id})

predict_id = out.split('_outputs__')[-1]
query = db['documents'].select_ids_of_missing_outputs(predict_id)
x = list(query.execute())
assert len(x) == 1
assert source_id == x[0]['_id']
1 change: 1 addition & 0 deletions plugins/mongodb/plugin_test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_select_missing_outputs(db):
)
)
select = MongoQuery(table='documents').find({}, {'_id': 1})
select.db = db
modified_select = select.select_ids_of_missing_outputs('x::test_model_output::0::0')
out = list(db.execute(modified_select))
assert len(out) == (len(docs) - len(ids))
Expand Down
37 changes: 36 additions & 1 deletion plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class MongoQuery(Query):
'post_like': r'^.*\.(find|select)\(.*\)\.like(.*)$',
'bulk_write': r'^.*\.bulk_write\(.*\)$',
'outputs': r'^.*\.outputs\(.*\)',
'missing_outputs': r'^.*\.missing_outputs\(.*\)$',
'find_one': r'^.*\.find_one\(.*\)',
'find': r'^.*\.find\(.*\)',
'select': r'^.*\.select\(.*\)$',
Expand Down Expand Up @@ -416,7 +417,41 @@ def select_ids_of_missing_outputs(self, predict_id: str):
:param predict_id: The id of the prediction.
"""
return self.select_ids
return self.missing_outputs(predict_id=predict_id, ids_only=1)

def _execute_missing_outputs(self, parent):
"""Select the documents that are missing the given output."""
if len(self.parts[-1][2]) == 0:
raise ValueError("Predict id is required")
predict_id = self.parts[-1][2]["predict_id"]
ids_only = self.parts[-1][2].get('ids_only', False)

key = f'{CFG.output_prefix}{predict_id}'

lookup = [
{
'$lookup': {
'from': key,
'localField': '_id',
'foreignField': '_source',
'as': key,
}
},
{'$match': {key: {'$size': 0}}},
]

raw_cursor = getattr(parent, 'aggregate')(lookup)

def get_ids(result):
return {"_id": result["_id"]}

return SuperDuperCursor(
raw_cursor=raw_cursor,
db=self.db,
id_field='_id',
process_func=get_ids if ids_only else self._postprocess_result,
schema=self._get_schema(),
)

@property
@applies_to('find')
Expand Down

0 comments on commit 510226b

Please sign in to comment.