Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize keypoint detection #174

Merged
merged 6 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/inference_sdk/http_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ print(predictions)
## Why client has two modes - `v0` and `v1`?
We are constantly improving our `infrence` package - initial version (`v0`) is compatible with
models deployed at Roboflow platform (task types: `classification`, `object-detection`, `instance-segmentation` and
`keypoints-detection`)
`keypoint-detection`)
are supported. Version `v1` is available in locally hosted Docker images with HTTP API.

Locally hosted `inference` server exposes endpoints for model manipulations, but those endpoints are not available
Expand Down
2 changes: 1 addition & 1 deletion inference/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CLASSIFICATION_TASK = "classification"
OBJECT_DETECTION_TASK = "object-detection"
INSTANCE_SEGMENTATION_TASK = "instance-segmentation"
KEYPOINTS_DETECTION_TASK = "keypoints-detection"
KEYPOINTS_DETECTION_TASK = "keypoint-detection"
2 changes: 1 addition & 1 deletion inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ async def legacy_infer_from_request(
}
elif task_type == "classification":
inference_request_type = ClassificationInferenceRequest
elif task_type == "keypoints-detection":
elif task_type == "keypoint-detection":
inference_request_type = KeypointsDetectionInferenceRequest
args = {"keypoint_confidence": keypoint_confidence}
inference_request = inference_request_type(
Expand Down
27 changes: 26 additions & 1 deletion inference/core/models/keypoints_detection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from inference.core.models.types import PreprocessReturnMetadata
from inference.core.nms import w_np_non_max_suppression
from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints
from inference.core.utils.validate import get_num_classes_from_model_prediction_shape

DEFAULT_CONFIDENCE = 0.5
DEFAULT_IOU_THRESH = 0.5
Expand All @@ -28,7 +29,7 @@ class KeypointsDetectionBaseOnnxRoboflowInferenceModel(
):
"""Roboflow ONNX Object detection model. This class implements an object detection specific infer method."""

task_type = "keypoints-detection"
task_type = "keypoint-detection"

def __init__(self, model_id: str, *args, **kwargs):
super().__init__(model_id, *args, **kwargs)
Expand Down Expand Up @@ -188,3 +189,27 @@ def _model_keypoints_to_response(
)
results.append(keypoint)
return results

def superset_keypoints_count(self) -> int:
"""Returns the number of keypoints in the superset."""
if self.keypoints_metadata is None:
raise ModelArtefactError("Keypoints metadata not available.")
max_keypoints = 0
for keypoints in self.keypoints_metadata.values():
if len(keypoints) > max_keypoints:
max_keypoints = len(keypoints)
return max_keypoints

def validate_model_classes(self) -> None:
num_keypoints = self.superset_keypoints_count()
keypoints_shape = num_keypoints * 3
output_shape = self.get_model_output_shape()
num_classes = get_num_classes_from_model_prediction_shape(
len_prediction=output_shape[2], keypoints_shape=keypoints_shape
)
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if num_classes != self.num_classes:
    raise ValueError()

assert num_classes == self.num_classes
except AssertionError:
raise ValueError(
f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
)
2 changes: 1 addition & 1 deletion inference/core/models/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def make_response(


class KeypointsDetectionModelStub(ModelStub):
task_type = "keypoints-detection"
task_type = "keypoint-detection"

def make_response(
self, request: InferenceRequest, prediction: dict, **kwargs
Expand Down
1 change: 1 addition & 0 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"object-detection": "yolov5v2s",
"instance-segmentation": "yolact",
"classification": "vit",
"keypoint-detection": "yolov8n",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most changes are keypoints-detection -> keypoint-detection, but this was a necessary addition to make sure model manager accepts type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}
PROJECT_TASK_TYPE_KEY = "project_task_type"
MODEL_TYPE_KEY = "model_type"
Expand Down
6 changes: 4 additions & 2 deletions inference/core/utils/validate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
def get_num_classes_from_model_prediction_shape(len_prediction, masks=0):
num_classes = len_prediction - 5 - masks
def get_num_classes_from_model_prediction_shape(
len_prediction, masks=0, keypoints_shape=0
):
num_classes = len_prediction - 5 - masks - keypoints_shape
return num_classes
22 changes: 11 additions & 11 deletions inference/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,17 @@
"instance-segmentation",
"yolov8-seg",
): YOLOv8InstanceSegmentation,
("keypoints-detection", "stub"): KeypointsDetectionModelStub,
("keypoints-detection", "yolov8n"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8s"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8m"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8l"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8x"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8n-pose"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8s-pose"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8m-pose"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8l-pose"): YOLOv8KeypointsDetection,
("keypoints-detection", "yolov8x-pose"): YOLOv8KeypointsDetection,
("keypoint-detection", "stub"): KeypointsDetectionModelStub,
("keypoint-detection", "yolov8n"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8s"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8m"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8l"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8x"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8n-pose"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8s-pose"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8m-pose"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8l-pose"): YOLOv8KeypointsDetection,
("keypoint-detection", "yolov8x-pose"): YOLOv8KeypointsDetection,
}

try:
Expand Down
2 changes: 1 addition & 1 deletion inference_sdk/http/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
CLASSIFICATION_TASK = "classification"
OBJECT_DETECTION_TASK = "object-detection"
INSTANCE_SEGMENTATION_TASK = "instance-segmentation"
KEYPOINTS_DETECTION_TASK = "keypoints-detection"
KEYPOINTS_DETECTION_TASK = "keypoint-detection"
DEFAULT_MAX_INPUT_SIZE = 1024


Expand Down
2 changes: 1 addition & 1 deletion tests/inference/unit_tests/core/models/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
(ClassificationModelStub, "classification"),
(ObjectDetectionModelStub, "object-detection"),
(InstanceSegmentationModelStub, "instance-segmentation"),
(KeypointsDetectionModelStub, "keypoints-detection"),
(KeypointsDetectionModelStub, "keypoint-detection"),
],
)
def test_model_stub(stub_class: Type[ModelStub], expected_task_type: str) -> None:
Expand Down
Loading