Skip to content

Commit

Permalink
Semantic Search API (#12105)
Browse files Browse the repository at this point in the history
* initial event search api implementation

* fix lint

* fix tests

* move chromadb imports and pysqlite hotswap to fix tests

* remove unused import

* switch default limit to 50

* fix events accidently pulling inside chroma results loop
  • Loading branch information
hunterjm authored and NickM-27 committed Aug 9, 2024
1 parent ac895b6 commit 034ff99
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 23 deletions.
7 changes: 1 addition & 6 deletions frigate/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import faulthandler
import sys
import threading

from flask import cli

# Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from frigate.app import FrigateApp

faulthandler.enable()

Expand All @@ -15,8 +12,6 @@
cli.show_server_banner = lambda *x: None

if __name__ == "__main__":
from frigate.app import FrigateApp

frigate_app = FrigateApp()

frigate_app.start()
3 changes: 3 additions & 0 deletions frigate/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from frigate.api.review import ReviewBp
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor
from frigate.models import Event, Timeline
from frigate.plus import PlusApi
Expand Down Expand Up @@ -52,6 +53,7 @@
def create_app(
frigate_config,
database: SqliteQueueDatabase,
embeddings: EmbeddingsContext,
detected_frames_processor,
storage_maintainer: StorageMaintainer,
onvif: OnvifController,
Expand Down Expand Up @@ -79,6 +81,7 @@ def _db_close(exc):
database.close()

app.frigate_config = frigate_config
app.embeddings = embeddings
app.detected_frames_processor = detected_frames_processor
app.storage_maintainer = storage_maintainer
app.onvif = onvif
Expand Down
244 changes: 242 additions & 2 deletions frigate/api/event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Event apis."""

import base64
import io
import logging
import os
from datetime import datetime
Expand All @@ -8,20 +10,24 @@
from urllib.parse import unquote

import cv2
import numpy as np
from flask import (
Blueprint,
current_app,
jsonify,
make_response,
request,
)
from peewee import DoesNotExist, fn, operator
from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict

from frigate.const import (
CLIPS_DIR,
)
from frigate.models import Event, Timeline
from frigate.embeddings import EmbeddingsContext
from frigate.embeddings.embeddings import get_metadata
from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers

Expand Down Expand Up @@ -245,6 +251,189 @@ def events():
return jsonify(list(events))


@EventBp.route("/events/search")
def events_search():
query = request.args.get("query", type=str)
search_type = request.args.get("search_type", "text", type=str)
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
limit = request.args.get("limit", 50, type=int)

# Filters
cameras = request.args.get("cameras", "all", type=str)
labels = request.args.get("labels", "all", type=str)
zones = request.args.get("zones", "all", type=str)
after = request.args.get("after", type=float)
before = request.args.get("before", type=float)

if not query:
return make_response(
jsonify(
{
"success": False,
"message": "A search query must be supplied",
}
),
400,
)

if not current_app.frigate_config.semantic_search.enabled:
return make_response(
jsonify(
{
"success": False,
"message": "Semantic search is not enabled",
}
),
400,
)

context: EmbeddingsContext = current_app.embeddings

selected_columns = [
Event.id,
Event.camera,
Event.label,
Event.sub_label,
Event.zones,
Event.start_time,
Event.end_time,
Event.data,
ReviewSegment.thumb_path,
]

if include_thumbnails:
selected_columns.append(Event.thumbnail)

# Build the where clause for the embeddings query
embeddings_filters = []

if cameras != "all":
camera_list = cameras.split(",")
embeddings_filters.append({"camera": {"$in": camera_list}})

if labels != "all":
label_list = labels.split(",")
embeddings_filters.append({"label": {"$in": label_list}})

if zones != "all":
filtered_zones = zones.split(",")
zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
if len(zone_filters) > 1:
embeddings_filters.append({"$or": zone_filters})
else:
embeddings_filters.append(zone_filters[0])

if after:
embeddings_filters.append({"start_time": {"$gt": after}})

if before:
embeddings_filters.append({"start_time": {"$lt": before}})

where = None
if len(embeddings_filters) > 1:
where = {"$and": embeddings_filters}
elif len(embeddings_filters) == 1:
where = embeddings_filters[0]

thumb_ids = {}
desc_ids = {}

if search_type == "thumbnail":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == query)
except DoesNotExist:
return make_response(
jsonify(
{
"success": False,
"message": "Event not found",
}
),
404,
)
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=limit,
where=where,
)
thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
else:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)

results = {}
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}

if not results:
return jsonify([])

# Get the event data
events = (
Event.select(*selected_columns)
.join(
ReviewSegment,
JOIN.LEFT_OUTER,
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
)
.where(Event.id << list(results.keys()))
.dicts()
.iterator()
)
events = list(events)

events = [
{k: v for k, v in event.items() if k != "data"}
| {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
| {
"search_distance": results[event["id"]]["distance"],
"search_source": results[event["id"]]["source"],
}
for event in events
]
events = sorted(events, key=lambda x: x["search_distance"])[:limit]

return jsonify(events)


@EventBp.route("/events/summary")
def events_summary():
tz_name = request.args.get("timezone", default="utc", type=str)
Expand Down Expand Up @@ -604,6 +793,52 @@ def set_sub_label(id):
)


@EventBp.route("/events/<id>/description", methods=("POST",))
def set_description(id):
try:
event: Event = Event.get(Event.id == id)
except DoesNotExist:
return make_response(
jsonify({"success": False, "message": "Event " + id + " not found"}), 404
)

json: dict[str, any] = request.get_json(silent=True) or {}
new_description = json.get("description")

if new_description is None or len(new_description) == 0:
return make_response(
jsonify(
{
"success": False,
"message": "description cannot be empty",
}
),
400,
)

event.data["description"] = new_description
event.save()

# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.description.upsert(
documents=[new_description],
metadatas=[get_metadata(event)],
ids=[id],
)

return make_response(
jsonify(
{
"success": True,
"message": "Event " + id + " description set to " + new_description,
}
),
200,
)


@EventBp.route("/events/<id>", methods=("DELETE",))
def delete_event(id):
try:
Expand All @@ -625,6 +860,11 @@ def delete_event(id):

event.delete_instance()
Timeline.delete().where(Timeline.source_id == id).execute()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.thumbnail.delete(ids=[id])
context.embeddings.description.delete(ids=[id])
return make_response(
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
)
Expand Down
9 changes: 6 additions & 3 deletions frigate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
MODEL_CACHE_DIR,
RECORD_DIR,
)
from frigate.embeddings import manage_embeddings
from frigate.embeddings.embeddings import Embeddings
from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.events.audio import listen_to_audio
from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor
Expand Down Expand Up @@ -322,7 +321,7 @@ def init_review_segment_manager(self) -> None:

def init_embeddings_manager(self) -> None:
# Create a client for other processes to use
self.embeddings = Embeddings()
self.embeddings = EmbeddingsContext()
embedding_process = mp.Process(
target=manage_embeddings,
name="embeddings_manager",
Expand Down Expand Up @@ -384,6 +383,7 @@ def init_web_server(self) -> None:
self.flask_app = create_app(
self.config,
self.db,
self.embeddings,
self.detected_frames_processor,
self.storage_maintainer,
self.onvif_controller,
Expand Down Expand Up @@ -811,6 +811,9 @@ def stop(self) -> None:
self.frigate_watchdog.join()
self.db.stop()

# Save embeddings stats to disk
self.embeddings.save_stats()

# Stop Communicators
self.inter_process_communicator.stop()
self.inter_config_updater.stop()
Expand Down
Loading

0 comments on commit 034ff99

Please sign in to comment.