Skip to content

Commit

Permalink
chore: flatten output object
Browse files Browse the repository at this point in the history
  • Loading branch information
ediardo committed Oct 10, 2024
1 parent 1ec0d46 commit 2364951
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from inference.core.workflows.execution_engine.entities.types import (
DICTIONARY_KIND,
INTEGER_KIND,
FLOAT_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
Expand All @@ -25,12 +26,13 @@
WorkflowBlockManifest,
)

SHORT_DESCRIPTION = "Measure the distance between two bounding boxes on a 2D plane with a camera perpendicular to the plane."
SHORT_DESCRIPTION = "Measure the distance between two bounding boxes on a 2D plane using a perpendicular camera and either a reference object or a pixel-to-centimeter ratio for scaling."

LONG_DESCRIPTION = """
Measure the distance between two bounding boxes on a 2D plane with a camera perpendicular to the plane. The method requires footage from this specific perspective and a reference object with known dimensions or a pixel-to-centimeter ratio, placed in the same plane as the bounding boxes."""
Measure the distance between two bounding boxes on a 2D plane using a camera positioned perpendicular to the plane. This method requires footage from this specific perspective, along with either a reference object of known dimensions placed in the same plane as the bounding boxes or a pixel-to-centimeter ratio that defines how many pixels correspond to one centimeter."""

OUTPUT_KEY = "distance_measurement"
OUTPUT_KEY_CENTIMETER = "distance_cm"
OUTPUT_KEY_PIXEL = "distance_pixel"

class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
Expand Down Expand Up @@ -146,9 +148,15 @@ class BlockManifest(WorkflowBlockManifest):
@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(name=OUTPUT_KEY, kind=[DICTIONARY_KIND]),
OutputDefinition(
name=OUTPUT_KEY_CENTIMETER,
kind=[INTEGER_KIND],
),
OutputDefinition(
name=OUTPUT_KEY_PIXEL,
kind=[INTEGER_KIND],
),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.0.0,<2.0.0"
Expand Down Expand Up @@ -195,8 +203,7 @@ def run(
else:
raise ValueError(f"Invalid calibration type: {calibration_method}")

return distances

return distances

def measure_distance_with_reference_object(
detections: sv.Detections,
Expand All @@ -217,7 +224,7 @@ def measure_distance_with_reference_object(
raise ValueError(f"Reference class '{object_1_class_name}' or '{object_2_class_name}' not found in predictions.")

if has_overlap(reference_bbox_1, reference_bbox_2) or not has_axis_gap(reference_bbox_1, reference_bbox_2, reference_axis):
return {"distance_cm": 0, "distance_pixel": 0}
return {OUTPUT_KEY_CENTIMETER: 0, OUTPUT_KEY_PIXEL: 0}

# get the reference object bounding box
reference_bbox = None
Expand Down Expand Up @@ -250,7 +257,7 @@ def measure_distance_with_reference_object(

distance_cm = distance_pixels / pixel_ratio

return {"distance_cm": distance_cm, "distance_pixel": distance_pixels}
return {OUTPUT_KEY_CENTIMETER: distance_cm, OUTPUT_KEY_PIXEL: distance_pixels}



Expand All @@ -270,7 +277,7 @@ def measure_distance_with_pixel_ratio(
raise ValueError(f"Reference class '{object_1_class_name}' or '{object_2_class_name}' not found in predictions.")

if has_overlap(reference_bbox_1, reference_bbox_2) or not has_axis_gap(reference_bbox_1, reference_bbox_2, reference_axis):
return {"distance_cm": 0, "distance_pixel": 0}
return {OUTPUT_KEY_CENTIMETER: 0, OUTPUT_KEY_PIXEL: 0}

if pixel_ratio is None:
raise ValueError("Pixel-to-centimeter ratio must be provided.")
Expand All @@ -285,7 +292,7 @@ def measure_distance_with_pixel_ratio(

distance_cm = distance_pixels / pixel_ratio

return {"distance_cm": distance_cm, "distance_pixel": distance_pixels}
return {OUTPUT_KEY_CENTIMETER: distance_cm, OUTPUT_KEY_PIXEL: distance_pixels}


def has_overlap(bbox1: Tuple[int, int, int, int], bbox2: Tuple[int, int, int, int]) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import supervision as sv

from inference.core.workflows.core_steps.classical_cv.distance_measurement.v1 import (
OUTPUT_KEY_CENTIMETER,
OUTPUT_KEY_PIXEL,
DistanceMeasurementBlockV1
)
from inference.core.workflows.execution_engine.entities.base import Batch
Expand Down Expand Up @@ -50,16 +52,16 @@ def test_distance_measurement_block_pixel_ratio_vertical():


# then
expected_result = {'distance_cm': 9.275, 'distance_pixel': 806}
tolerance = 1e-3 # Define a tolerance level
expected_result = {OUTPUT_KEY_CENTIMETER: 9.275028768, OUTPUT_KEY_PIXEL: 806}
tolerance = 1e-5 # Define a tolerance level

# Compare the float values within the specified tolerance
assert math.isclose(result['distance_cm'], expected_result['distance_cm'], rel_tol=tolerance), \
f"Expected {expected_result['distance_cm']} cm, but got {result['distance_cm']} cm"
assert math.isclose(result[OUTPUT_KEY_CENTIMETER], expected_result[OUTPUT_KEY_CENTIMETER], rel_tol=tolerance), \
f"Expected {expected_result[OUTPUT_KEY_CENTIMETER]} cm, but got {result[OUTPUT_KEY_CENTIMETER]} cm"

# Compare the integer values directly
assert result['distance_pixel'] == expected_result['distance_pixel'], \
f"Expected {expected_result['distance_pixel']} pixels, but got {result['distance_pixel']} pixels"
assert result[OUTPUT_KEY_PIXEL] == expected_result[OUTPUT_KEY_PIXEL], \
f"Expected {expected_result[OUTPUT_KEY_PIXEL]} pixels, but got {result[OUTPUT_KEY_PIXEL]} pixels"

def test_distance_measurement_block_pixel_ratio_horizontal():
# given
Expand Down Expand Up @@ -481,5 +483,5 @@ def test_distance_measurement_block_with_vertical_overlapping_target_objects():
)

# then
expected_result = {'distance_cm': 0, 'distance_pixel': 0}
expected_result = {OUTPUT_KEY_CENTIMETER: 0, OUTPUT_KEY_PIXEL: 0}
assert result == expected_result, f"Expected {expected_result}, but got {result}"

0 comments on commit 2364951

Please sign in to comment.