Skip to content

Commit

Permalink
address some PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hunterjm committed Jun 16, 2024
1 parent 9ec66fc commit 9cddb34
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 117 deletions.
2 changes: 1 addition & 1 deletion docker/main/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ARG SLIM_BASE=debian:11-slim

FROM ${BASE_IMAGE} AS base

FROM --platform=${BUILDPLATFORM} ${BASE_IMAGE} AS base_host
FROM --platform=${BUILDPLATFORM} debian:11 AS base_host

FROM ${SLIM_BASE} AS slim-base

Expand Down
13 changes: 9 additions & 4 deletions docker/main/build_pysqlite3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

set -euxo pipefail

SQLITE3_VERSION="96c92aba00c8375bc32fafcdf12429c58bd8aabfcadab6683e35bbb9cdebf19e" # 3.46.0
PYSQLITE3_VERSION="0.5.3"

# Fetch the source code for the latest release of Sqlite.
if [[ ! -d "sqlite" ]]; then
wget https://www.sqlite.org/src/tarball/sqlite.tar.gz?r=release -O sqlite.tar.gz
wget https://www.sqlite.org/src/tarball/sqlite.tar.gz?r=${SQLITE3_VERSION} -O sqlite.tar.gz
tar xzf sqlite.tar.gz
cd sqlite/
LIBS="-lm" ./configure --disable-tcl --enable-tempstore=always
Expand All @@ -18,13 +21,15 @@ if [[ ! -d "./pysqlite3" ]]; then
git clone https://github.com/coleifer/pysqlite3.git
fi

cd pysqlite3/
git checkout ${PYSQLITE3_VERSION}

# Copy the sqlite3 source amalgamation into the pysqlite3 directory so we can
# create a self-contained extension module.
cp "sqlite/sqlite3.c" pysqlite3/
cp "sqlite/sqlite3.h" pysqlite3/
cp "../sqlite/sqlite3.c" ./
cp "../sqlite/sqlite3.h" ./

# Create the wheel and put it in the /wheels dir.
cd pysqlite3/
sed -i "s|name='pysqlite3-binary'|name=PACKAGE_NAME|g" setup.py
python3 setup.py build_static
pip3 wheel . -w /wheels
4 changes: 2 additions & 2 deletions frigate/comms/events_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class EventEndPublisher(Publisher):
topic_base = "event/"

def __init__(self) -> None:
super().__init__("ended")
super().__init__("finalized")

def publish(
self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]
Expand All @@ -48,4 +48,4 @@ class EventEndSubscriber(Subscriber):
topic_base = "event/"

def __init__(self) -> None:
super().__init__("ended")
super().__init__("finalized")
3 changes: 3 additions & 0 deletions frigate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,9 @@ class GenAIConfig(FrigateBaseModel):
title="Default caption prompt.",
)
object_prompts: Dict[str, str] = Field(default={}, title="Object specific prompts.")
reindex: Optional[bool] = Field(
default=False, title="Reindex all detections on startup."
)


class GenAICameraConfig(FrigateBaseModel):
Expand Down
3 changes: 1 addition & 2 deletions frigate/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:
embeddings = Embeddings()

# Check if we need to re-index events
if os.path.exists(f"{CONFIG_DIR}/.reindex"):
if config.genai.reindex:
embeddings.reindex()
os.remove(f"{CONFIG_DIR}/.reindex")

maintainer = EmbeddingMaintainer(
config,
Expand Down
4 changes: 0 additions & 4 deletions frigate/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import base64
import io
import logging
import os
import time

import numpy as np
Expand All @@ -13,7 +12,6 @@
from PIL import Image
from playhouse.shortcuts import model_to_dict

from frigate.const import CONFIG_DIR
from frigate.models import Event

from .functions.clip import ClipEmbedding
Expand Down Expand Up @@ -112,5 +110,3 @@ def reindex(self) -> None:
len(descriptions["ids"]),
time.time() - st,
)

os.remove(f"{CONFIG_DIR}/.reindex_events")
144 changes: 75 additions & 69 deletions frigate/embeddings/maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,82 +46,88 @@ def __init__(
def run(self) -> None:
"""Maintain a Chroma vector database for semantic search."""
while not self.stop_event.is_set():
update = self.event_subscriber.check_for_update()
self._process_updates()
self._process_finalized()

if update is None:
continue
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update()

source_type, _, camera, data = update
if update is None:
return

source_type, _, camera, data = update

if not camera or source_type != EventTypeEnum.tracked_object:
return

camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []

if camera and source_type == EventTypeEnum.tracked_object:
camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []
# Create our own thumbnail based on the bounding box and the frame time
try:
frame_id = f"{camera}{data['frame_time']}"
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id)
except FileNotFoundError:
pass

# Create our own thumbnail based on the bounding box and the frame time
def _process_finalized(self) -> None:
"""Process the end of an event."""
while True:
ended = self.event_end_subscriber.check_for_update()

if ended == None:
break

event_id, camera, updated_db = ended
camera_config = self.config.cameras[camera]

