Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/vision and search interface #20

Merged
merged 6 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ruff==0.1.7
pytest==7.4.3
twine==4.0.2
twine==5.1.1
wheel==0.42.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
version=about["__version__"],
author=about["__author__"],
author_email=about["__email__"],
license=about["__license__"],
description="VideoDB Python SDK",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
3 changes: 2 additions & 1 deletion videodb/__about__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
""" About information for videodb sdk"""


__version__ = "0.2.0"
__version__ = "0.2.1"
__title__ = "videodb"
__author__ = "videodb"
__email__ = "contact@videodb.io"
__url__ = "https://github.com/video-db/videodb-python"
__license__ = "Apache License 2.0"
2 changes: 2 additions & 0 deletions videodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from videodb._utils._video import play_stream
from videodb._constants import (
VIDEO_DB_API,
IndexType,
SceneExtractionType,
MediaType,
SearchType,
Expand All @@ -30,6 +31,7 @@
"VideodbError",
"AuthenticationError",
"InvalidRequestError",
"IndexType",
"SearchError",
"play_stream",
"MediaType",
Expand Down
6 changes: 4 additions & 2 deletions videodb/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ class SearchType:


class IndexType:
semantic = "semantic"
spoken_word = "spoken_word"
scene = "scene"


class SceneExtractionType:
scene_based = "scene"
shot_based = "shot"
time_based = "time"


Expand Down Expand Up @@ -58,6 +58,8 @@ class ApiPath:
invoices = "invoices"
scenes = "scenes"
scene = "scene"
frame = "frame"
describe = "describe"


class Status:
Expand Down
4 changes: 4 additions & 0 deletions videodb/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from videodb._constants import (
ApiPath,
IndexType,
SearchType,
)
from videodb.video import Video
Expand Down Expand Up @@ -100,6 +101,7 @@ def search(
self,
query: str,
search_type: Optional[str] = SearchType.semantic,
index_type: Optional[str] = IndexType.spoken_word,
result_threshold: Optional[int] = None,
score_threshold: Optional[float] = None,
dynamic_score_percentage: Optional[float] = None,
Expand All @@ -108,6 +110,8 @@ def search(
return search.search_inside_collection(
collection_id=self.id,
query=query,
search_type=search_type,
index_type=index_type,
result_threshold=result_threshold,
score_threshold=score_threshold,
dynamic_score_percentage=dynamic_score_percentage,
Expand Down
8 changes: 8 additions & 0 deletions videodb/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ def to_json(self):
"frame_time": self.frame_time,
"description": self.description,
}

def describe(self, prompt: str = None, model_name=None):
description_data = self._connection.post(
path=f"{ApiPath.video}/{self.video_id}/{ApiPath.frame}/{self.id}/{ApiPath.describe}",
data={"prompt": prompt, "model_name": model_name},
)
self.description = description_data.get("description", None)
return self.description
12 changes: 12 additions & 0 deletions videodb/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def __init__(
description: str,
id: str = None,
frames: List[Frame] = [],
connection=None,
):
self.id = id
self.video_id = video_id
self.start = start
self.end = end
self.frames: List[Frame] = frames
self.description = description
self._connection = connection

def __repr__(self) -> str:
return (
Expand All @@ -43,6 +45,16 @@ def to_json(self):
"description": self.description,
}

def describe(self, prompt: str = None, model_name=None) -> None:
if self._connection is None:
raise ValueError("Connection is required to describe a scene")
description_data = self._connection.post(
path=f"{ApiPath.video}/{self.video_id}/{ApiPath.scene}/{self.id}/{ApiPath.describe}",
data={"prompt": prompt, "model_name": model_name},
)
self.description = description_data.get("description", None)
return self.description


class SceneCollection:
def __init__(
Expand Down
21 changes: 17 additions & 4 deletions videodb/search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from videodb._utils._video import play_stream
from videodb._constants import (
IndexType,
SearchType,
ApiPath,
SemanticSearchDefaultValues,
Expand Down Expand Up @@ -109,6 +110,8 @@ def search_inside_video(
self,
video_id: str,
query: str,
search_type: str,
index_type: str,
result_threshold: Optional[int] = None,
score_threshold: Optional[float] = None,
dynamic_score_percentage: Optional[float] = None,
Expand All @@ -117,7 +120,8 @@ def search_inside_video(
search_data = self._connection.post(
path=f"{ApiPath.video}/{video_id}/{ApiPath.search}",
data={
"index_type": SearchType.semantic,
"search_type": search_type,
"index_type": index_type,
"query": query,
"score_threshold": score_threshold
or SemanticSearchDefaultValues.score_threshold,
Expand All @@ -133,6 +137,8 @@ def search_inside_collection(
self,
collection_id: str,
query: str,
search_type: str,
index_type: str,
result_threshold: Optional[int] = None,
score_threshold: Optional[float] = None,
dynamic_score_percentage: Optional[float] = None,
Expand All @@ -141,7 +147,8 @@ def search_inside_collection(
search_data = self._connection.post(
path=f"{ApiPath.collection}/{collection_id}/{ApiPath.search}",
data={
"index_type": SearchType.semantic,
"search_type": search_type,
"index_type": index_type,
"query": query,
"score_threshold": score_threshold
or SemanticSearchDefaultValues.score_threshold,
Expand All @@ -162,6 +169,8 @@ def search_inside_video(
self,
video_id: str,
query: str,
search_type: str,
index_type: str,
result_threshold: Optional[int] = None,
score_threshold: Optional[float] = None,
dynamic_score_percentage: Optional[float] = None,
Expand All @@ -170,7 +179,8 @@ def search_inside_video(
search_data = self._connection.post(
path=f"{ApiPath.video}/{video_id}/{ApiPath.search}",
data={
"index_type": SearchType.keyword,
"search_type": search_type,
"index_type": index_type,
"query": query,
"score_threshold": score_threshold,
"result_threshold": result_threshold,
Expand All @@ -190,6 +200,8 @@ def search_inside_video(
self,
video_id: str,
query: str,
search_type: str,
index_type: str,
result_threshold: Optional[int] = None,
score_threshold: Optional[float] = None,
dynamic_score_percentage: Optional[float] = None,
Expand All @@ -198,7 +210,8 @@ def search_inside_video(
search_data = self._connection.post(
path=f"{ApiPath.video}/{video_id}/{ApiPath.search}",
data={
"index_type": SearchType.scene,
"search_type": search_type,
"index_type": IndexType.scene,
"query": query,
"score_threshold": score_threshold,
"result_threshold": result_threshold,
Expand Down
30 changes: 21 additions & 9 deletions videodb/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def search(
self,
query: str,
search_type: Optional[str] = SearchType.semantic,
index_type: Optional[str] = IndexType.spoken_word,
result_threshold: Optional[int] = None,
score_threshold: Optional[float] = None,
dynamic_score_percentage: Optional[float] = None,
Expand All @@ -58,6 +59,8 @@ def search(
return search.search_inside_video(
video_id=self.id,
query=query,
search_type=search_type,
index_type=index_type,
result_threshold=result_threshold,
score_threshold=score_threshold,
dynamic_score_percentage=dynamic_score_percentage,
Expand Down Expand Up @@ -152,7 +155,7 @@ def index_spoken_words(
self._connection.post(
path=f"{ApiPath.video}/{self.id}/{ApiPath.index}",
data={
"index_type": IndexType.semantic,
"index_type": IndexType.spoken_word,
"language_code": language_code,
"force": force,
"callback_url": callback_url,
Expand Down Expand Up @@ -194,6 +197,7 @@ def _format_scene_collection(self, scene_collection_data: dict) -> SceneCollecti
description=scene.get("description"),
id=scene.get("scene_id"),
frames=frames,
connection=self._connection,
)
scenes.append(scene)

Expand All @@ -207,11 +211,11 @@ def _format_scene_collection(self, scene_collection_data: dict) -> SceneCollecti

def extract_scenes(
self,
extraction_type: SceneExtractionType = SceneExtractionType.scene_based,
extraction_type: SceneExtractionType = SceneExtractionType.shot_based,
extraction_config: dict = {},
force: bool = False,
callback_url: str = None,
):
) -> Optional[SceneCollection]:
scenes_data = self._connection.post(
path=f"{ApiPath.video}/{self.id}/{ApiPath.scenes}",
data={
Expand All @@ -225,10 +229,14 @@ def extract_scenes(
return None
return self._format_scene_collection(scenes_data.get("scene_collection"))

def get_scene_collection(self, collection_id: str):
def get_scene_collection(self, collection_id: str) -> Optional[SceneCollection]:
if not collection_id:
raise ValueError("collection_id is required")
scenes_data = self._connection.get(
path=f"{ApiPath.video}/{self.id}/{ApiPath.scenes}/{collection_id}"
)
if not scenes_data:
return None
return self._format_scene_collection(scenes_data.get("scene_collection"))

def list_scene_collection(self):
Expand All @@ -238,29 +246,31 @@ def list_scene_collection(self):
return scene_collections_data.get("scene_collections", [])

def delete_scene_collection(self, collection_id: str) -> None:
if not collection_id:
raise ValueError("collection_id is required")
self._connection.delete(
path=f"{ApiPath.video}/{self.id}/{ApiPath.scenes}/{collection_id}"
)

def index_scenes(
self,
extraction_type: SceneExtractionType = SceneExtractionType.scene_based,
extraction_type: SceneExtractionType = SceneExtractionType.shot_based,
extraction_config: Dict = {},
prompt: Optional[str] = None,
model: Optional[str] = None,
model_name: Optional[str] = None,
model_config: Optional[Dict] = None,
name: Optional[str] = None,
scenes: Optional[List[Scene]] = None,
force: Optional[bool] = False,
callback_url: Optional[str] = None,
) -> Optional[List]:
) -> Optional[str]:
scenes_data = self._connection.post(
path=f"{ApiPath.video}/{self.id}/{ApiPath.index}/{ApiPath.scene}",
data={
"extraction_type": extraction_type,
"extraction_config": extraction_config,
"prompt": prompt,
"model": model,
"model_name": model_name,
"model_config": model_config,
"name": name,
"force": force,
Expand All @@ -270,7 +280,7 @@ def index_scenes(
)
if not scenes_data:
return None
return scenes_data.get("scene_index_records", [])
return scenes_data.get("scene_index_id")

def list_scene_index(self) -> List:
index_data = self._connection.get(
Expand All @@ -287,6 +297,8 @@ def get_scene_index(self, scene_index_id: str) -> Optional[List]:
return index_data.get("scene_index_records", [])

def delete_scene_index(self, scene_index_id: str) -> None:
if not scene_index_id:
raise ValueError("scene_index_id is required")
self._connection.delete(
path=f"{ApiPath.video}/{self.id}/{ApiPath.index}/{ApiPath.scene}/{scene_index_id}"
)
Expand Down
Loading