Skip to content

Commit

Permalink
Made executors optional in all Estimators (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
integraledelebesgue authored Dec 8, 2024
1 parent 9360383 commit 7e37a25
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 36 deletions.
1 change: 0 additions & 1 deletion child_lab_framework/_procedure/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def run(
)

visualizer = Visualizer(
None, # type: ignore
properties=video_properties,
configuration=Configuration(),
)
Expand Down
15 changes: 1 addition & 14 deletions child_lab_framework/_procedure/demo_sequential.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from concurrent.futures import ThreadPoolExecutor
from itertools import repeat
from pathlib import Path

Expand Down Expand Up @@ -33,8 +32,6 @@ def main(
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

executor = ThreadPoolExecutor(max_workers=8)

ceiling, window_left, window_right = inputs

ceiling_reader = Reader(
Expand All @@ -61,36 +58,32 @@ def main(
)
window_right_properties = window_right_reader.properties

depth_estimator = depth.Estimator(executor, device, input=ceiling_properties)
depth_estimator = depth.Estimator(device, input=ceiling_properties)

transformation_buffer = transformation_buffer or transformation.Buffer()

window_left_to_ceiling_transformation_estimator = heuristic_transformation.Estimator(
executor,
transformation_buffer,
window_left_properties,
ceiling_properties,
keypoint_threshold=0.35,
)

window_right_to_ceiling_transformation_estimator = heuristic_transformation.Estimator(
executor,
transformation_buffer,
window_right_properties,
ceiling_properties,
keypoint_threshold=0.35,
)

pose_estimator = pose.Estimator(
executor,
device,
input=ceiling_properties,
max_detections=2,
threshold=0.5,
)

face_estimator = face.Estimator(
executor,
# A workaround to use the model efficiently on both desktop and server.
# TODO: remove this as soon as it's possible to specify device per component via CLI/config file.
device if device == torch.device('cuda') else torch.device('cpu'),
Expand All @@ -100,17 +93,14 @@ def main(
)

window_left_gaze_estimator = gaze.Estimator(
executor,
input=window_left_properties,
)

window_right_gaze_estimator = gaze.Estimator(
executor,
input=window_right_properties,
)

ceiling_gaze_estimator = gaze.ceiling_projection.Estimator(
executor,
transformation_buffer,
ceiling_properties,
window_left_properties,
Expand All @@ -121,19 +111,16 @@ def main(
# social_distance_logger = social_distance.FileLogger('dev/output/distance.csv')

ceiling_visualizer = Visualizer(
executor,
properties=ceiling_properties,
configuration=VisualizationConfiguration(),
)

window_left_visualizer = Visualizer(
executor,
properties=window_left_properties,
configuration=VisualizationConfiguration(),
)

window_right_visualizer = Visualizer(
executor,
properties=window_right_properties,
configuration=VisualizationConfiguration(),
)
Expand Down
1 change: 0 additions & 1 deletion child_lab_framework/_procedure/estimate_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def run(
)

visualizer = visualization.Visualizer(
None, # type: ignore
properties=readers[0].properties,
configuration=visualization.Configuration(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class Estimator:
executor: ThreadPoolExecutor
executor: ThreadPoolExecutor | None
transformation_buffer: Buffer[str]

from_view: Properties
Expand All @@ -29,12 +29,12 @@ class Estimator:

def __init__(
self,
executor: ThreadPoolExecutor,
transformation_buffer: Buffer[str],
from_view: Properties,
to_view: Properties,
*,
keypoint_threshold: float = 0.25,
executor: ThreadPoolExecutor | None = None,
) -> None:
self.executor = executor

Expand Down Expand Up @@ -149,8 +149,13 @@ def __predict_safe(
async def stream(
self,
) -> Fiber[Input | None, list[Transformation | None] | None]:
loop = asyncio.get_running_loop()
executor = self.executor
if executor is None:
raise RuntimeError(
'Processing in the stream mode requires the Estimator to have an executor. Please pass an "executor" argument to the estimator constructor'
)

loop = asyncio.get_running_loop()

results: list[Transformation | None] | None = None

Expand Down
15 changes: 12 additions & 3 deletions child_lab_framework/task/depth/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def to_frame(depth_map: FloatArray2) -> Frame:
class Estimator:
MODEL_PATH = MODELS_DIR / 'depth_pro.pt'

executor: ThreadPoolExecutor
executor: ThreadPoolExecutor | None
device: torch.device

model: DepthPro
Expand All @@ -34,7 +34,11 @@ class Estimator:
from_model: Compose

def __init__(
self, executor: ThreadPoolExecutor, device: torch.device, *, input: Properties
self,
device: torch.device,
*,
input: Properties,
executor: ThreadPoolExecutor | None = None,
) -> None:
self.executor = executor
self.device = device
Expand Down Expand Up @@ -82,8 +86,13 @@ def predict(self, frame: Frame, properties: Properties) -> FloatArray2:
return depth

async def stream(self) -> Fiber[list[Frame] | None, list[FloatArray2] | None]:
loop = asyncio.get_running_loop()
executor = self.executor
if executor is None:
raise RuntimeError(
'Processing in the stream mode requires the Estimator to have an executor. Please pass an "executor" argument to the estimator constructor'
)

loop = asyncio.get_running_loop()

results: list[FloatArray2] | None = None

Expand Down
9 changes: 7 additions & 2 deletions child_lab_framework/task/face/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def visualize(


class Estimator:
executor: ThreadPoolExecutor
executor: ThreadPoolExecutor | None
device: torch.device

input: Properties
Expand All @@ -58,12 +58,12 @@ class Estimator:

def __init__(
self,
executor: ThreadPoolExecutor,
device: torch.device,
*,
input: Properties,
confidence_threshold: float,
suppression_threshold: float,
executor: ThreadPoolExecutor | None = None,
) -> None:
self.executor = executor
self.device = device
Expand Down Expand Up @@ -171,6 +171,11 @@ def __match_faces_with_actors(

async def stream(self) -> Fiber[Input | None, list[Result | None] | None]:
executor = self.executor
if executor is None:
raise RuntimeError(
'Processing in the stream mode requires the Estimator to have an executor. Please pass an "executor" argument to the estimator constructor'
)

loop = asyncio.get_running_loop()

results: list[Result | None] | None = None
Expand Down
10 changes: 8 additions & 2 deletions child_lab_framework/task/gaze/ceiling_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def visualize(


class Estimator:
executor: ThreadPoolExecutor
executor: ThreadPoolExecutor | None

transformation_buffer: transformation.Buffer[str]

Expand All @@ -76,11 +76,12 @@ class Estimator:

def __init__(
self,
executor: ThreadPoolExecutor,
transformation_buffer: transformation.Buffer[str],
ceiling_properties: Properties,
window_left_properties: Properties,
window_right_properties: Properties,
*,
executor: ThreadPoolExecutor | None = None,
) -> None:
self.executor = executor
self.transformation_buffer = transformation_buffer
Expand Down Expand Up @@ -204,6 +205,11 @@ def __predict_safe(
# NOTE: heuristic idea: actors seen from right and left are in reversed lexicographic order
async def stream(self) -> Fiber[Input | None, list[Result | None] | None]:
executor = self.executor
if executor is None:
raise RuntimeError(
'Processing in the stream mode requires the Estimator to have an executor. Please pass an "executor" argument to the estimator constructor'
)

loop = asyncio.get_running_loop()

results: list[Result | None] | None = None
Expand Down
9 changes: 7 additions & 2 deletions child_lab_framework/task/gaze/gaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,16 @@ class Estimator:

extractor: mf.gaze.Extractor

executor: ThreadPoolExecutor
executor: ThreadPoolExecutor | None

def __init__(
self,
executor: ThreadPoolExecutor,
*,
input: Properties,
wild: bool = False,
multiple_views: bool = False,
limit_angles: bool = False,
executor: ThreadPoolExecutor | None = None,
) -> None:
self.executor = executor

Expand Down Expand Up @@ -375,6 +375,11 @@ def __predict_safe(self, frame: Frame, faces: face.Result | None) -> Result3d |

async def stream(self) -> Fiber[Input | None, list[Result3d | None] | None]:
executor = self.executor
if executor is None:
raise RuntimeError(
'Processing in the stream mode requires the Estimator to have an executor. Please pass an "executor" argument to the estimator constructor'
)

loop = asyncio.get_running_loop()

results: list[Result3d | None] | None = None
Expand Down
12 changes: 9 additions & 3 deletions child_lab_framework/task/pose/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from enum import IntEnum, auto
from functools import cached_property
from pathlib import Path

import cv2
import numpy as np
Expand Down Expand Up @@ -245,9 +246,9 @@ def transform(self, transformation: Transformation) -> 'Result3d':


class Estimator:
MODEL_PATH: str = str(MODELS_DIR / 'yolov11x-pose.pt')
MODEL_PATH: Path = MODELS_DIR / 'yolov11x-pose.pt'

executor: ThreadPoolExecutor
executor: ThreadPoolExecutor | None
device: torch.device

model: ultralytics.YOLO
Expand All @@ -260,12 +261,12 @@ class Estimator:

def __init__(
self,
executor: ThreadPoolExecutor,
device: torch.device,
*,
input: Properties,
max_detections: int,
threshold: float,
executor: ThreadPoolExecutor | None = None,
) -> None:
self.executor = executor
self.device = device
Expand Down Expand Up @@ -355,6 +356,11 @@ def __interpret(self, detections: yolo.Results) -> Result | None:

async def stream(self) -> Fiber[list[Frame] | None, list[Result] | None]:
executor = self.executor
if executor is None:
raise RuntimeError(
'Processing in the stream mode requires the Estimator to have an executor. Please pass an "executor" argument to the estimator constructor'
)

loop = asyncio.get_running_loop()
device = self.device

Expand Down
5 changes: 0 additions & 5 deletions child_lab_framework/task/visualization/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import reduce
from itertools import repeat, starmap

Expand All @@ -19,19 +18,15 @@ def visualize(


class Visualizer[T: Configuration]:
executor: ThreadPoolExecutor

properties: Properties
configuration: T

def __init__(
self,
executor: ThreadPoolExecutor,
*,
properties: Properties,
configuration: T,
) -> None:
self.executor = executor
self.properties = properties
self.configuration = configuration

Expand Down

0 comments on commit 7e37a25

Please sign in to comment.