if updated_db:
try:
frame_id = f"{camera}{data['frame_time']}"
yuv_frame = self.frame_manager.get(
frame_id, camera_config.frame_shape_yuv
)
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id)
except FileNotFoundError:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
continue

# Skip the event if not an object
if event.data.get("type") != "object":
continue

# Embed thumbnails when an event ends
while True:
ended = self.event_end_subscriber.check_for_update()

if ended == None:
break

event_id, camera, updated_db = ended
camera_config = self.config.cameras[camera]

if updated_db:
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
continue

# Skip the event if not an object
if event.data.get("type") != "object":
continue

# Extract valid event metadata
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)

# Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail, metadata)

if (
camera_config.genai.enabled
and self.genai_client is not None
and event.data.get("description") is None
):
# Generate the description. Call happens in a thread since it is network bound.
threading.Thread(
target=self._embed_description,
name=f"_embed_description_{event.id}",
daemon=True,
args=(
event,
[
data["thumbnail"]
for data in self.tracked_events.get(
event_id, [{"thumbnail": thumbnail}]
)
],
metadata,
),
).start()

# Delete tracked events based on the event_id
if event_id in self.tracked_events:
del self.tracked_events[event_id]
# Extract valid event metadata
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)

# Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail, metadata)

if (
camera_config.genai.enabled
and self.genai_client is not None
and event.data.get("description") is None
):
# Generate the description. Call happens in a thread since it is network bound.
threading.Thread(
target=self._embed_description,
name=f"_embed_description_{event.id}",
daemon=True,
args=(
event,
[
data["thumbnail"]
for data in self.tracked_events[event_id]
]
if len(self.tracked_events.get(event_id, [])) > 0
else [thumbnail],
metadata,
),
).start()

# Delete tracked events based on the event_id
if event_id in self.tracked_events:
del self.tracked_events[event_id]

def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
"""Return jpg thumbnail of a region of the frame."""
Expand Down
13 changes: 8 additions & 5 deletions frigate/events/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,13 @@ def run(self) -> None:
)
events_to_delete = [e.id for e in events]
if len(events_to_delete) > 0:
Event.delete().where(Event.id << events_to_delete).execute()

if self.config.semantic_search.enabled:
self.embeddings.thumbnail.delete(ids=events_to_delete)
self.embeddings.description.delete(ids=events_to_delete)
chunk_size = 50
for i in range(0, len(events_to_delete), chunk_size):
chunk = events_to_delete[i : i + chunk_size]
Event.delete().where(Event.id << chunk).execute()

if self.config.semantic_search.enabled:
self.embeddings.thumbnail.delete(ids=chunk)
self.embeddings.description.delete(ids=chunk)

logger.info("Exiting event cleanup...")
7 changes: 5 additions & 2 deletions frigate/genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def decorator(cls):
class GenAIClient:
"""Generative AI client for Frigate."""

def __init__(self, genai_config: GenAIConfig) -> None:
def __init__(self, genai_config: GenAIConfig, timeout: int = 60) -> None:
self.genai_config: GenAIConfig = genai_config
self.timeout = timeout
self.provider = self._init_provider()

def generate_description(self, thumbnails: list[bytes], metadata: dict[str, any]):
def generate_description(
self, thumbnails: list[bytes], metadata: dict[str, any]
) -> Optional[str]:
"""Generate a description for the frame."""
prompt = self.genai_config.object_prompts.get(
metadata["label"], self.genai_config.prompt
Expand Down
19 changes: 13 additions & 6 deletions frigate/genai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import google.generativeai as genai
from google.api_core.exceptions import DeadlineExceeded

from frigate.config import GenAIProviderEnum
from frigate.genai import GenAIClient, register_genai_provider
Expand All @@ -28,12 +29,18 @@ def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
}
for img in images
] + [prompt]
response = self.provider.generate_content(
data,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
),
)
try:
response = self.provider.generate_content(
data,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
),
request_options=genai.types.RequestOptions(
timeout=self.timeout,
),
)
except DeadlineExceeded:
return None
try:
description = response.text.strip()
except ValueError:
Expand Down
7 changes: 4 additions & 3 deletions frigate/genai/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Optional

from httpx import TimeoutException
from ollama import Client as ApiClient
from ollama import ResponseError

Expand All @@ -20,21 +21,21 @@ class OllamaClient(GenAIClient):

def _init_provider(self):
"""Initialize the client."""
client = ApiClient(host=self.genai_config.base_url)
client = ApiClient(host=self.genai_config.base_url, timeout=self.timeout)
response = client.pull(self.genai_config.model)
if response["status"] != "success":
logger.error("Failed to pull %s model from Ollama", self.genai_config.model)
return None
return client

def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to Ollama."""
"""Submit a request to Ollama"""
try:
result = self.provider.generate(
self.genai_config.model,
prompt,
images=images,
)
return result["response"].strip()
except ResponseError:
except (TimeoutException, ResponseError):
return None
Loading

0 comments on commit 9cddb34

Please sign in to comment.