Skip to content
This repository was archived by the owner on Jul 3, 2024. It is now read-only.

Commit

Permalink
feat: init result processor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mouwrice committed Apr 20, 2024
1 parent 746f0d3 commit a27bfe8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 11 deletions.
8 changes: 5 additions & 3 deletions drumpy/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
log_file: Optional[str] = None,
landmark_type: LandmarkType = LandmarkType.WORLD_LANDMARKS,
camera_index: int = 0,
disable_drum: bool = False, # noqa: FBT001, FBT002
) -> None:
"""
Initialize the application
Expand All @@ -46,7 +47,9 @@ def __init__(
self.window_surface = pygame.display.set_mode(initial_window_size)
self.manager = UIManager(initial_window_size)

self.drum_trackers = DrumTrackers()
self.drum_trackers: Optional[DrumTrackers] = None
if not disable_drum:
self.drum_trackers = DrumTrackers()

self.media_pipe_pose = MediaPipePose(
running_mode=running_mode, # type: ignore
Expand Down Expand Up @@ -79,11 +82,9 @@ def __init__(
)

def start(self: Self) -> None:
frame = 0
clock = pygame.time.Clock()
running = True
while running and not self.video_source.stopped:
frame += 1
time_delta_ms = clock.tick(self.fps)
for event in pygame.event.get():
if event.type == pygame.QUIT:
Expand All @@ -110,6 +111,7 @@ def main() -> None:
delegate=BaseOptions.Delegate.CPU, # type: ignore
file_path="../data/Recordings/multicam_asil_01_front.mkv",
# log_file="test.csv",
disable_drum=True,
)
app.start()

Expand Down
21 changes: 14 additions & 7 deletions drumpy/mediapipe_pose/mediapipe_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import numpy as np
import numpy.typing as npt
from drumpy.mediapipe_pose.landmark_type import LandmarkType
from drumpy.mediapipe_pose.landmarker_model import LandmarkerModel
from drumpy.tracking.drum_trackers import DrumTrackers
from drumpy.trajectory_file import TrajectoryFile
from mediapipe import Image, ImageFormat
from mediapipe.framework.formats import landmark_pb2
from mediapipe.python import solutions
Expand All @@ -12,11 +16,7 @@
PoseLandmarkerResult,
RunningMode,
)

from drumpy.mediapipe_pose.landmark_type import LandmarkType
from drumpy.mediapipe_pose.landmarker_model import LandmarkerModel
from drumpy.tracking.drum_trackers import DrumTrackers
from drumpy.trajectory_file import TrajectoryFile
from drumpy.mediapipe_pose.process_result import ResultProcessor


def visualize_landmarks(
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self: Self,
running_mode: RunningMode,
landmark_type: LandmarkType,
drum_trackers: DrumTrackers,
drum_trackers: Optional[DrumTrackers] = None,
model: LandmarkerModel = LandmarkerModel.FULL,
delegate: BaseOptions.Delegate = BaseOptions.Delegate.GPU,
log_file: Optional[str] = None,
Expand Down Expand Up @@ -105,6 +105,8 @@ def __init__(

self.drum_trackers = drum_trackers

self.result_processor = ResultProcessor()

def result_callback(
self: Self, result: PoseLandmarkerResult, image: Image, timestamp_ms: int
) -> None:
Expand All @@ -115,18 +117,23 @@ def result_callback(
:param timestamp_ms: The timestamp of the frame
:return:
"""
result = self.result_processor.process_result(result, timestamp_ms)
self.detection_result = result
self.latency = timestamp_ms - self.latest_timestamp
self.latest_timestamp = timestamp_ms
self.drum_trackers.drum.check_calibrations()
if self.drum_trackers is not None:
self.drum_trackers.drum.check_calibrations()
if (
result.pose_world_landmarks is not None
and len(result.pose_world_landmarks) > 0
and self.drum_trackers is not None
):
self.drum_trackers.update(result.pose_world_landmarks[0])

self.visualisation = visualize_landmarks(
image.numpy_view(), self.detection_result
)

self.frame_count += 1
if self.csv_writer is not None:
self.write_landmarks(result, timestamp_ms)
Expand Down
115 changes: 115 additions & 0 deletions drumpy/mediapipe_pose/process_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from mediapipe.tasks.python.components.containers.landmark import NormalizedLandmark
from mediapipe.tasks.python.vision import PoseLandmarkerResult


