diff --git a/docs/inference_sdk/http_client.md b/docs/inference_sdk/http_client.md index 2dbfc351c..eb4333355 100644 --- a/docs/inference_sdk/http_client.md +++ b/docs/inference_sdk/http_client.md @@ -31,7 +31,7 @@ print(predictions) ## Why client has two modes - `v0` and `v1`? We are constantly improving our `infrence` package - initial version (`v0`) is compatible with models deployed at Roboflow platform (task types: `classification`, `object-detection`, `instance-segmentation` and -`keypoints-detection`) +`keypoint-detection`) are supported. Version `v1` is available in locally hosted Docker images with HTTP API. Locally hosted `inference` server exposes endpoints for model manipulations, but those endpoints are not available diff --git a/docs/quickstart/docker.md b/docs/quickstart/docker.md index 6988701fd..a6125b6b3 100644 --- a/docs/quickstart/docker.md +++ b/docs/quickstart/docker.md @@ -148,41 +148,41 @@ Choose a Dockerfile from the following options, depending on the hardware you wa === "x86 CPU" ``` docker build \ - -f dockerfiles/Dockerfile.onnx.cpu \ + -f docker/dockerfiles/Dockerfile.onnx.cpu \ -t roboflow/roboflow-inference-server-cpu . ``` === "arm64 CPU" ``` docker build \ - -f dockerfiles/Dockerfile.onnx.cpu \ + -f docker/dockerfiles/Dockerfile.onnx.cpu \ -t roboflow/roboflow-inference-server-cpu . ``` === "GPU" ``` docker build \ - -f dockerfiles/Dockerfile.onnx.gpu \ + -f docker/dockerfiles/Dockerfile.onnx.gpu \ -t roboflow/roboflow-inference-server-gpu . ``` === "Jetson 4.5.x" ``` docker build \ - -f dockerfiles/Dockerfile.onnx.jetson \ + -f docker/dockerfiles/Dockerfile.onnx.jetson \ -t roboflow/roboflow-inference-server-jetson-4.5.0 . ``` === "Jetson 4.6.x" ``` docker build \ - -f dockerfiles/Dockerfile.onnx.jetson \ + -f docker/dockerfiles/Dockerfile.onnx.jetson \ -t roboflow/roboflow-inference-server-jetson-4.6.1 . ``` === "Jetson 5.x" ``` docker build \ - -f dockerfiles/Dockerfile.onnx.jetson.5.1.1 \ + -f docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1 \ -t roboflow/roboflow-inference-server-jetson-5.1.1 . ``` diff --git a/inference/core/constants.py b/inference/core/constants.py index a17766247..6958784c4 100644 --- a/inference/core/constants.py +++ b/inference/core/constants.py @@ -1,4 +1,4 @@ CLASSIFICATION_TASK = "classification" OBJECT_DETECTION_TASK = "object-detection" INSTANCE_SEGMENTATION_TASK = "instance-segmentation" -KEYPOINTS_DETECTION_TASK = "keypoints-detection" +KEYPOINTS_DETECTION_TASK = "keypoint-detection" diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 863f86703..6698f6a79 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -1024,7 +1024,7 @@ async def legacy_infer_from_request( } elif task_type == "classification": inference_request_type = ClassificationInferenceRequest - elif task_type == "keypoints-detection": + elif task_type == "keypoint-detection": inference_request_type = KeypointsDetectionInferenceRequest args = {"keypoint_confidence": keypoint_confidence} inference_request = inference_request_type( diff --git a/inference/core/models/classification_base.py b/inference/core/models/classification_base.py index f72860a2c..eaf6e1f5c 100644 --- a/inference/core/models/classification_base.py +++ b/inference/core/models/classification_base.py @@ -14,8 +14,10 @@ ) from inference.core.models.roboflow import OnnxRoboflowInferenceModel from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) from inference.core.utils.image_utils import load_image_rgb -from inference.core.utils.validate import get_num_classes_from_model_prediction_shape class ClassificationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel): diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py index cd7718d4a..c1bca195b 100644 --- a/inference/core/models/instance_segmentation_base.py +++ b/inference/core/models/instance_segmentation_base.py @@ -11,6 +11,9 @@ from inference.core.exceptions import InvalidMaskDecodeArgument from inference.core.models.roboflow import OnnxRoboflowInferenceModel from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) from inference.core.nms import w_np_non_max_suppression from inference.core.utils.postprocess import ( masks2poly, @@ -20,7 +23,6 @@ process_mask_fast, process_mask_tradeoff, ) -from inference.core.utils.validate import get_num_classes_from_model_prediction_shape DEFAULT_CONFIDENCE = 0.5 DEFAULT_IOU_THRESH = 0.5 diff --git a/inference/core/models/keypoints_detection_base.py b/inference/core/models/keypoints_detection_base.py index 1f4a32448..125e0878c 100644 --- a/inference/core/models/keypoints_detection_base.py +++ b/inference/core/models/keypoints_detection_base.py @@ -13,6 +13,9 @@ ObjectDetectionBaseOnnxRoboflowInferenceModel, ) from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) from inference.core.nms import w_np_non_max_suppression from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints @@ -28,7 +31,7 @@ class KeypointsDetectionBaseOnnxRoboflowInferenceModel( ): """Roboflow ONNX Object detection model. This class implements an object detection specific infer method.""" - task_type = "keypoints-detection" + task_type = "keypoint-detection" def __init__(self, model_id: str, *args, **kwargs): super().__init__(model_id, *args, **kwargs) @@ -188,3 +191,17 @@ def _model_keypoints_to_response( ) results.append(keypoint) return results + + def keypoints_count(self) -> int: + raise NotImplementedError + + def validate_model_classes(self) -> None: + num_keypoints = self.keypoints_count() + output_shape = self.get_model_output_shape() + num_classes = get_num_classes_from_model_prediction_shape( + len_prediction=output_shape[2], keypoints=num_keypoints + ) + if num_classes != self.num_classes: + raise ValueError( + f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" + ) diff --git a/inference/core/models/object_detection_base.py b/inference/core/models/object_detection_base.py index a1c0c8887..70597028a 100644 --- a/inference/core/models/object_detection_base.py +++ b/inference/core/models/object_detection_base.py @@ -11,9 +11,11 @@ from inference.core.logger import logger from inference.core.models.roboflow import OnnxRoboflowInferenceModel from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) from inference.core.nms import w_np_non_max_suppression from inference.core.utils.postprocess import post_process_bboxes -from inference.core.utils.validate import get_num_classes_from_model_prediction_shape DEFAULT_CONFIDENCE = 0.5 DEFAULT_IOU_THRESH = 0.5 diff --git a/inference/core/models/stubs.py b/inference/core/models/stubs.py index 76a72a677..9c31f72d1 100644 --- a/inference/core/models/stubs.py +++ b/inference/core/models/stubs.py @@ -117,7 +117,7 @@ def make_response( class KeypointsDetectionModelStub(ModelStub): - task_type = "keypoints-detection" + task_type = "keypoint-detection" def make_response( self, request: InferenceRequest, prediction: dict, **kwargs diff --git a/inference/core/models/utils/keypoints.py b/inference/core/models/utils/keypoints.py new file mode 100644 index 000000000..90e8b02be --- /dev/null +++ b/inference/core/models/utils/keypoints.py @@ -0,0 +1,7 @@ +def superset_keypoints_count(keypoints_metadata={}) -> int: + """Returns the number of keypoints in the superset.""" + max_keypoints = 0 + for keypoints in keypoints_metadata.values(): + if len(keypoints) > max_keypoints: + max_keypoints = len(keypoints) + return max_keypoints diff --git a/inference/core/utils/validate.py b/inference/core/models/utils/validate.py similarity index 50% rename from inference/core/utils/validate.py rename to inference/core/models/utils/validate.py index 3ebbd6c3e..d373decd0 100644 --- a/inference/core/utils/validate.py +++ b/inference/core/models/utils/validate.py @@ -1,3 +1,3 @@ -def get_num_classes_from_model_prediction_shape(len_prediction, masks=0): - num_classes = len_prediction - 5 - masks +def get_num_classes_from_model_prediction_shape(len_prediction, masks=0, keypoints=0): + num_classes = len_prediction - 5 - masks - (keypoints * 3) return num_classes diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index 20743e8af..95e06e6b4 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -35,6 +35,7 @@ "object-detection": "yolov5v2s", "instance-segmentation": "yolact", "classification": "vit", + "keypoint-detection": "yolov8n", } PROJECT_TASK_TYPE_KEY = "project_task_type" MODEL_TYPE_KEY = "model_type" diff --git a/inference/models/utils.py b/inference/models/utils.py index 4c2c80638..a36978b5e 100644 --- a/inference/models/utils.py +++ b/inference/models/utils.py @@ -119,17 +119,17 @@ "instance-segmentation", "yolov8-seg", ): YOLOv8InstanceSegmentation, - ("keypoints-detection", "stub"): KeypointsDetectionModelStub, - ("keypoints-detection", "yolov8n"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8s"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8m"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8l"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8x"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8n-pose"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8s-pose"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8m-pose"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8l-pose"): YOLOv8KeypointsDetection, - ("keypoints-detection", "yolov8x-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "stub"): KeypointsDetectionModelStub, + ("keypoint-detection", "yolov8n"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8s"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8m"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8l"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8x"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8n-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8s-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8m-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8l-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8x-pose"): YOLOv8KeypointsDetection, } try: diff --git a/inference/models/yolov8/yolov8_keypoints_detection.py b/inference/models/yolov8/yolov8_keypoints_detection.py index 893c7ca45..3c4daf912 100644 --- a/inference/models/yolov8/yolov8_keypoints_detection.py +++ b/inference/models/yolov8/yolov8_keypoints_detection.py @@ -2,12 +2,11 @@ import numpy as np +from inference.core.exceptions import ModelArtefactError from inference.core.models.keypoints_detection_base import ( KeypointsDetectionBaseOnnxRoboflowInferenceModel, ) -from inference.core.models.object_detection_base import ( - ObjectDetectionBaseOnnxRoboflowInferenceModel, -) +from inference.core.models.utils.keypoints import superset_keypoints_count class YOLOv8KeypointsDetection(KeypointsDetectionBaseOnnxRoboflowInferenceModel): @@ -52,3 +51,9 @@ def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]: [boxes, confs, class_confs, keypoints_detections], axis=2 ) return (bboxes_predictions,) + + def keypoints_count(self) -> int: + """Returns the number of keypoints in the model.""" + if self.keypoints_metadata is None: + raise ModelArtefactError("Keypoints metadata not available.") + return superset_keypoints_count(self.keypoints_metadata) diff --git a/inference_sdk/http/entities.py b/inference_sdk/http/entities.py index af8282fbb..83dc28937 100644 --- a/inference_sdk/http/entities.py +++ b/inference_sdk/http/entities.py @@ -17,7 +17,7 @@ CLASSIFICATION_TASK = "classification" OBJECT_DETECTION_TASK = "object-detection" INSTANCE_SEGMENTATION_TASK = "instance-segmentation" -KEYPOINTS_DETECTION_TASK = "keypoints-detection" +KEYPOINTS_DETECTION_TASK = "keypoint-detection" DEFAULT_MAX_INPUT_SIZE = 1024 diff --git a/tests/inference/unit_tests/core/models/test_stubs.py b/tests/inference/unit_tests/core/models/test_stubs.py index e2657d3e2..5a6f9152a 100644 --- a/tests/inference/unit_tests/core/models/test_stubs.py +++ b/tests/inference/unit_tests/core/models/test_stubs.py @@ -25,7 +25,7 @@ (ClassificationModelStub, "classification"), (ObjectDetectionModelStub, "object-detection"), (InstanceSegmentationModelStub, "instance-segmentation"), - (KeypointsDetectionModelStub, "keypoints-detection"), + (KeypointsDetectionModelStub, "keypoint-detection"), ], ) def test_model_stub(stub_class: Type[ModelStub], expected_task_type: str) -> None: diff --git a/tests/inference/unit_tests/core/models/utils/test_keypoints.py b/tests/inference/unit_tests/core/models/utils/test_keypoints.py new file mode 100644 index 000000000..42f9882a7 --- /dev/null +++ b/tests/inference/unit_tests/core/models/utils/test_keypoints.py @@ -0,0 +1,97 @@ +from inference.core.models.utils.keypoints import ( + superset_keypoints_count, +) + + +def test_superset_keypoints_count() -> None: + # given + keypoints_metadata = { + 0: { + 0: "nose", + 1: "left_eye", + 2: "right_eye", + 3: "left_ear", + 4: "right_ear", + 5: "left_shoulder", + 6: "right_shoulder", + 7: "left_elbow", + 8: "right_elbow", + 9: "left_wrist", + 10: "right_wrist", + 11: "left_hip", + 12: "right_hip", + 13: "left_knee", + 14: "right_knee", + 15: "left_ankle", + 16: "right_ankle", + } + } + # when + keypoints_count = superset_keypoints_count(keypoints_metadata) + # then + assert keypoints_count == 17 + + +def test_superset_keypoints_count_with_two_classes() -> None: + # given + keypoints_metadata = { + 0: { + 0: "nose", + 1: "left_eye", + 2: "right_eye", + 3: "left_ear", + 4: "right_ear", + 5: "left_shoulder", + 6: "right_shoulder", + }, + 1: { + 0: "nose", + 1: "left_eye", + 2: "right_eye", + 3: "left_ear", + 4: "right_ear", + 5: "left_shoulder", + 6: "right_shoulder", + 7: "left_elbow", + 8: "right_elbow", + 9: "left_wrist", + 10: "right_wrist", + 11: "left_hip", + 12: "right_hip", + }, + } + # when + keypoints_count = superset_keypoints_count(keypoints_metadata) + # then + assert keypoints_count == 13 + + +def test_superset_keypoints_count_with_two_non_overlapping_classes() -> None: + # given + keypoints_metadata = { + 0: { + 0: "nose", + 1: "left_eye", + 2: "right_eye", + 3: "left_ear", + 4: "right_ear", + 5: "left_shoulder", + 6: "right_shoulder", + }, + 1: { + 0: "nose1", + 1: "left_eye1", + 2: "right_eye2", + 3: "left_ear3", + 4: "right_ear4", + 5: "left_shoulder5", + 6: "right_shoulder6", + 7: "left_elbow7", + 8: "right_elbow8", + 9: "left_wrist9", + }, + } + # when + keypoints_count = superset_keypoints_count(keypoints_metadata) + # then + assert keypoints_count == 10 diff --git a/tests/inference/unit_tests/core/models/utils/test_validate.py b/tests/inference/unit_tests/core/models/utils/test_validate.py new file mode 100644 index 000000000..324ca5576 --- /dev/null +++ b/tests/inference/unit_tests/core/models/utils/test_validate.py @@ -0,0 +1,48 @@ +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) + + +def test_get_num_classes_from_model_prediction_shape() -> None: + # given + prediction_len = 10 + # when + num_classes = get_num_classes_from_model_prediction_shape(prediction_len) + # then + assert num_classes == 5 + + +def test_get_num_classes_from_model_prediction_shape_with_masks() -> None: + # given + prediction_len = 42 + num_masks = 32 + # when + num_classes = get_num_classes_from_model_prediction_shape(prediction_len, num_masks) + # then + assert num_classes == 5 + + +def test_get_num_classes_from_model_prediction_shape_with_keypoints() -> None: + # given + prediction_len = 57 + num_keypoints = 17 + # when + num_classes = get_num_classes_from_model_prediction_shape( + prediction_len, keypoints=num_keypoints + ) + # then + assert num_classes == 1 + + +def test_get_num_classes_from_model_prediction_shape_with_keypoints_more_classes() -> ( + None +): + # given + prediction_len = 46 + num_keypoints = 12 + # when + num_classes = get_num_classes_from_model_prediction_shape( + prediction_len, keypoints=num_keypoints + ) + # then + assert num_classes == 5 diff --git a/tests/inference/unit_tests/core/utils/test_validate.py b/tests/inference/unit_tests/core/utils/test_validate.py deleted file mode 100644 index 68b33596d..000000000 --- a/tests/inference/unit_tests/core/utils/test_validate.py +++ /dev/null @@ -1,14 +0,0 @@ -from inference.core.utils.validate import get_num_classes_from_model_prediction_shape - - -def test_get_num_classes_from_model_prediction_shape(): - prediction_len = 10 - num_classes = get_num_classes_from_model_prediction_shape(prediction_len) - assert num_classes == 5 - - -def test_get_num_classes_from_model_prediction_shape_with_masks(): - prediction_len = 42 - num_masks = 32 - num_classes = get_num_classes_from_model_prediction_shape(prediction_len, num_masks) - assert num_classes == 5