Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add select missing ids in mongodb plugin
Browse files Browse the repository at this point in the history
kartik4949 committed Nov 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent a392ee2 commit 9fdd567
Showing 3 changed files with 51 additions and 1 deletion.
17 changes: 17 additions & 0 deletions plugins/mongodb/plugin_test/test_queries.py
Original file line number Diff line number Diff line change
@@ -269,3 +269,20 @@ def test_replace_one(db):
doc = db.execute(collection.find_one({'_id': r['_id']}))
print(doc['x'])
assert doc.unpack()['x'].tolist() == new_x.tolist()


def test_select_missing_ids(db):
db.cfg.auto_schema = True
add_random_data(db, n=5)
add_models(db)
add_vector_index(db)
out = db.load('listener', 'vector-x').outputs
doc = list(db[out].select().execute())[0]
source_id = doc['_source']
db.databackend._db[out].delete_one({'_source': source_id})

predict_id = out.split('_outputs__')[-1]
query = db['documents'].select_ids_of_missing_outputs(predict_id)
x = list(query.execute())
assert len(x) == 1
assert source_id == x[0]['_id']
1 change: 1 addition & 0 deletions plugins/mongodb/plugin_test/test_query.py
Original file line number Diff line number Diff line change
@@ -72,6 +72,7 @@ def test_select_missing_outputs(db):
)
)
select = MongoQuery(table='documents').find({}, {'_id': 1})
select.db = db
modified_select = select.select_ids_of_missing_outputs('x::test_model_output::0::0')
out = list(db.execute(modified_select))
assert len(out) == (len(docs) - len(ids))
34 changes: 33 additions & 1 deletion plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
@@ -103,6 +103,7 @@ class MongoQuery(Query):
'post_like': r'^.*\.(find|select)\(.*\)\.like(.*)$',
'bulk_write': r'^.*\.bulk_write\(.*\)$',
'outputs': r'^.*\.outputs\(.*\)',
'missing_outputs': r'^.*\.missing_outputs\(.*\)$',
'find_one': r'^.*\.find_one\(.*\)',
'find': r'^.*\.find\(.*\)',
'select': r'^.*\.select\(.*\)$',
@@ -416,7 +417,38 @@ def select_ids_of_missing_outputs(self, predict_id: str):
:param predict_id: The id of the prediction.
"""
return self.select_ids
return self.missing_outputs(predict_id=predict_id, ids_only=1)

def _execute_missing_outputs(self, parent):
"""Select the documents that are missing the given output."""
if len(self.parts[-1][2]) == 0:
raise ValueError("Predict id is required")
predict_id = self.parts[-1][2]["predict_id"]
ids_only = self.parts[-1][2].get('ids_only', False)

key = f'{CFG.output_prefix}{predict_id}'
lookup = {
"$lookup": {
"from": key,
"localField": "_id",
"foreignField": "_source",
"as": key,
},
"$match": {key: {"$size": 0}},
}

raw_cursor = getattr(parent, 'aggregate')([lookup])

def get_ids(result):
return {"_id": result["_id"]}

return SuperDuperCursor(
raw_cursor=raw_cursor,
db=self.db,
id_field='_id',
process_func=get_ids if ids_only else self._postprocess_result,
schema=self._get_schema(),
)

@property
@applies_to('find')

0 comments on commit 9fdd567

Please sign in to comment.