class ResultProcessor:
"""
Process the result of the pose estimation.
Tries to predict the position of the result based on the previous positions.
The more plausible the positions are, the more likely the result is correct.
Major deviations from the previous positions are considered outliers and will be corrected.
"""

def __init__(self, memory: int = 5, threshold: float = 0.08) -> None:
self.memory: int = memory
self.threshold: float = threshold
self.results: list[PoseLandmarkerResult] = []
self.timestamps_ms: list[float] = [] # timestamps of the results, in ms
self.time_deltas_ms: list[float] = [] # time deltas between the results, in ms
self.time_duration_ms: float = 0.0 # duration of the time deltas, in ms

def process_result(
self, result: PoseLandmarkerResult, timestamp_ms: float
) -> PoseLandmarkerResult:
"""
Process the result of the pose estimation
"""
for i, landmark in enumerate(result.pose_landmarks[0]):
result.pose_landmarks[0][i] = self.process_normalized_landmark(
landmark, i, timestamp_ms
)

self.results.append(result)
if len(self.results) > self.memory:
self.results.pop(0)

self.timestamps_ms.append(timestamp_ms)
if len(self.timestamps_ms) > self.memory:
self.timestamps_ms.pop(0)

if len(self.timestamps_ms) > 1:
time_delta = self.timestamps_ms[-1] - self.timestamps_ms[-2]
self.time_deltas_ms.append(time_delta)
self.time_duration_ms += time_delta
if len(self.time_deltas_ms) > self.memory:
self.time_duration_ms -= self.time_deltas_ms.pop(0)

return result

def process_normalized_landmark(
self, landmark: NormalizedLandmark, index: int, timestamp_ms: float
) -> NormalizedLandmark:
if len(self.results) < 2: # noqa: PLR2004
return landmark

# Calculate the average difference between the current and previous positions
diffs = [
self.calculate_diff(
self.results[i].pose_landmarks[0][index], # current position
self.results[i - 1].pose_landmarks[0][index], # previous position
)
for i in range(1, len(self.results))
]

# Calculate the average difference over time
# The average difference is the average movement of the landmark per millisecond
avg_diff = NormalizedLandmark()
avg_diff.x = sum(diff.x for diff in diffs) / self.time_duration_ms
avg_diff.y = sum(diff.y for diff in diffs) / self.time_duration_ms
avg_diff.z = sum(diff.z for diff in diffs) / self.time_duration_ms

# Predict the current position by adding the average difference to the previous position
# This is the expected position of the landmark based on the previous positions
predicted = NormalizedLandmark()
time_delta = timestamp_ms - self.timestamps_ms[-1]
predicted.x = (
self.results[-1].pose_landmarks[0][index].x + avg_diff.x * time_delta
)
predicted.y = (
self.results[-1].pose_landmarks[0][index].y + avg_diff.y * time_delta
)
predicted.z = (
self.results[-1].pose_landmarks[0][index].z + avg_diff.z * time_delta
)
# print(f"Predicted: {predicted.x}, {predicted.y}, {predicted.z}")
# print(f"Current: {landmark.x}, {landmark.y}, {landmark.z}\n")

# Calculate the difference between the predicted and current position
diff = self.calculate_diff(landmark, predicted)

# If the difference is too large, the current position is considered an outlier
# The current position is corrected by the predicted position
if abs(diff.x) > self.threshold:
print(f"Corrected x: {landmark.x} -> {predicted.x}")
landmark.x = predicted.x
if abs(diff.y) > self.threshold:
print(f"Corrected y: {landmark.y} -> {predicted.y}")
landmark.y = predicted.y
if abs(diff.z) > self.threshold:
# print(f"Corrected z: {landmark.z} -> {predicted.z}")
landmark.z = predicted.z

return landmark

@staticmethod
def calculate_diff(
current: NormalizedLandmark, previous: NormalizedLandmark
) -> NormalizedLandmark:
"""
Calculate the difference between the current and previous positions
"""
diff = NormalizedLandmark()
diff.x = current.x - previous.x
diff.y = current.y - previous.y
diff.z = current.z - previous.z
return diff
2 changes: 1 addition & 1 deletion drumpy/tracking/marker_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_velocity(self: Self) -> float:
if len(self.positions) < 2: # noqa: PLR2004
return 0

return self.positions[-1][2] - self.positions[-2][2]
return float(self.positions[-1][2]) - float(self.positions[-2][2])

def is_hit(self: Self) -> bool:
"""
Expand Down

0 comments on commit a27bfe8

Please sign in to comment.