diff --git a/.github/workflows/docker.stream_management_api.yml b/.github/workflows/docker.stream_management_api.yml new file mode 100644 index 000000000..b82abfbbd --- /dev/null +++ b/.github/workflows/docker.stream_management_api.yml @@ -0,0 +1,42 @@ +name: Build and Push Container with Stream Management API + +on: + release: + types: [created] + workflow_dispatch: + +env: + VERSION: '0.0.0' # Default version, will be overwritten + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - + name: Set up QEMU + uses: docker/setup-qemu-action@v2 + - + name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - + name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - + name: 🛎️ Checkout + uses: actions/checkout@v3 + - + name: Read version from file + run: echo "VERSION=$(DISABLE_VERSION_CHECK=true python ./inference/core/version.py)" >> $GITHUB_ENV + - + name: Build and Push + uses: docker/build-push-action@v4 + with: + push: true + tags: roboflow/roboflow-inference-stream-management-api:latest,roboflow/roboflow-inference-stream-management-api:${{env.VERSION}} + cache-from: type=registry,ref=roboflow/roboflow-inference-stream-management-api:cache + cache-to: type=registry,ref=roboflow/roboflow-inference-stream-management-api:cache,mode=max + platforms: linux/amd64,linux/arm64 + file: ./docker/dockerfiles/Dockerfile.stream_management_api \ No newline at end of file diff --git a/.github/workflows/docker.stream_manager.cpu.yml b/.github/workflows/docker.stream_manager.cpu.yml new file mode 100644 index 000000000..64129591c --- /dev/null +++ b/.github/workflows/docker.stream_manager.cpu.yml @@ -0,0 +1,42 @@ +name: Build and Push Container with Stream Manager CPU + +on: + release: + types: [created] + workflow_dispatch: + +env: + VERSION: '0.0.0' # Default version, will be overwritten + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - + name: Set up QEMU + uses: docker/setup-qemu-action@v2 + - + name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - + name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - + name: 🛎️ Checkout + uses: actions/checkout@v3 + - + name: Read version from file + run: echo "VERSION=$(DISABLE_VERSION_CHECK=true python ./inference/core/version.py)" >> $GITHUB_ENV + - + name: Build and Push + uses: docker/build-push-action@v4 + with: + push: true + tags: roboflow/roboflow-inference-stream-manager-cpu:latest,roboflow/roboflow-inference-stream-manager-cpu:${{env.VERSION}} + cache-from: type=registry,ref=roboflow/roboflow-inference-stream-manager-cpu:cache + cache-to: type=registry,ref=roboflow/roboflow-inference-stream-manager-cpu:cache,mode=max + platforms: linux/amd64,linux/arm64 + file: ./docker/dockerfiles/Dockerfile.onnx.cpu \ No newline at end of file diff --git a/.github/workflows/docker.stream_manager.gpu.yml b/.github/workflows/docker.stream_manager.gpu.yml new file mode 100644 index 000000000..588c16755 --- /dev/null +++ b/.github/workflows/docker.stream_manager.gpu.yml @@ -0,0 +1,46 @@ +name: Build and Push Container with Stream Manager GPU + +on: + release: + types: [created] + workflow_dispatch: + +env: + VERSION: '0.0.0' # Default version, will be overwritten + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - name: Remove unnecessary files + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + - + name: Set up QEMU + uses: docker/setup-qemu-action@v2 + - + name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - + name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - + name: 🛎️ Checkout + uses: actions/checkout@v3 + - + name: Read version from file + run: echo "VERSION=$(DISABLE_VERSION_CHECK=true python ./inference/core/version.py)" >> $GITHUB_ENV + - + name: Build and Push + uses: docker/build-push-action@v4 + with: + push: true + tags: roboflow/roboflow-inference-stream-manager-gpu:latest,roboflow/roboflow-inference-stream-manager-gpu:${{env.VERSION}} + cache-from: type=registry,ref=roboflow/roboflow-inference-stream-manager-gpu:cache + cache-to: type=registry,ref=roboflow/roboflow-inference-stream-manager-gpu,mode=max + platforms: linux/amd64 + file: ./docker/dockerfiles/Dockerfile.onnx.gpu.stream_manager \ No newline at end of file diff --git a/.github/workflows/docker.stream_manager.jetson.5.1.1.yml b/.github/workflows/docker.stream_manager.jetson.5.1.1.yml new file mode 100644 index 000000000..76abbdc50 --- /dev/null +++ b/.github/workflows/docker.stream_manager.jetson.5.1.1.yml @@ -0,0 +1,46 @@ +name: Build and Push Container with Stream Manager Jetson 5.1.1 + +on: + release: + types: [created] + workflow_dispatch: + +env: + VERSION: '0.0.0' # Default version, will be overwritten + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - name: Remove unnecessary files + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + - + name: Set up QEMU + uses: docker/setup-qemu-action@v2 + - + name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - + name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - + name: 🛎️ Checkout + uses: actions/checkout@v3 + - + name: Read version from file + run: echo "VERSION=$(DISABLE_VERSION_CHECK=true python ./inference/core/version.py)" >> $GITHUB_ENV + - + name: Build and Push + uses: docker/build-push-action@v4 + with: + push: true + tags: roboflow/roboflow-inference-stream-manager-jetson-5.1.1:latest,roboflow/roboflow-inference-stream-manager-jetson-5.1.1:${{ env.VERSION}} + cache-from: type=registry,ref=roboflow/roboflow-inference-stream-manager-jetson-5.1.1:cache + cache-to: type=registry,ref=roboflow/roboflow-inference-stream-manager-jetson-5.1.1:cache,mode=max + platforms: linux/arm64 + file: ./docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1.stream_manager \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5c5522946..10ec40445 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,5 @@ inference_cli/version.py inference_sdk/version.py **/.DS_Store -!tests/inference/unit_tests/core/interfaces/assets/*.mp4 \ No newline at end of file +!tests/inference/unit_tests/core/interfaces/assets/*.mp4 +!inference/enterprise/stream_management/assets/*.jpg \ No newline at end of file diff --git a/development/stream_interface/udp_receiver.py b/development/stream_interface/udp_receiver.py index a3b921bfa..9d7609313 100644 --- a/development/stream_interface/udp_receiver.py +++ b/development/stream_interface/udp_receiver.py @@ -2,13 +2,14 @@ import os import socket +HOST = os.getenv("HOST", "127.0.0.1") PORT = int(os.getenv("PORT", "9999")) BUFFER_SIZE = 65535 def main() -> None: udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) - udp_socket.bind(("127.0.0.1", PORT)) + udp_socket.bind((HOST, PORT)) try: while True: message, _ = udp_socket.recvfrom(BUFFER_SIZE) diff --git a/docker/dockerfiles/Dockerfile.onnx.cpu.stream_manager b/docker/dockerfiles/Dockerfile.onnx.cpu.stream_manager new file mode 100644 index 000000000..ed883d3e4 --- /dev/null +++ b/docker/dockerfiles/Dockerfile.onnx.cpu.stream_manager @@ -0,0 +1,37 @@ +FROM python:3.9 + +WORKDIR /app + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt update -y && apt install -y \ + ffmpeg \ + libxext6 \ + libopencv-dev \ + uvicorn \ + python3-pip \ + git \ + libgdal-dev \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements/requirements.cpu.txt \ + requirements/requirements.http.txt \ + requirements/_requirements.txt \ + ./ + +RUN pip3 install --upgrade pip && pip3 install \ + -r _requirements.txt \ + -r requirements.cpu.txt \ + -r requirements.http.txt \ + --upgrade \ + && rm -rf ~/.cache/pip + +COPY inference inference + +ENV VERSION_CHECK_MODE=continuous +ENV PROJECT=roboflow-platform +ENV HOST=0.0.0.0 +ENV PORT=7070 + +ENTRYPOINT ["python", "-m", "inference.enterprise.stream_management.manager.app"] \ No newline at end of file diff --git a/docker/dockerfiles/Dockerfile.onnx.gpu.stream_manager b/docker/dockerfiles/Dockerfile.onnx.gpu.stream_manager new file mode 100644 index 000000000..525830de7 --- /dev/null +++ b/docker/dockerfiles/Dockerfile.onnx.gpu.stream_manager @@ -0,0 +1,35 @@ +FROM nvcr.io/nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04 + +WORKDIR /app + +RUN rm -rf /var/lib/apt/lists/* && apt-get clean && apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + ffmpeg \ + libxext6 \ + libopencv-dev \ + uvicorn \ + python3-pip \ + git \ + libgdal-dev \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements/requirements.http.txt \ + requirements/requirements.gpu.txt \ + requirements/_requirements.txt \ + ./ + +RUN pip3 install --upgrade pip && pip3 install \ + -r _requirements.txt \ + -r requirements.http.txt \ + -r requirements.gpu.txt \ + --upgrade \ + && rm -rf ~/.cache/pip + +WORKDIR /app/ +COPY inference inference + +ENV VERSION_CHECK_MODE=continuous +ENV PROJECT=roboflow-platform +ENV HOST=0.0.0.0 +ENV PORT=7070 + +ENTRYPOINT ["python3", "-m", "inference.enterprise.stream_management.manager.app"] \ No newline at end of file diff --git a/docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1.stream_manager b/docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1.stream_manager new file mode 100644 index 000000000..d5c1e5ecc --- /dev/null +++ b/docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1.stream_manager @@ -0,0 +1,56 @@ +FROM nvcr.io/nvidia/l4t-ml:r35.2.1-py3 + +ARG DEBIAN_FRONTEND=noninteractive +ENV LANG en_US.UTF-8 + +RUN apt-get update -y && apt-get install -y \ + lshw \ + git \ + python3-pip \ + python3-matplotlib \ + gfortran \ + build-essential \ + libatlas-base-dev \ + ffmpeg \ + libsm6 \ + libxext6 \ + wget \ + python3-shapely \ + gdal-bin \ + libgdal-dev \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements/requirements.clip.txt \ + requirements/requirements.http.txt \ + requirements/_requirements.txt \ + ./ + +RUN pip3 install --ignore-installed PyYAML && rm -rf ~/.cache/pip + +RUN pip3 install --upgrade pip && pip3 install \ + -r _requirements.txt \ + -r requirements.clip.txt \ + -r requirements.http.txt \ + --upgrade \ + && rm -rf ~/.cache/pip + +RUN pip3 uninstall --yes onnxruntime +RUN wget https://nvidia.box.com/shared/static/v59xkrnvederwewo2f1jtv6yurl92xso.whl -O onnxruntime_gpu-1.12.1-cp38-cp38-linux_aarch64.whl +RUN pip3 install onnxruntime_gpu-1.12.1-cp38-cp38-linux_aarch64.whl "opencv-python-headless<4.3" \ + && rm -rf ~/.cache/pip \ + && rm onnxruntime_gpu-1.12.1-cp38-cp38-linux_aarch64.whl + +WORKDIR /app/ +COPY inference inference + +ENV ORT_TENSORRT_FP16_ENABLE=1 +ENV ORT_TENSORRT_ENGINE_CACHE_ENABLE=1 +ENV CORE_MODEL_SAM_ENABLED=False +ENV OPENBLAS_CORETYPE=ARMV8 +ENV LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libgomp.so.1:/usr/local/lib/python3.8/dist-packages/torch.libs/libgomp-d22c30c5.so.1.0.0 +ENV VERSION_CHECK_MODE=continuous +ENV PROJECT=roboflow-platform +ENV HOST=0.0.0.0 +ENV PORT=7070 + +ENTRYPOINT ["python3", "-m", "inference.enterprise.stream_management.manager.app"] diff --git a/docker/dockerfiles/Dockerfile.stream_management_api b/docker/dockerfiles/Dockerfile.stream_management_api new file mode 100644 index 000000000..b8c7c5e13 --- /dev/null +++ b/docker/dockerfiles/Dockerfile.stream_management_api @@ -0,0 +1,38 @@ +FROM python:3.9 + +WORKDIR /app + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt update -y && apt install -y \ + ffmpeg \ + libxext6 \ + libopencv-dev \ + uvicorn \ + python3-pip \ + git \ + libgdal-dev \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements/requirements.http.txt \ + requirements/requirements.cpu.txt \ + requirements/_requirements.txt \ + ./ + +RUN pip3 install --upgrade pip && pip3 install \ + -r _requirements.txt \ + -r requirements.cpu.txt \ + -r requirements.http.txt \ + --upgrade \ + && rm -rf ~/.cache/pip + +COPY inference inference + +ENV VERSION_CHECK_MODE=continuous +ENV STREAM_MANAGEMENT_API_HOST=0.0.0.0 +ENV STREAM_MANAGEMENT_API_PORT=8080 +ENV STREAM_MANAGER_HOST=127.0.0.1 +ENV STREAM_MANAGER_PORT=7070 + +ENTRYPOINT ["python", "-m", "inference.enterprise.stream_management.api.app"] diff --git a/docker/dockerfiles/stream-management-api.compose-cpu.yaml b/docker/dockerfiles/stream-management-api.compose-cpu.yaml new file mode 100644 index 000000000..8dbcba6cc --- /dev/null +++ b/docker/dockerfiles/stream-management-api.compose-cpu.yaml @@ -0,0 +1,21 @@ +services: + management-api: + build: + context: ${PWD} + dockerfile: docker/dockerfiles/Dockerfile.stream_management_api + ports: + - "8080:8080" + environment: + - STREAM_MANAGER_HOST=stream_manager + - STREAM_MANAGER_PORT=7070 + depends_on: + - stream_manager + stream_manager: + build: + context: ${PWD} + dockerfile: docker/dockerfiles/Dockerfile.onnx.cpu.stream_manager + privileged: true + environment: + - STREAM_MANAGER_HOST=0.0.0.0 + ports: + - "7070:7070" \ No newline at end of file diff --git a/docker/dockerfiles/stream-management-api.compose-gpu.yaml b/docker/dockerfiles/stream-management-api.compose-gpu.yaml new file mode 100644 index 000000000..bbe67abe1 --- /dev/null +++ b/docker/dockerfiles/stream-management-api.compose-gpu.yaml @@ -0,0 +1,22 @@ +services: + management-api: + build: + context: ${PWD} + dockerfile: docker/dockerfiles/Dockerfile.stream_management_api + ports: + - "8080:8080" + environment: + - STREAM_MANAGER_HOST=stream_manager + - STREAM_MANAGER_PORT=7070 + depends_on: + - stream_manager + stream_manager: + build: + context: ${PWD} + dockerfile: docker/dockerfiles/Dockerfile.onnx.gpu.stream_manager + privileged: true + runtime: nvidia + environment: + - STREAM_MANAGER_HOST=0.0.0.0 + ports: + - "7070:7070" \ No newline at end of file diff --git a/docker/dockerfiles/stream-management-api.compose-jetson.5.1.1.yaml b/docker/dockerfiles/stream-management-api.compose-jetson.5.1.1.yaml new file mode 100644 index 000000000..f925b6721 --- /dev/null +++ b/docker/dockerfiles/stream-management-api.compose-jetson.5.1.1.yaml @@ -0,0 +1,22 @@ +services: + management-api: + build: + context: ${PWD} + dockerfile: docker/dockerfiles/Dockerfile.stream_management_api + ports: + - "8080:8080" + environment: + - STREAM_MANAGER_HOST=stream_manager + - STREAM_MANAGER_PORT=7070 + depends_on: + - stream_manager + stream_manager: + build: + context: ${PWD} + dockerfile: docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1.stream_manager + privileged: true + runtime: nvidia + environment: + - STREAM_MANAGER_HOST=0.0.0.0 + ports: + - "7070:7070" \ No newline at end of file diff --git a/inference/enterprise/stream_management/README.md b/inference/enterprise/stream_management/README.md new file mode 100644 index 000000000..c8c6b0ef5 --- /dev/null +++ b/inference/enterprise/stream_management/README.md @@ -0,0 +1,299 @@ +# Stream Management + +## Overview +This feature is designed to cater to users requiring the execution of inference to generate predictions using +Roboflow object-detection models, particularly when dealing with online video streams. +It enhances the functionalities of the familiar `inference.Stream()` and `InferencePipeline()` interfaces, as found in +the open-source version of the library, by introducing a sophisticated management layer. The inclusion of additional +capabilities empowers users to remotely manage the state of inference pipelines through the HTTP management interface +integrated into this package. + +This functionality proves beneficial in various scenarios, **including but not limited to**: + +* Performing inference across multiple online video streams simultaneously. +* Executing inference on multiple devices that necessitate coordination. +* Establishing a monitoring layer to oversee video processing based on the `inference` package. + + +## Design +![Stream Management - design](./assets/stream_management_api_design.jpg) + + +## Example use-case + +Joe aims to monitor objects within the footage captured by a fleet of IP cameras installed in his factory. After +successfully training an object-detection model on the Roboflow platform, he is now prepared for deployment. With four +cameras in his factory, Joe opts for a model that is sufficiently compact, allowing for over 30 inferences per second +on his Jetson devices. Considering this computational budget per device, Joe determines that he requires two Jetson +devices to efficiently process footage from all cameras, anticipating an inference throughput of approximately +15 frames per second for each video source. + +To streamline the deployment, Joe chooses to deploy Stream Management containers to all available Jetson devices within +his local network. This setup enables him to communicate with each Jetson device via HTTP, facilitating the +orchestration of processing tasks. Joe develops a web app through which he can send commands to the devices and retrieve +metrics regarding the statuses of the video streams. + +Finally, Joe implements a UDP server capable of receiving predictions, leveraging the `supervision` package to +effectively track objects in the footage. This comprehensive approach allows Joe to manage and monitor the +object-detection process seamlessly across his fleet of Jetson devices. + +## How to run? + +### In docker - using `docker compose` +The most prevalent use-cases are conveniently encapsulated with Docker Compose configurations, ensuring readiness for +immediate use. Nevertheless, in specific instances where custom configuration adjustments are required within Docker +containers, such as passing camera devices, alternative options may prove more suitable. + +#### CPU-based devices +```bash +repository_root$ docker compose -f ./docker/dockerfiles/stream-management-api.compose-cpu.yaml up +``` + +#### GPU-based devices +```bash +repository_root$ docker compose -f ./docker/dockerfiles/stream-management-api.compose-gpu.yaml up +``` + +#### Jetson devices (`JetPack 5.1.1`) +```bash +repository_root$ docker-compose -f ./docker/dockerfiles/stream-management-api.compose-jetson.5.1.1.yaml up +``` + +**Disclaimer:** At Jetson devices, some operations (like container bootstrap or initialisation of model) takes more time +than for other ones. In particular - docker compose definition in current form do not define active awaiting +TCP socket port to be opened by Stream Manager - which means that initial requests to HTTP API may be responded with +HTTP 503. + +### In docker - running API and stream manager containers separately + +#### Run + +##### CPU-based devices +```bash +docker run -d --name stream_manager --network host roboflow/roboflow-inference-stream-manager-cpu:latest +docker run -d --name stream_management_api --network host roboflow/roboflow-inference-stream-management-api:latest +``` + +##### GPU-based devices +```bash +docker run -d --name stream_manager --network host --runtime nvidia roboflow/roboflow-inference-stream-manager-gpu:latest +docker run -d --name stream_management_api --network host roboflow/roboflow-inference-stream-management-api:latest +``` + +##### Jetson devices (`JetPack 5.1.1`) +```bash +docker run -d --name stream_manager --network host --runtime nvidia roboflow/roboflow-inference-stream-manager-jetson-5.1.1:latest +docker run -d --name stream_management_api --network host roboflow/roboflow-inference-stream-management-api:latest +``` + +#### Configuration parameters + +##### Stream Management API +* `STREAM_MANAGER_HOST` - hostname for stream manager container (alter with container name if `--network host` not used +or used against remote machine) +* `STREAM_MANAGER_PORT` - port to communicate with stream manager (must match with stream manager container) + +##### Stream Manager +* `PORT` - port at which server will be running +* one can mount volume under container's `/tmp/cache` to enable permanent storage of models - for faster inference +pipelines initialisation +* at the level of this container the connectivity to camera must be enabled - so if device passing to docker must +happen - it should happen at this stage + +#### Build (Optional) + +##### Stream Management API +```bash +docker build -t roboflow/roboflow-inference-stream-management-api:dev -f docker/dockerfiles/Dockerfile.stream_management_api . +``` + +##### Stream Manager +```bash +docker build -t roboflow/roboflow-inference-stream-manager-{device}:dev -f docker/dockerfiles/Dockerfile.onnx.{device}.stream_manager . +``` + +### Bare-metal deployment +In some cases, it would be required to deploy the application at host level. This is possible, although +client must resolve the environment in a way that is presented in Stream Manager and Stream Management API dockerfiles +appropriate for specific platform. Once this is done the following command should be run: + +```bash +repository_root$ python -m inference.enterprise.stream_management.manager.app # runs manager +``` + +```bash +repository_root$ python -m inference.enterprise.stream_management.api.app # runs management API +``` + +## How to integrate? +After running `roboflow-inference-stream-management-api` container, HTTP API will be available under +`http://127.0.0.1:8080` (given that default configuration is used). + +One can call `wget http://127.0.0.1:8080/openapi.json` to get OpenApi specification of API that can be rendered +[here](https://editor.swagger.io/) + +Example Python client is provided below: +```python +import requests +from typing import Optional + +URL = "http://127.0.0.1:8080" + +def list_pipelines() -> dict: + response = requests.get(f"{URL}/list_pipelines") + return response.json() + + +def get_pipeline_status(pipeline_id: str) -> dict: + response = requests.get(f"{URL}/status/{pipeline_id}") + return response.json() + + +def pause_pipeline(pipeline_id: str) -> dict: + response = requests.post(f"{URL}/pause/{pipeline_id}") + return response.json() + + +def resume_pipeline(pipeline_id: str) -> dict: + response = requests.post(f"{URL}/resume/{pipeline_id}") + return response.json() + +def terminate_pipeline(pipeline_id: str) -> dict: + response = requests.post(f"{URL}/terminate/{pipeline_id}") + return response.json() + +def initialise_pipeline( + video_reference: str, + model_id: str, + api_key: str, + sink_host: str, + sink_port: int, + max_fps: Optional[int] = None, +) -> dict: + response = requests.post( + f"{URL}/initialise", + json={ + "type": "init", + "sink_configuration": { + "type": "udp_sink", + "host": sink_host, + "port": sink_port, + }, + "video_reference": video_reference, + "model_id": model_id, + "api_key": api_key, + "max_fps": max_fps, + + }, + ) + return response.json() +``` + +### Important notes +* Please remember that `initialise_pipeline()` must be filled with `video_reference` and `sink_configuration` +in such a way, that any resource (video file / camera device) or URI (stream reference, sink reference) **must be +reachable from Stream Manager environment!** For instance - in some cases inside docker containers `localhost` will +be bound into **container localhost** not the localhost of the machine hosting container. + +## Developer notes +The pivotal element of the implementation is the Stream Manager component, operating as an application in +single-threaded, TCP-server mode. It systematically processes requests received from a TCP socket, +taking on the responsibility of spawning and overseeing processes that run the `InferencePipelineManager`. +Communication between the `InferencePipelineManager` processes and the main process of the Stream Manager occurs +through multiprocessing queues. These queues facilitate the exchange of input commands and the retrieval of results. + +Requests directed to the Stream Manager are sequentially handled in blocking mode, +ensuring that each request must conclude before the initiation of the next one. + +### Communication protocol - requests +Stream Manager accepts the following binary protocol in communication. Each communication payload contains: +``` +[HEADER: 4B, big-endian, not signed - int value with message size][MESSAGE: utf-8 serialised json of size dictated by header] +``` +Message must be a valid JSON after decoding and represent valid command. + +#### `list_pipelines` command +```json +{ + "type": "list_pipelines" +} +``` + +#### `init` command +```json +{ + "type": "init", + "model_id": "some/1", + "video_reference": "rtsp://192.168.0.1:554", + "sink_configuration": { + "type": "udp_sink", + "host": "192.168.0.3", + "port": 9999 + }, + "api_key": "YOUR-API-KEY", + "max_fps": 16, + "model_configuration": { + "type": "object-detection", + "class_agnostic_nms": true, + "confidence": 0.5, + "iou_threshold": 0.4, + "max_candidates": 300, + "max_detections": 3000 + } +} +``` + +#### `terminate` command +```json +{ + "type": "terminate", + "pipeline_id": "my_pipeline" +} +``` + +#### `pause` command +```json +{ + "type": "mute", + "pipeline_id": "my_pipeline" +} +``` + +#### `resume` command +```json +{ + "type": "resume", + "pipeline_id": "my_pipeline" +} +``` + +#### `status` command +```json +{ + "type": "status", + "pipeline_id": "my_pipeline" +} +``` + +### Communication protocol - responses +Stream Manager, for each request that can be processed (without timeout or source disconnection), will return the +result in a format: +``` +[HEADER: 4B, big-endian, not signed - int value with result size][RESULT: utf-8 serialised json of size dictated by header] +``` + +Structure of result: +* `request_id` - field with random string representing request id assigned by Stream Manager - to ease debugging +* `pipeline_id` - if command from request can be associated to specific pipeline - its ID will be denoted in response +* `response` - payload of operation response + +Each `response` has the `status` key with two values possible: `success` or `failure` to denote operation status. +Each failed response contain `error_type` key to dispatch error handling and optional fields `error_class` and +`error_message` representing inner details of error. + +Content of successful responses depends on type of operation. + + +## Future work +* securing API connection layer (to enable safe remote control) +* securing TCP socket of Stream Manager \ No newline at end of file diff --git a/inference/enterprise/stream_management/__init__.py b/inference/enterprise/stream_management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/enterprise/stream_management/api/__init__.py b/inference/enterprise/stream_management/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/enterprise/stream_management/api/app.py b/inference/enterprise/stream_management/api/app.py new file mode 100644 index 000000000..0b228d7e0 --- /dev/null +++ b/inference/enterprise/stream_management/api/app.py @@ -0,0 +1,178 @@ +import os +from functools import wraps +from typing import Any, Awaitable, Callable + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from inference.core import logger +from inference.enterprise.stream_management.api.entities import ( + CommandResponse, + InferencePipelineStatusResponse, + ListPipelinesResponse, + PipelineInitialisationRequest, +) +from inference.enterprise.stream_management.api.errors import ( + ConnectivityError, + ProcessesManagerAuthorisationError, + ProcessesManagerClientError, + ProcessesManagerInvalidPayload, + ProcessesManagerNotFoundError, +) +from inference.enterprise.stream_management.api.stream_manager_client import ( + StreamManagerClient, +) +from inference.enterprise.stream_management.manager.entities import ( + STATUS_KEY, + OperationStatus, +) + +API_HOST = os.getenv("STREAM_MANAGEMENT_API_HOST", "127.0.0.1") +API_PORT = int(os.getenv("STREAM_MANAGEMENT_API_PORT", "8080")) + +OPERATIONS_TIMEOUT = os.getenv("STREAM_MANAGER_OPERATIONS_TIMEOUT") +if OPERATIONS_TIMEOUT is not None: + OPERATIONS_TIMEOUT = float(OPERATIONS_TIMEOUT) + +STREAM_MANAGER_CLIENT = StreamManagerClient.init( + host=os.getenv("STREAM_MANAGER_HOST", "127.0.0.1"), + port=int(os.getenv("STREAM_MANAGER_PORT", "7070")), + operations_timeout=OPERATIONS_TIMEOUT, +) + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def with_route_exceptions(route: callable) -> Callable[[Any], Awaitable[JSONResponse]]: + @wraps(route) + async def wrapped_route(*args, **kwargs): + try: + return await route(*args, **kwargs) + except ProcessesManagerInvalidPayload as error: + resp = JSONResponse( + status_code=400, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager - invalid payload error") + return resp + except ProcessesManagerAuthorisationError as error: + resp = JSONResponse( + status_code=401, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager - authorisation error") + return resp + except ProcessesManagerNotFoundError as error: + resp = JSONResponse( + status_code=404, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager - not found error") + return resp + except ConnectivityError as error: + resp = JSONResponse( + status_code=503, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager connectivity error occurred") + return resp + except ProcessesManagerClientError as error: + resp = JSONResponse( + status_code=500, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager error occurred") + return resp + except Exception: + resp = JSONResponse( + status_code=500, + content={ + STATUS_KEY: OperationStatus.FAILURE, + "message": "Internal error.", + }, + ) + logger.exception("Internal error in API") + return resp + + return wrapped_route + + +@app.get( + "/list_pipelines", + response_model=ListPipelinesResponse, + summary="List active pipelines", + description="Listing all active pipelines in the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def list_pipelines(_: Request) -> ListPipelinesResponse: + return await STREAM_MANAGER_CLIENT.list_pipelines() + + +@app.get( + "/status/{pipeline_id}", + response_model=InferencePipelineStatusResponse, + summary="Get status of pipeline", + description="Returns detailed statis of Inference Pipeline in the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def get_status(pipeline_id: str) -> InferencePipelineStatusResponse: + return await STREAM_MANAGER_CLIENT.get_status(pipeline_id=pipeline_id) + + +@app.post( + "/initialise", + response_model=CommandResponse, + summary="Initialise the pipeline", + description="Starts new Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def initialise(request: PipelineInitialisationRequest) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.initialise_pipeline( + initialisation_request=request + ) + + +@app.post( + "/pause/{pipeline_id}", + response_model=CommandResponse, + summary="Pauses the pipeline processing", + description="Mutes the VideoSource of Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def pause(pipeline_id: str) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.pause_pipeline(pipeline_id=pipeline_id) + + +@app.post( + "/resume/{pipeline_id}", + response_model=CommandResponse, + summary="Resumes the pipeline processing", + description="Resumes the VideoSource of Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def resume(pipeline_id: str) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.resume_pipeline(pipeline_id=pipeline_id) + + +@app.post( + "/terminate/{pipeline_id}", + response_model=CommandResponse, + summary="Terminates the pipeline processing", + description="Terminates the VideoSource of Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def terminate(pipeline_id: str) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.terminate_pipeline(pipeline_id=pipeline_id) + + +if __name__ == "__main__": + uvicorn.run(app, host=API_HOST, port=API_PORT) diff --git a/inference/enterprise/stream_management/api/entities.py b/inference/enterprise/stream_management/api/entities.py new file mode 100644 index 000000000..fe77b9b39 --- /dev/null +++ b/inference/enterprise/stream_management/api/entities.py @@ -0,0 +1,91 @@ +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + +from inference.core.interfaces.camera.video_source import ( + BufferConsumptionStrategy, + BufferFillingStrategy, +) + + +class UDPSinkConfiguration(BaseModel): + type: str = Field( + description="Type identifier field. Must be `udp_sink`", default="udp_sink" + ) + host: str = Field(description="Host of UDP sink.") + port: int = Field(description="Port of UDP sink.") + + +class ObjectDetectionModelConfiguration(BaseModel): + type: str = Field( + description="Type identifier field. Must be `object-detection`", + default="object-detection", + ) + class_agnostic_nms: Optional[bool] = Field( + description="Flag to decide if class agnostic NMS to be applied. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + confidence: Optional[float] = Field( + description="Confidence threshold for predictions. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + iou_threshold: Optional[float] = Field( + description="IoU threshold of post-processing. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + max_candidates: Optional[int] = Field( + description="Max candidates in post-processing. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + max_detections: Optional[int] = Field( + description="Max detections in post-processing. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + + +class PipelineInitialisationRequest(BaseModel): + model_id: str = Field(description="Roboflow model id") + video_reference: Union[str, int] = Field( + description="Reference to video source - either stream, video file or device. It must be accessible from the host running inference stream" + ) + sink_configuration: UDPSinkConfiguration = Field( + description="Configuration of the sink." + ) + api_key: str = Field(description="Roboflow API key") + max_fps: Optional[Union[float, int]] = Field( + description="Limit of FPS in video processing.", default=None + ) + source_buffer_filling_strategy: Optional[str] = Field( + description=f"`source_buffer_filling_strategy` parameter of Inference Pipeline (see docs). One of {[e.value for e in BufferFillingStrategy]}", + default=None, + ) + source_buffer_consumption_strategy: Optional[str] = Field( + description=f"`source_buffer_consumption_strategy` parameter of Inference Pipeline (see docs). One of {[e.value for e in BufferConsumptionStrategy]}", + default=None, + ) + model_configuration: ObjectDetectionModelConfiguration = Field( + description="Configuration of the model", + default_factory=ObjectDetectionModelConfiguration, + ) + + +class CommandContext(BaseModel): + request_id: Optional[str] = Field( + description="Server-side request ID", default=None + ) + pipeline_id: Optional[str] = Field( + description="Identifier of pipeline connected to operation", default=None + ) + + +class CommandResponse(BaseModel): + status: str = Field(description="Operation status") + context: CommandContext = Field(description="Context of the command.") + + +class InferencePipelineStatusResponse(CommandResponse): + report: dict + + +class ListPipelinesResponse(CommandResponse): + pipelines: List[str] = Field(description="List IDs of active pipelines") diff --git a/inference/enterprise/stream_management/api/errors.py b/inference/enterprise/stream_management/api/errors.py new file mode 100644 index 000000000..3669bac2a --- /dev/null +++ b/inference/enterprise/stream_management/api/errors.py @@ -0,0 +1,26 @@ +class ProcessesManagerClientError(Exception): + pass + + +class ConnectivityError(ProcessesManagerClientError): + pass + + +class ProcessesManagerInternalError(ProcessesManagerClientError): + pass + + +class ProcessesManagerOperationError(ProcessesManagerClientError): + pass + + +class ProcessesManagerInvalidPayload(ProcessesManagerClientError): + pass + + +class ProcessesManagerNotFoundError(ProcessesManagerClientError): + pass + + +class ProcessesManagerAuthorisationError(ProcessesManagerClientError): + pass diff --git a/inference/enterprise/stream_management/api/stream_manager_client.py b/inference/enterprise/stream_management/api/stream_manager_client.py new file mode 100644 index 000000000..e89e7e914 --- /dev/null +++ b/inference/enterprise/stream_management/api/stream_manager_client.py @@ -0,0 +1,288 @@ +import asyncio +import json +from asyncio import StreamReader, StreamWriter +from json import JSONDecodeError +from typing import Optional, Tuple + +from inference.core import logger +from inference.enterprise.stream_management.api.entities import ( + CommandContext, + CommandResponse, + InferencePipelineStatusResponse, + ListPipelinesResponse, + PipelineInitialisationRequest, +) +from inference.enterprise.stream_management.api.errors import ( + ConnectivityError, + ProcessesManagerAuthorisationError, + ProcessesManagerClientError, + ProcessesManagerInternalError, + ProcessesManagerInvalidPayload, + ProcessesManagerNotFoundError, + ProcessesManagerOperationError, +) +from inference.enterprise.stream_management.manager.entities import ( + ERROR_TYPE_KEY, + PIPELINE_ID_KEY, + REQUEST_ID_KEY, + RESPONSE_KEY, + STATUS_KEY, + TYPE_KEY, + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.errors import ( + CommunicationProtocolError, + MalformedHeaderError, + MalformedPayloadError, + MessageToBigError, + TransmissionChannelClosed, +) + +BUFFER_SIZE = 16384 +HEADER_SIZE = 4 + +ERRORS_MAPPING = { + ErrorType.INTERNAL_ERROR.value: ProcessesManagerInternalError, + ErrorType.INVALID_PAYLOAD.value: ProcessesManagerInvalidPayload, + ErrorType.NOT_FOUND.value: ProcessesManagerNotFoundError, + ErrorType.OPERATION_ERROR.value: ProcessesManagerOperationError, + ErrorType.AUTHORISATION_ERROR.value: ProcessesManagerAuthorisationError, +} + + +class StreamManagerClient: + @classmethod + def init( + cls, + host: str, + port: int, + operations_timeout: Optional[float] = None, + header_size: int = HEADER_SIZE, + buffer_size: int = BUFFER_SIZE, + ) -> "StreamManagerClient": + return cls( + host=host, + port=port, + operations_timeout=operations_timeout, + header_size=header_size, + buffer_size=buffer_size, + ) + + def __init__( + self, + host: str, + port: int, + operations_timeout: Optional[float], + header_size: int, + buffer_size: int, + ): + self._host = host + self._port = port + self._operations_timeout = operations_timeout + self._header_size = header_size + self._buffer_size = buffer_size + + async def list_pipelines(self) -> ListPipelinesResponse: + command = { + TYPE_KEY: CommandType.LIST_PIPELINES, + } + response = await self._handle_command(command=command) + status = response[RESPONSE_KEY][STATUS_KEY] + context = CommandContext( + request_id=response.get(REQUEST_ID_KEY), + pipeline_id=response.get(PIPELINE_ID_KEY), + ) + pipelines = response[RESPONSE_KEY]["pipelines"] + return ListPipelinesResponse( + status=status, + context=context, + pipelines=pipelines, + ) + + async def initialise_pipeline( + self, initialisation_request: PipelineInitialisationRequest + ) -> CommandResponse: + command = initialisation_request.dict(exclude_none=True) + command[TYPE_KEY] = CommandType.INIT + response = await self._handle_command(command=command) + return build_response(response=response) + + async def terminate_pipeline(self, pipeline_id: str) -> CommandResponse: + command = { + TYPE_KEY: CommandType.TERMINATE, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + return build_response(response=response) + + async def pause_pipeline(self, pipeline_id: str) -> CommandResponse: + command = { + TYPE_KEY: CommandType.MUTE, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + return build_response(response=response) + + async def resume_pipeline(self, pipeline_id: str) -> CommandResponse: + command = { + TYPE_KEY: CommandType.RESUME, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + return build_response(response=response) + + async def get_status(self, pipeline_id: str) -> InferencePipelineStatusResponse: + command = { + TYPE_KEY: CommandType.STATUS, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + status = response[RESPONSE_KEY][STATUS_KEY] + context = CommandContext( + request_id=response.get(REQUEST_ID_KEY), + pipeline_id=response.get(PIPELINE_ID_KEY), + ) + report = response[RESPONSE_KEY]["report"] + return InferencePipelineStatusResponse( + status=status, + context=context, + report=report, + ) + + async def _handle_command(self, command: dict) -> dict: + response = await send_command( + host=self._host, + port=self._port, + command=command, + header_size=self._header_size, + buffer_size=self._buffer_size, + timeout=self._operations_timeout, + ) + if is_request_unsuccessful(response=response): + dispatch_error(error_response=response) + return response + + +async def send_command( + host: str, + port: int, + command: dict, + header_size: int, + buffer_size: int, + timeout: Optional[float] = None, +) -> dict: + try: + reader, writer = await establish_socket_connection( + host=host, port=port, timeout=timeout + ) + await send_message( + writer=writer, message=command, header_size=header_size, timeout=timeout + ) + data = await receive_message( + reader, header_size=header_size, buffer_size=buffer_size, timeout=timeout + ) + writer.close() + await writer.wait_closed() + return json.loads(data) + except JSONDecodeError as error: + raise MalformedPayloadError( + f"Could not decode response. Cause: {error}" + ) from error + except (OSError, asyncio.TimeoutError) as errors: + raise ConnectivityError( + f"Could not communicate with Process Manager" + ) from errors + + +async def establish_socket_connection( + host: str, port: int, timeout: Optional[float] = None +) -> Tuple[StreamReader, StreamWriter]: + return await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) + + +async def send_message( + writer: StreamWriter, + message: dict, + header_size: int, + timeout: Optional[float] = None, +) -> None: + try: + body = json.dumps(message).encode("utf-8") + header = len(body).to_bytes(length=header_size, byteorder="big") + payload = header + body + writer.write(payload) + await asyncio.wait_for(writer.drain(), timeout=timeout) + except TypeError as error: + raise MalformedPayloadError(f"Could not serialise message. Details: {error}") + except OverflowError as error: + raise MessageToBigError( + f"Could not send message due to size overflow. Details: {error}" + ) + except asyncio.TimeoutError as error: + raise ConnectivityError( + f"Could not communicate with Process Manager" + ) from error + except Exception as error: + raise CommunicationProtocolError( + f"Could not send message. Cause: {error}" + ) from error + + +async def receive_message( + reader: StreamReader, + header_size: int, + buffer_size: int, + timeout: Optional[float] = None, +) -> bytes: + header = await asyncio.wait_for(reader.read(header_size), timeout=timeout) + if len(header) != header_size: + raise MalformedHeaderError("Header size missmatch") + payload_size = int.from_bytes(bytes=header, byteorder="big") + received = b"" + while len(received) < payload_size: + chunk = await asyncio.wait_for(reader.read(buffer_size), timeout=timeout) + if len(chunk) == 0: + raise TransmissionChannelClosed( + "Socket was closed to read before payload was decoded." + ) + received += chunk + return received + + +def is_request_unsuccessful(response: dict) -> bool: + return ( + response.get(RESPONSE_KEY, {}).get(STATUS_KEY, OperationStatus.FAILURE.value) + != OperationStatus.SUCCESS.value + ) + + +def dispatch_error(error_response: dict) -> None: + response_payload = error_response.get(RESPONSE_KEY, {}) + error_type = response_payload.get(ERROR_TYPE_KEY) + error_class = response_payload.get("error_class", "N/A") + error_message = response_payload.get("error_message", "N/A") + logger.error( + f"Error in ProcessesManagerClient. error_type={error_type} error_class={error_class} " + f"error_message={error_message}" + ) + if error_type in ERRORS_MAPPING: + raise ERRORS_MAPPING[error_type]( + f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" + ) + raise ProcessesManagerClientError( + f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" + ) + + +def build_response(response: dict) -> CommandResponse: + status = response[RESPONSE_KEY][STATUS_KEY] + context = CommandContext( + request_id=response.get(REQUEST_ID_KEY), + pipeline_id=response.get(PIPELINE_ID_KEY), + ) + return CommandResponse( + status=status, + context=context, + ) diff --git a/inference/enterprise/stream_management/assets/stream_management_api_design.jpg b/inference/enterprise/stream_management/assets/stream_management_api_design.jpg new file mode 100644 index 000000000..2afea9d06 Binary files /dev/null and b/inference/enterprise/stream_management/assets/stream_management_api_design.jpg differ diff --git a/inference/enterprise/stream_management/manager/__init__.py b/inference/enterprise/stream_management/manager/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/enterprise/stream_management/manager/app.py b/inference/enterprise/stream_management/manager/app.py new file mode 100644 index 000000000..7c5f836ba --- /dev/null +++ b/inference/enterprise/stream_management/manager/app.py @@ -0,0 +1,273 @@ +import os +import signal +import socket +import sys +from functools import partial +from multiprocessing import Process, Queue +from socketserver import BaseRequestHandler, BaseServer +from types import FrameType +from typing import Any, Dict, Optional, Tuple +from uuid import uuid4 + +from inference.core import logger +from inference.enterprise.stream_management.manager.communication import ( + receive_socket_data, + send_data_trough_socket, +) +from inference.enterprise.stream_management.manager.entities import ( + PIPELINE_ID_KEY, + STATUS_KEY, + TYPE_KEY, + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.errors import MalformedPayloadError +from inference.enterprise.stream_management.manager.inference_pipeline_manager import ( + InferencePipelineManager, +) +from inference.enterprise.stream_management.manager.serialisation import ( + describe_error, + prepare_error_response, + prepare_response, +) +from inference.enterprise.stream_management.manager.tcp_server import RoboflowTCPServer + +PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue]] = {} +HEADER_SIZE = 4 +SOCKET_BUFFER_SIZE = 16384 +HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1") +PORT = int(os.getenv("STREAM_MANAGER_PORT", "7070")) +SOCKET_TIMEOUT = float(os.getenv("STREAM_MANAGER_SOCKET_TIMEOUT", "5.0")) + + +class InferencePipelinesManagerHandler(BaseRequestHandler): + def __init__( + self, + request: socket.socket, + client_address: Any, + server: BaseServer, + processes_table: Dict[str, Tuple[Process, Queue, Queue]], + ): + self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes handle() + super().__init__(request, client_address, server) + + def handle(self) -> None: + pipeline_id: Optional[str] = None + request_id = str(uuid4()) + try: + data = receive_socket_data( + source=self.request, + header_size=HEADER_SIZE, + buffer_size=SOCKET_BUFFER_SIZE, + ) + data[TYPE_KEY] = CommandType(data[TYPE_KEY]) + if data[TYPE_KEY] is CommandType.LIST_PIPELINES: + return self._list_pipelines(request_id=request_id) + if data[TYPE_KEY] is CommandType.INIT: + return self._initialise_pipeline(request_id=request_id, command=data) + pipeline_id = data[PIPELINE_ID_KEY] + if data[TYPE_KEY] is CommandType.TERMINATE: + self._terminate_pipeline( + request_id=request_id, pipeline_id=pipeline_id, command=data + ) + else: + response = handle_command( + processes_table=self._processes_table, + request_id=request_id, + pipeline_id=pipeline_id, + command=data, + ) + serialised_response = prepare_response( + request_id=request_id, response=response, pipeline_id=pipeline_id + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + pipeline_id=pipeline_id, + ) + except (KeyError, ValueError, MalformedPayloadError) as error: + logger.error( + f"Invalid payload in processes manager. error={error} request_id={request_id}..." + ) + payload = prepare_error_response( + request_id=request_id, + error=error, + error_type=ErrorType.INVALID_PAYLOAD, + pipeline_id=pipeline_id, + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=payload, + request_id=request_id, + pipeline_id=pipeline_id, + ) + except Exception as error: + logger.error( + f"Internal error in processes manager. error={error} request_id={request_id}..." + ) + payload = prepare_error_response( + request_id=request_id, + error=error, + error_type=ErrorType.INTERNAL_ERROR, + pipeline_id=pipeline_id, + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=payload, + request_id=request_id, + pipeline_id=pipeline_id, + ) + + def _list_pipelines(self, request_id: str) -> None: + serialised_response = prepare_response( + request_id=request_id, + response={ + "pipelines": list(self._processes_table.keys()), + STATUS_KEY: OperationStatus.SUCCESS, + }, + pipeline_id=None, + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + ) + + def _initialise_pipeline(self, request_id: str, command: dict) -> None: + pipeline_id = str(uuid4()) + command_queue = Queue() + responses_queue = Queue() + inference_pipeline_manager = InferencePipelineManager.init( + command_queue=command_queue, + responses_queue=responses_queue, + ) + inference_pipeline_manager.start() + self._processes_table[pipeline_id] = ( + inference_pipeline_manager, + command_queue, + responses_queue, + ) + command_queue.put((request_id, command)) + response = get_response_ignoring_thrash( + responses_queue=responses_queue, matching_request_id=request_id + ) + serialised_response = prepare_response( + request_id=request_id, response=response, pipeline_id=pipeline_id + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + pipeline_id=pipeline_id, + ) + + def _terminate_pipeline( + self, request_id: str, pipeline_id: str, command: dict + ) -> None: + response = handle_command( + processes_table=self._processes_table, + request_id=request_id, + pipeline_id=pipeline_id, + command=command, + ) + if response[STATUS_KEY] is OperationStatus.SUCCESS: + logger.info( + f"Joining inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" + ) + join_inference_pipeline( + processes_table=self._processes_table, pipeline_id=pipeline_id + ) + logger.info( + f"Joined inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" + ) + serialised_response = prepare_response( + request_id=request_id, response=response, pipeline_id=pipeline_id + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + pipeline_id=pipeline_id, + ) + + +def handle_command( + processes_table: Dict[str, Tuple[Process, Queue, Queue]], + request_id: str, + pipeline_id: str, + command: dict, +) -> dict: + if pipeline_id not in processes_table: + return describe_error(exception=None, error_type=ErrorType.NOT_FOUND) + _, command_queue, responses_queue = processes_table[pipeline_id] + command_queue.put((request_id, command)) + return get_response_ignoring_thrash( + responses_queue=responses_queue, matching_request_id=request_id + ) + + +def get_response_ignoring_thrash( + responses_queue: Queue, matching_request_id: str +) -> dict: + while True: + response = responses_queue.get() + if response[0] == matching_request_id: + return response[1] + logger.warning( + f"Dropping response for request_id={response[0]} with payload={response[1]}" + ) + + +def execute_termination( + signal_number: int, + frame: FrameType, + processes_table: Dict[str, Tuple[Process, Queue, Queue]], +) -> None: + pipeline_ids = list(processes_table.keys()) + for pipeline_id in pipeline_ids: + logger.info(f"Terminating pipeline: {pipeline_id}") + processes_table[pipeline_id][0].terminate() + logger.info(f"Pipeline: {pipeline_id} terminated.") + logger.info(f"Joining pipeline: {pipeline_id}") + processes_table[pipeline_id][0].join() + logger.info(f"Pipeline: {pipeline_id} joined.") + logger.info(f"Termination handler completed.") + sys.exit(0) + + +def join_inference_pipeline( + processes_table: Dict[str, Tuple[Process, Queue, Queue]], pipeline_id: str +) -> None: + inference_pipeline_manager, command_queue, responses_queue = processes_table[ + pipeline_id + ] + inference_pipeline_manager.join() + del processes_table[pipeline_id] + + +if __name__ == "__main__": + signal.signal( + signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE) + ) + signal.signal( + signal.SIGTERM, partial(execute_termination, processes_table=PROCESSES_TABLE) + ) + with RoboflowTCPServer( + server_address=(HOST, PORT), + handler_class=partial( + InferencePipelinesManagerHandler, processes_table=PROCESSES_TABLE + ), + socket_operations_timeout=SOCKET_TIMEOUT, + ) as tcp_server: + logger.info( + f"Inference Pipeline Processes Manager is ready to accept connections at {(HOST, PORT)}" + ) + tcp_server.serve_forever() diff --git a/inference/enterprise/stream_management/manager/communication.py b/inference/enterprise/stream_management/manager/communication.py new file mode 100644 index 000000000..0c5f88849 --- /dev/null +++ b/inference/enterprise/stream_management/manager/communication.py @@ -0,0 +1,76 @@ +import json +import socket +from typing import Optional + +from inference.core import logger +from inference.enterprise.stream_management.manager.entities import ErrorType +from inference.enterprise.stream_management.manager.errors import ( + MalformedHeaderError, + MalformedPayloadError, + TransmissionChannelClosed, +) +from inference.enterprise.stream_management.manager.serialisation import ( + prepare_error_response, +) + + +def receive_socket_data( + source: socket.socket, header_size: int, buffer_size: int +) -> dict: + header = source.recv(header_size) + if len(header) != header_size: + raise MalformedHeaderError( + f"Expected header size: {header_size}, received: {header}" + ) + payload_size = int.from_bytes(bytes=header, byteorder="big") + if payload_size <= 0: + raise MalformedHeaderError( + f"Header is indicating non positive payload size: {payload_size}" + ) + received = b"" + while len(received) < payload_size: + chunk = source.recv(buffer_size) + if len(chunk) == 0: + raise TransmissionChannelClosed( + "Socket was closed to read before payload was decoded." + ) + received += chunk + try: + return json.loads(received) + except ValueError: + raise MalformedPayloadError("Received payload that is not in a JSON format") + + +def send_data_trough_socket( + target: socket.socket, + header_size: int, + data: bytes, + request_id: str, + recover_from_overflow: bool = True, + pipeline_id: Optional[str] = None, +) -> None: + try: + data_size = len(data) + header = data_size.to_bytes(length=header_size, byteorder="big") + payload = header + data + target.sendall(payload) + except OverflowError as error: + if not recover_from_overflow: + logger.error(f"OverflowError was suppressed. {error}") + return None + error_response = prepare_error_response( + request_id=request_id, + error=error, + error_type=ErrorType.INTERNAL_ERROR, + pipeline_id=pipeline_id, + ) + send_data_trough_socket( + target=target, + header_size=header_size, + data=error_response, + request_id=request_id, + recover_from_overflow=False, + pipeline_id=pipeline_id, + ) + except Exception as error: + logger.error(f"Could not send the response through socket. Error: {error}") diff --git a/inference/enterprise/stream_management/manager/entities.py b/inference/enterprise/stream_management/manager/entities.py new file mode 100644 index 000000000..cd02ee66c --- /dev/null +++ b/inference/enterprise/stream_management/manager/entities.py @@ -0,0 +1,32 @@ +from enum import Enum + +STATUS_KEY = "status" +TYPE_KEY = "type" +ERROR_TYPE_KEY = "error_type" +REQUEST_ID_KEY = "request_id" +PIPELINE_ID_KEY = "pipeline_id" +COMMAND_KEY = "command" +RESPONSE_KEY = "response" +ENCODING = "utf-8" + + +class OperationStatus(str, Enum): + SUCCESS = "success" + FAILURE = "failure" + + +class ErrorType(str, Enum): + INTERNAL_ERROR = "internal_error" + INVALID_PAYLOAD = "invalid_payload" + NOT_FOUND = "not_found" + OPERATION_ERROR = "operation_error" + AUTHORISATION_ERROR = "authorisation_error" + + +class CommandType(str, Enum): + INIT = "init" + MUTE = "mute" + RESUME = "resume" + STATUS = "status" + TERMINATE = "terminate" + LIST_PIPELINES = "list_pipelines" diff --git a/inference/enterprise/stream_management/manager/errors.py b/inference/enterprise/stream_management/manager/errors.py new file mode 100644 index 000000000..3a3df9a14 --- /dev/null +++ b/inference/enterprise/stream_management/manager/errors.py @@ -0,0 +1,18 @@ +class CommunicationProtocolError(Exception): + pass + + +class MessageToBigError(CommunicationProtocolError): + pass + + +class MalformedHeaderError(CommunicationProtocolError): + pass + + +class TransmissionChannelClosed(CommunicationProtocolError): + pass + + +class MalformedPayloadError(CommunicationProtocolError): + pass diff --git a/inference/enterprise/stream_management/manager/inference_pipeline_manager.py b/inference/enterprise/stream_management/manager/inference_pipeline_manager.py new file mode 100644 index 000000000..d5e4d6cb9 --- /dev/null +++ b/inference/enterprise/stream_management/manager/inference_pipeline_manager.py @@ -0,0 +1,257 @@ +import os +import signal +from dataclasses import asdict +from multiprocessing import Process, Queue +from types import FrameType +from typing import Callable, Optional, Tuple + +from inference.core import logger +from inference.core.exceptions import ( + MissingApiKeyError, + RoboflowAPINotAuthorizedError, + RoboflowAPINotNotFoundError, +) +from inference.core.interfaces.camera.entities import VideoFrame +from inference.core.interfaces.camera.exceptions import StreamOperationNotAllowedError +from inference.core.interfaces.camera.video_source import ( + BufferConsumptionStrategy, + BufferFillingStrategy, +) +from inference.core.interfaces.stream.entities import ObjectDetectionPrediction +from inference.core.interfaces.stream.inference_pipeline import InferencePipeline +from inference.core.interfaces.stream.sinks import UDPSink +from inference.core.interfaces.stream.watchdog import ( + BasePipelineWatchDog, + PipelineWatchDog, +) +from inference.enterprise.stream_management.manager.entities import ( + STATUS_KEY, + TYPE_KEY, + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.serialisation import describe_error + + +def ignore_signal(signal_number: int, frame: FrameType) -> None: + pid = os.getpid() + logger.info( + f"Ignoring signal {signal_number} in InferencePipelineManager in process:{pid}" + ) + + +class InferencePipelineManager(Process): + @classmethod + def init( + cls, command_queue: Queue, responses_queue: Queue + ) -> "InferencePipelineManager": + return cls(command_queue=command_queue, responses_queue=responses_queue) + + def __init__(self, command_queue: Queue, responses_queue: Queue): + super().__init__() + self._command_queue = command_queue + self._responses_queue = responses_queue + self._inference_pipeline: Optional[InferencePipeline] = None + self._watchdog: Optional[PipelineWatchDog] = None + self._stop = False + + def run(self) -> None: + signal.signal(signal.SIGINT, ignore_signal) + signal.signal(signal.SIGTERM, self._handle_termination_signal) + while not self._stop: + command: Optional[Tuple[str, dict]] = self._command_queue.get() + if command is None: + break + request_id, payload = command + self._handle_command(request_id=request_id, payload=payload) + + def _handle_command(self, request_id: str, payload: dict) -> None: + try: + logger.info(f"Processing request={request_id}...") + command_type = payload[TYPE_KEY] + if command_type is CommandType.INIT: + return self._initialise_pipeline(request_id=request_id, payload=payload) + if command_type is CommandType.TERMINATE: + return self._terminate_pipeline(request_id=request_id) + if command_type is CommandType.MUTE: + return self._mute_pipeline(request_id=request_id) + if command_type is CommandType.RESUME: + return self._resume_pipeline(request_id=request_id) + if command_type is CommandType.STATUS: + return self._get_pipeline_status(request_id=request_id) + raise NotImplementedError( + f"Command type `{command_type}` cannot be handled" + ) + except (KeyError, NotImplementedError) as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD + ) + except Exception as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.INTERNAL_ERROR + ) + + def _initialise_pipeline(self, request_id: str, payload: dict) -> None: + try: + watchdog = BasePipelineWatchDog() + sink = assembly_pipeline_sink(sink_config=payload["sink_configuration"]) + source_buffer_filling_strategy, source_buffer_consumption_strategy = ( + None, + None, + ) + if "source_buffer_filling_strategy" in payload: + source_buffer_filling_strategy = BufferFillingStrategy( + payload["source_buffer_filling_strategy"].upper() + ) + if "source_buffer_consumption_strategy" in payload: + source_buffer_consumption_strategy = BufferConsumptionStrategy( + payload["source_buffer_consumption_strategy"].upper() + ) + model_configuration = payload["model_configuration"] + if model_configuration["type"] != "object-detection": + raise NotImplementedError("Only object-detection models are supported") + self._inference_pipeline = InferencePipeline.init( + model_id=payload["model_id"], + video_reference=payload["video_reference"], + on_prediction=sink, + api_key=payload["api_key"], + max_fps=payload.get("max_fps"), + watchdog=watchdog, + source_buffer_filling_strategy=source_buffer_filling_strategy, + source_buffer_consumption_strategy=source_buffer_consumption_strategy, + class_agnostic_nms=model_configuration.get("class_agnostic_nms"), + confidence=model_configuration.get("confidence"), + iou_threshold=model_configuration.get("iou_threshold"), + max_candidates=model_configuration.get("max_candidates"), + max_detections=model_configuration.get("max_detections"), + ) + self._watchdog = watchdog + self._inference_pipeline.start(use_main_thread=False) + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + logger.info(f"Pipeline initialised. request_id={request_id}...") + except (MissingApiKeyError, KeyError, NotImplementedError) as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD + ) + except RoboflowAPINotAuthorizedError as error: + self._handle_error( + request_id=request_id, + error=error, + error_type=ErrorType.AUTHORISATION_ERROR, + ) + except RoboflowAPINotNotFoundError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.NOT_FOUND + ) + + def _terminate_pipeline(self, request_id: str) -> None: + if self._inference_pipeline is None: + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + self._stop = True + return None + try: + self._execute_termination() + logger.info(f"Pipeline terminated. request_id={request_id}...") + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _handle_termination_signal(self, signal_number: int, frame: FrameType) -> None: + try: + pid = os.getpid() + logger.info(f"Terminating pipeline in process:{pid}...") + if self._inference_pipeline is not None: + self._execute_termination() + self._command_queue.put(None) + logger.info(f"Termination successful in process:{pid}...") + except Exception as error: + logger.warning(f"Could not terminate pipeline gracefully. Error: {error}") + + def _execute_termination(self) -> None: + self._inference_pipeline.terminate() + self._inference_pipeline.join() + self._stop = True + + def _mute_pipeline(self, request_id: str) -> None: + if self._inference_pipeline is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + try: + self._inference_pipeline.mute_stream() + logger.info(f"Pipeline muted. request_id={request_id}...") + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _resume_pipeline(self, request_id: str) -> None: + if self._inference_pipeline is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + try: + self._inference_pipeline.resume_stream() + logger.info(f"Pipeline resumed. request_id={request_id}...") + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _get_pipeline_status(self, request_id: str) -> None: + if self._watchdog is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + try: + report = self._watchdog.get_report() + if report is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + response_payload = { + STATUS_KEY: OperationStatus.SUCCESS, + "report": asdict(report), + } + self._responses_queue.put((request_id, response_payload)) + logger.info(f"Pipeline status returned. request_id={request_id}...") + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _handle_error( + self, + request_id: str, + error: Optional[Exception] = None, + error_type: ErrorType = ErrorType.INTERNAL_ERROR, + ): + logger.error( + f"Could not handle Command. request_id={request_id}, error={error}, error_type={error_type}" + ) + response_payload = describe_error(error, error_type=error_type) + self._responses_queue.put((request_id, response_payload)) + + +def assembly_pipeline_sink( + sink_config: dict, +) -> Callable[[ObjectDetectionPrediction, VideoFrame], None]: + if sink_config["type"] != "udp_sink": + raise NotImplementedError("Only `udp_socket` sink type is supported") + sink = UDPSink.init(ip_address=sink_config["host"], port=sink_config["port"]) + return sink.send_predictions diff --git a/inference/enterprise/stream_management/manager/serialisation.py b/inference/enterprise/stream_management/manager/serialisation.py new file mode 100644 index 000000000..69a2c8787 --- /dev/null +++ b/inference/enterprise/stream_management/manager/serialisation.py @@ -0,0 +1,60 @@ +import json +from datetime import date, datetime +from enum import Enum +from typing import Any, Optional + +from inference.enterprise.stream_management.manager.entities import ( + ENCODING, + ERROR_TYPE_KEY, + PIPELINE_ID_KEY, + REQUEST_ID_KEY, + RESPONSE_KEY, + STATUS_KEY, + ErrorType, + OperationStatus, +) + + +def serialise_to_json(obj: Any) -> Any: + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if issubclass(type(obj), Enum): + return obj.value + raise TypeError(f"Type {type(obj)} not serializable") + + +def describe_error( + exception: Optional[Exception] = None, + error_type: ErrorType = ErrorType.INTERNAL_ERROR, +) -> dict: + payload = { + STATUS_KEY: OperationStatus.FAILURE, + ERROR_TYPE_KEY: error_type, + } + if exception is not None: + payload["error_class"] = exception.__class__.__name__ + payload["error_message"] = str(exception) + return payload + + +def prepare_error_response( + request_id: str, error: Exception, error_type: ErrorType, pipeline_id: Optional[str] +) -> bytes: + error_description = describe_error(exception=error, error_type=error_type) + return prepare_response( + request_id=request_id, response=error_description, pipeline_id=pipeline_id + ) + + +def prepare_response( + request_id: str, response: dict, pipeline_id: Optional[str] +) -> bytes: + payload = json.dumps( + { + REQUEST_ID_KEY: request_id, + RESPONSE_KEY: response, + PIPELINE_ID_KEY: pipeline_id, + }, + default=serialise_to_json, + ) + return payload.encode(ENCODING) diff --git a/inference/enterprise/stream_management/manager/tcp_server.py b/inference/enterprise/stream_management/manager/tcp_server.py new file mode 100644 index 000000000..5f951b1b6 --- /dev/null +++ b/inference/enterprise/stream_management/manager/tcp_server.py @@ -0,0 +1,19 @@ +import socket +from socketserver import BaseRequestHandler, TCPServer +from typing import Any, Optional, Tuple, Type + + +class RoboflowTCPServer(TCPServer): + def __init__( + self, + server_address: Tuple[str, int], + handler_class: Type[BaseRequestHandler], + socket_operations_timeout: Optional[float] = None, + ): + TCPServer.__init__(self, server_address, handler_class) + self._socket_operations_timeout = socket_operations_timeout + + def get_request(self) -> Tuple[socket.socket, Any]: + connection, address = self.socket.accept() + connection.settimeout(self._socket_operations_timeout) + return connection, address diff --git a/requirements/requirements.test.unit.txt b/requirements/requirements.test.unit.txt index c1fe79f73..d24435d78 100644 --- a/requirements/requirements.test.unit.txt +++ b/requirements/requirements.test.unit.txt @@ -10,3 +10,5 @@ flake8 rich pytest-asyncio<=0.21.1 pytest-timeout>=2.2.0 +httpx +uvicorn>=0.24.0 \ No newline at end of file diff --git a/tests/inference/unit_tests/enterprise/__init__.py b/tests/inference/unit_tests/enterprise/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/unit_tests/enterprise/stream_management/__init__.py b/tests/inference/unit_tests/enterprise/stream_management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/unit_tests/enterprise/stream_management/api/__init__.py b/tests/inference/unit_tests/enterprise/stream_management/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py b/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py new file mode 100644 index 000000000..37500bb56 --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py @@ -0,0 +1,251 @@ +from unittest import mock +from unittest.mock import AsyncMock + +from fastapi.testclient import TestClient + +from inference.enterprise.stream_management.api import app +from inference.enterprise.stream_management.api.entities import ( + CommandContext, + CommandResponse, + InferencePipelineStatusResponse, + ListPipelinesResponse, +) +from inference.enterprise.stream_management.api.errors import ( + ConnectivityError, + ProcessesManagerNotFoundError, +) + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_list_pipelines_when_communication_with_stream_manager_abused( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.list_pipelines.side_effect = ConnectivityError( + "Could not connect" + ) + + # when + response = client.get("/list_pipelines") + + # then + assert ( + response.status_code == 503 + ), "Status code when connectivity error occurs should be 503" + assert ( + response.json()["status"] == "failure" + ), "Failure must be denoted in response payload" + assert ( + len(response.json()["message"]) > 0 + ), "Message must be denoted in response payload" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_list_pipelines_when_communication_with_stream_manager_possible( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.list_pipelines.return_value = ListPipelinesResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id=None), + pipelines=["a", "b", "c"], + ) + + # when + response = client.get("/list_pipelines") + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": None, + }, + "pipelines": ["a", "b", "c"], + }, "ListPipelinesResponse must be serialised directly to JSON response" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_get_pipeline_status_when_pipeline_found( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.get_status.return_value = InferencePipelineStatusResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + report={"my": "report"}, # this is mock data + ) + + # when + response = client.get("/status/my_pipeline") + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": "my_pipeline", + }, + "report": {"my": "report"}, + }, "InferencePipelineStatusResponse must be serialised directly to JSON response" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_get_pipeline_status_when_pipeline_not_found( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.get_status.side_effect = ProcessesManagerNotFoundError( + "Pipeline not found" + ) + + # when + response = client.get("/status/my_pipeline") + + # then + assert response.status_code == 404, "Status code for not found must be 404" + assert ( + response.json()["status"] == "failure" + ), "Failure must be denoted in response payload" + assert ( + len(response.json()["message"]) > 0 + ), "Message must be denoted in response payload" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_initialise_pipeline_when_invalid_payload_given( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.initialise_pipeline.return_value = CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + + # when + response = client.post("/initialise") + + # then + assert ( + response.status_code == 422 + ), "Status code for invalid input entity must be 422" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_initialise_pipeline_when_valid_payload_given( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.initialise_pipeline.return_value = CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + + # when + response = client.post( + "/initialise", + json={ + "model_id": "some/1", + "video_reference": "rtsp://some:543", + "sink_configuration": { + "type": "udp_sink", + "host": "127.0.0.1", + "port": 9090, + }, + "api_key": "my_api_key", + "model_configuration": {"type": "object_detection"}, + }, + ) + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": "my_pipeline", + }, + }, "CommandResponse must be serialised directly to JSON response" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_pause_pipeline_when_successful_response_expected( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.pause_pipeline.return_value = CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + + # when + response = client.post("/pause/my_pipeline") + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": "my_pipeline", + }, + }, "CommandResponse must be serialised directly to JSON response" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_resume_pipeline_when_successful_response_expected( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.resume_pipeline.return_value = CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + + # when + response = client.post("/resume/my_pipeline") + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": "my_pipeline", + }, + }, "CommandResponse must be serialised directly to JSON response" + + +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_terminate_pipeline_when_successful_response_expected( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.terminate_pipeline.return_value = CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + + # when + response = client.post("/terminate/my_pipeline") + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": "my_pipeline", + }, + }, "CommandResponse must be serialised directly to JSON response" diff --git a/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py b/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py new file mode 100644 index 000000000..13db0670f --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py @@ -0,0 +1,729 @@ +import asyncio +import json +from typing import Type +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from inference.enterprise.stream_management.api import stream_manager_client +from inference.enterprise.stream_management.api.entities import ( + CommandContext, + CommandResponse, + InferencePipelineStatusResponse, + ListPipelinesResponse, + ObjectDetectionModelConfiguration, + PipelineInitialisationRequest, + UDPSinkConfiguration, +) +from inference.enterprise.stream_management.api.errors import ( + ConnectivityError, + ProcessesManagerAuthorisationError, + ProcessesManagerClientError, + ProcessesManagerInternalError, + ProcessesManagerInvalidPayload, + ProcessesManagerNotFoundError, + ProcessesManagerOperationError, +) +from inference.enterprise.stream_management.api.stream_manager_client import ( + StreamManagerClient, + build_response, + dispatch_error, + is_request_unsuccessful, + receive_message, + send_command, + send_message, +) +from inference.enterprise.stream_management.manager.entities import CommandType +from inference.enterprise.stream_management.manager.errors import ( + CommunicationProtocolError, + MalformedHeaderError, + MalformedPayloadError, + MessageToBigError, + TransmissionChannelClosed, +) + + +def test_build_response_when_all_optional_fields_are_filled() -> None: + # given + response = { + "response": {"status": "failure"}, + "request_id": "my_request", + "pipeline_id": "my_pipeline", + } + + # when + result = build_response(response=response) + + # then + assert result == CommandResponse( + status="failure", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ), "Assembled response must indicate failure and context with request id and pipeline id denoted" + + +def test_build_response_when_all_optional_fields_are_missing() -> None: + # given + response = { + "response": {"status": "failure"}, + } + + # when + result = build_response(response=response) + + # then + assert result == CommandResponse( + status="failure", + context=CommandContext(request_id=None, pipeline_id=None), + ), "Assembled response must indicate failure and empty context" + + +@pytest.mark.parametrize( + "error_type, expected_error", + [ + ("internal_error", ProcessesManagerInternalError), + ("invalid_payload", ProcessesManagerInvalidPayload), + ("not_found", ProcessesManagerNotFoundError), + ("operation_error", ProcessesManagerOperationError), + ("authorisation_error", ProcessesManagerAuthorisationError), + ], +) +def test_dispatch_error_when_known_error_is_detected( + error_type: str, expected_error: Type[Exception] +) -> None: + # given + error_response = { + "response": { + "status": "failure", + "error_type": error_type, + } + } + + # when + with pytest.raises(expected_error): + dispatch_error(error_response=error_response) + + +def test_dispatch_error_when_unknown_error_is_detected() -> None: + # given + error_response = { + "response": { + "status": "failure", + "error_type": "unknown", + } + } + + # when + with pytest.raises(ProcessesManagerClientError): + dispatch_error(error_response=error_response) + + +def test_dispatch_error_when_malformed_payload_is_detected() -> None: + # given + error_response = {"response": {"status": "failure"}} + + # when + with pytest.raises(ProcessesManagerClientError): + dispatch_error(error_response=error_response) + + +def test_is_request_unsuccessful_when_successful_response_given() -> None: + # given + response = {"response": {"status": "success"}} + + # when + result = is_request_unsuccessful(response=response) + + # then + assert ( + result is False + ), "Success status denoted should be assumed as sign of request success" + + +def test_is_request_unsuccessful_when_unsuccessful_response_given() -> None: + # given + error_response = { + "response": { + "status": "failure", + "error_type": "not_found", + } + } + + # when + result = is_request_unsuccessful(response=error_response) + + # then + assert result is True, "Explicitly failed response is indication of failed response" + + +def test_is_request_unsuccessful_when_malformed_response_given() -> None: + # given + response = {"response": {"some": "data"}} + + # when + result = is_request_unsuccessful(response=response) + + # then + assert ( + result is True + ), "When success is not clearly demonstrated - failure is to be assumed" + + +class DummyStreamReader: + def __init__(self, read_buffer_content: bytes): + self._read_buffer_content = read_buffer_content + + async def read(self, n: int = -1) -> bytes: + if n == -1: + n = len(self._read_buffer_content) + to_return = self._read_buffer_content[:n] + self._read_buffer_content = self._read_buffer_content[n:] + return to_return + + +@pytest.mark.asyncio +async def test_receive_message_when_malformed_header_sent() -> None: + # given + header = 3 + reader = DummyStreamReader( + read_buffer_content=header.to_bytes(length=1, byteorder="big") + ) + + # when + with pytest.raises(MalformedHeaderError): + _ = await receive_message(reader=reader, header_size=4, buffer_size=512) + + +@pytest.mark.asyncio +async def test_receive_message_when_payload_to_be_read_in_single_piece() -> None: + # given + data = b"DO OR NOT DO, THERE IS NO TRY" + payload = len(data).to_bytes(length=4, byteorder="big") + data + reader = DummyStreamReader(read_buffer_content=payload) + + # when + result = await receive_message( + reader=reader, header_size=4, buffer_size=len(payload) + ) + + # then + assert ( + result == b"DO OR NOT DO, THERE IS NO TRY" + ), "Result must be exact to the data in payload" + + +@pytest.mark.asyncio +async def test_receive_message_when_payload_to_be_read_in_multiple_pieces() -> None: + # given + data = b"DO OR NOT DO, THERE IS NO TRY" + payload = len(data).to_bytes(length=4, byteorder="big") + data + reader = DummyStreamReader(read_buffer_content=payload) + + # when + result = await receive_message(reader=reader, header_size=4, buffer_size=1) + + # then + assert ( + result == b"DO OR NOT DO, THERE IS NO TRY" + ), "Result must be exact to the data in payload" + + +@pytest.mark.asyncio +async def test_receive_message_when_not_all_declared_bytes_received() -> None: + # given + data = b"DO OR NOT DO, THERE IS NO TRY" + payload = len(data).to_bytes(length=4, byteorder="big") + data[:5] + reader = DummyStreamReader(read_buffer_content=payload) + + # when + with pytest.raises(TransmissionChannelClosed): + _ = await receive_message(reader=reader, header_size=4, buffer_size=1) + + +@pytest.mark.asyncio +async def test_send_message_when_content_cannot_be_serialised() -> None: + # given + writer = AsyncMock() + + # when + with pytest.raises(MalformedPayloadError): + await send_message(writer=writer, message=set([1, 2, 3]), header_size=4) + + +@pytest.mark.asyncio +async def test_send_message_when_message_is_to_long_up_to_header_length() -> None: + # given + writer = AsyncMock() + message = {"data": [i for i in range(1024)]} + + # when + with pytest.raises(MessageToBigError): + await send_message(writer=writer, message=message, header_size=1) + + +@pytest.mark.asyncio +async def test_send_message_when_communication_problem_arises() -> None: + # given + writer = AsyncMock() + writer.drain.side_effect = IOError() + message = {"data": "some"} + + # when + with pytest.raises(CommunicationProtocolError): + await send_message(writer=writer, message=message, header_size=4) + + +@pytest.mark.asyncio +async def test_send_message_when_communication_succeeds() -> None: + # given + writer = AsyncMock() + message = {"data": "some"} + serialised_message = json.dumps(message).encode("utf-8") + expected_payload = ( + len(serialised_message).to_bytes(length=4, byteorder="big") + serialised_message + ) + + # when + await send_message(writer=writer, message=message, header_size=4) + + # then + writer.write.assert_called_once_with(expected_payload) + + +class DummyStreamWriter: + def __init__(self, operation_delay: float = 0.0): + self._write_buffer_content = b"" + self._operation_delay = operation_delay + + def get_content(self) -> bytes: + return self._write_buffer_content + + def write(self, payload: bytes) -> None: + self._write_buffer_content += payload + + async def drain(self) -> None: + await asyncio.sleep(self._operation_delay) + + def close(self) -> None: + pass + + async def wait_closed(self) -> None: + await asyncio.sleep(self._operation_delay) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_send_command_when_connectivity_problem_arises( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + establish_socket_connection_mock.side_effect = ConnectionError() + + # when + with pytest.raises(ConnectivityError): + _ = await send_command( + host="127.0.0.1", + port=7070, + command={}, + header_size=4, + buffer_size=16438, + timeout=0.1, + ) + + +@pytest.mark.asyncio +@pytest.mark.timeout(30) +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_send_command_when_timeout_is_raised( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = DummyStreamReader(read_buffer_content=b"") + establish_socket_connection_mock.return_value = ( + reader, + DummyStreamWriter(operation_delay=1.0), + ) + + # when + with pytest.raises(ConnectivityError): + _ = await send_command( + host="127.0.0.1", + port=7070, + command={}, + header_size=4, + buffer_size=16438, + timeout=0.1, + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_send_command_when_communication_successful( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={"response": {"status": "success"}}, header_size=4 + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + command = { + "type": CommandType.TERMINATE, + "pipeline_id": "my_pipeline", + } + + # when + result = await send_command( + host="127.0.0.1", port=7070, command=command, header_size=4, buffer_size=16438 + ) + + # then + assert result == {"response": {"status": "success"}} + assert_correct_command_sent( + writer=writer, + command=command, + header_size=4, + message="Expected to send termination command successfully", + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_send_command_when_response_payload_could_not_be_decoded( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + response_message = b"FOR SURE NOT A JSON" + response_payload = ( + len(response_message).to_bytes(length=4, byteorder="big") + response_message + ) + reader = DummyStreamReader(read_buffer_content=response_payload) + establish_socket_connection_mock.return_value = ( + reader, + DummyStreamWriter(operation_delay=1.0), + ) + command = { + "type": CommandType.TERMINATE, + "pipeline_id": "my_pipeline", + } + + # when + with pytest.raises(MalformedPayloadError): + _ = await send_command( + host="127.0.0.1", + port=7070, + command=command, + header_size=4, + buffer_size=16438, + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_list_pipelines( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "response": {"status": "success", "pipelines": ["a", "b", "c"]}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + expected_command = {"type": CommandType.LIST_PIPELINES} + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.list_pipelines() + + # then + assert result == ListPipelinesResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id=None), + pipelines=["a", "b", "c"], + ) + assert_correct_command_sent( + writer=writer, + command=expected_command, + header_size=4, + message="Expected list pipelines command to be sent", + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_initialise_pipeline( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "new_pipeline", + "response": {"status": "success"}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + initialisation_request = PipelineInitialisationRequest( + model_id="some/1", + video_reference="rtsp://some:543", + sink_configuration=UDPSinkConfiguration( + type="udp_sink", + host="127.0.0.1", + port=9090, + ), + api_key="my_api_key", + model_configuration=ObjectDetectionModelConfiguration(type="object_detection"), + ) + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.initialise_pipeline( + initialisation_request=initialisation_request + ) + + # then + assert result == CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="new_pipeline"), + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_terminate_pipeline( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "my_pipeline", + "response": {"status": "success"}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + expected_command = {"type": CommandType.TERMINATE, "pipeline_id": "my_pipeline"} + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.terminate_pipeline(pipeline_id="my_pipeline") + + # then + assert result == CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + assert_correct_command_sent( + writer=writer, + command=expected_command, + header_size=4, + message="Expected termination command to be sent", + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_pause_pipeline( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "my_pipeline", + "response": {"status": "success"}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + expected_command = {"type": CommandType.MUTE, "pipeline_id": "my_pipeline"} + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.pause_pipeline(pipeline_id="my_pipeline") + + # then + assert result == CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + assert_correct_command_sent( + writer=writer, + command=expected_command, + header_size=4, + message="Expected pause command to be sent", + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_resume_pipeline( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "my_pipeline", + "response": {"status": "success"}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + expected_command = {"type": CommandType.RESUME, "pipeline_id": "my_pipeline"} + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.resume_pipeline(pipeline_id="my_pipeline") + + # then + assert result == CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + assert_correct_command_sent( + writer=writer, + command=expected_command, + header_size=4, + message="Expected resume command to be sent", + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_get_pipeline_status( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "my_pipeline", + "response": {"status": "success", "report": {"my": "report"}}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + expected_command = {"type": CommandType.STATUS, "pipeline_id": "my_pipeline"} + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.get_status(pipeline_id="my_pipeline") + + # then + assert result == InferencePipelineStatusResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + report={"my": "report"}, # this is mock data + ) + assert_correct_command_sent( + writer=writer, + command=expected_command, + header_size=4, + message="Expected get info command to be sent", + ) + + +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_dispatch_error_response( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "my_pipeline", + "response": {"status": "failure", "error_type": "not_found"}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + expected_command = {"type": CommandType.RESUME, "pipeline_id": "my_pipeline"} + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + with pytest.raises(ProcessesManagerNotFoundError): + _ = await client.resume_pipeline(pipeline_id="my_pipeline") + + # then + + assert_correct_command_sent( + writer=writer, + command=expected_command, + header_size=4, + message="Expected resume command to be sent", + ) + + +def assembly_socket_reader(message: dict, header_size: int) -> DummyStreamReader: + serialised = json.dumps(message).encode("utf-8") + response_payload = ( + len(serialised).to_bytes(length=header_size, byteorder="big") + serialised + ) + return DummyStreamReader(read_buffer_content=response_payload) + + +def assert_correct_command_sent( + writer: DummyStreamWriter, command: dict, header_size: int, message: str +) -> None: + serialised_command = json.dumps(command).encode("utf-8") + payload = ( + len(serialised_command).to_bytes(length=header_size, byteorder="big") + + serialised_command + ) + assert writer.get_content() == payload, message diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/__init__.py b/tests/inference/unit_tests/enterprise/stream_management/manager/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/test_app.py b/tests/inference/unit_tests/enterprise/stream_management/manager/test_app.py new file mode 100644 index 000000000..54c85046a --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/manager/test_app.py @@ -0,0 +1,441 @@ +""" +Unit tests in this module are realised using `InferencePipelineManager` mock - and within single process, submitting +command queues upfront, and then handling one-by-one in the same process. +""" +import json +from multiprocessing import Process, Queue +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from inference.enterprise.stream_management.manager import app +from inference.enterprise.stream_management.manager.app import ( + InferencePipelinesManagerHandler, + execute_termination, + get_response_ignoring_thrash, + handle_command, + join_inference_pipeline, +) +from inference.enterprise.stream_management.manager.entities import ( + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.inference_pipeline_manager import ( + InferencePipelineManager, +) + + +def test_get_response_ignoring_thrash_when_nothing_is_to_ignore() -> None: + # given + responses_queue = Queue() + responses_queue.put(("my_request", {"some": "data"})) + + # when + result = get_response_ignoring_thrash( + responses_queue=responses_queue, matching_request_id="my_request" + ) + + # then + assert result == {"some": "data"} + + +def test_get_response_ignoring_thrash_when_trash_message_is_to_be_ignored() -> None: + # given + responses_queue = Queue() + responses_queue.put(("thrash", {"other": "data"})) + responses_queue.put(("my_request", {"some": "data"})) + + # when + result = get_response_ignoring_thrash( + responses_queue=responses_queue, matching_request_id="my_request" + ) + + # then + assert result == {"some": "data"} + + +def test_handle_command_when_pipeline_id_is_not_registered_in_the_table() -> None: + # when + result = handle_command( + processes_table={}, + request_id="my_request", + pipeline_id="unknown", + command={"type": CommandType.RESUME}, + ) + + # then + assert result == { + "status": OperationStatus.FAILURE, + "error_type": ErrorType.NOT_FOUND, + } + + +class DummyPipelineManager(Process): + def __init__(self, input_queue: Queue, output_queue: Queue): + super().__init__() + self._input_queue = input_queue + self._output_queue = output_queue + + def run(self) -> None: + input_data = self._input_queue.get() + self._output_queue.put(input_data) + + +@pytest.mark.timeout(30) +@pytest.mark.slow +def test_handle_command_when_pipeline_id_is_registered_in_the_table() -> None: + # given + input_queue, output_queue = Queue(), Queue() + pipeline_manager = DummyPipelineManager( + input_queue=input_queue, output_queue=output_queue + ) + pipeline_manager.start() + processes_table = {"my_pipeline": (pipeline_manager, input_queue, output_queue)} + + try: + # when + result = handle_command( + processes_table=processes_table, + request_id="my_request", + pipeline_id="my_pipeline", + command={"type": CommandType.RESUME}, + ) + + # then + assert result == {"type": CommandType.RESUME} + finally: + pipeline_manager.join(timeout=1.0) + + +@pytest.mark.timeout(30) +@pytest.mark.slow +def test_join_inference_pipeline() -> None: + # given + input_queue, output_queue = Queue(), Queue() + pipeline_manager = DummyPipelineManager( + input_queue=input_queue, output_queue=output_queue + ) + pipeline_manager.start() + processes_table = {"my_pipeline": (pipeline_manager, input_queue, output_queue)} + + # when + input_queue.put(None) + _ = output_queue.get() + join_inference_pipeline(processes_table=processes_table, pipeline_id="my_pipeline") + + # then + assert "my_pipeline" not in processes_table + assert pipeline_manager.is_alive() is False + + +@pytest.mark.timeout(30) +@pytest.mark.slow +@mock.patch.object(app.sys, "exit") +def test_execute_termination(exit_mock: MagicMock) -> None: + # given + command_queue, responses_queue = Queue(), Queue() + inference_pipeline_manager = InferencePipelineManager( + command_queue=command_queue, + responses_queue=responses_queue, + ) + inference_pipeline_manager.start() + processes_table = { + "my_pipeline": (inference_pipeline_manager, command_queue, responses_queue) + } + + # when + execute_termination(9, MagicMock(), processes_table=processes_table) + + # then + exit_mock.assert_called_once_with(0) + + +class DummySocket: + def __init__(self): + self._buffer = b"" + self._sent = b"" + + def get_data_that_was_sent(self) -> bytes: + return self._sent + + def fill(self, data: bytes) -> None: + self._buffer = data + + def recv(self, __bufsize: int) -> bytes: + chunk = self._buffer[:__bufsize] + self._buffer = self._buffer[__bufsize:] + return chunk + + def sendall(self, __data: bytes) -> None: + self._sent += __data + + +@pytest.mark.timeout(30) +def test_pipeline_manager_handler_when_wrong_input_format_is_sent() -> None: + # given + socket = DummySocket() + payload = "FOR SURE NOT A JSON".encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table={}, + ) + response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + + # then + assert ( + response["pipeline_id"] is None + ), "Pipeline ID cannot be associated to this request" + assert response["response"]["status"] == "failure", "Operation should failed" + assert ( + response["response"]["error_type"] == "invalid_payload" + ), "Wrong payload should be denoted as error cause" + + +@pytest.mark.timeout(30) +def test_pipeline_manager_handler_when_malformed_input_is_sent() -> None: + # given + socket = DummySocket() + payload = json.dumps({"invalid": "data"}).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table={}, + ) + response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + + # then + assert ( + response["pipeline_id"] is None + ), "Pipeline ID cannot be associated to this request" + assert response["response"]["status"] == "failure", "Operation should failed" + assert ( + response["response"]["error_type"] == "invalid_payload" + ), "Wrong payload should be denoted as error cause" + + +@pytest.mark.timeout(30) +def test_pipeline_manager_handler_when_unknown_command_is_sent() -> None: + # given + socket = DummySocket() + payload = json.dumps({"type": "unknown"}).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table={}, + ) + response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + + # then + assert ( + response["pipeline_id"] is None + ), "Pipeline ID cannot be associated to this request" + assert response["response"]["status"] == "failure", "Operation should failed" + assert ( + response["response"]["error_type"] == "invalid_payload" + ), "Wrong payload should be denoted as error cause" + + +@pytest.mark.timeout(30) +def test_pipeline_manager_handler_when_command_requested_for_unknown_pipeline() -> None: + # given + socket = DummySocket() + payload = json.dumps({"type": "terminate", "pipeline_id": "unknown"}).encode( + "utf-8" + ) + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table={}, + ) + response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + + # then + assert ( + response["pipeline_id"] == "unknown" + ), "Pipeline ID must be assigned to request" + assert response["response"]["status"] == "failure", "Operation should failed" + assert ( + response["response"]["error_type"] == "not_found" + ), "Pipeline not found should be denoted as error cause" + + +@pytest.mark.timeout(30) +@pytest.mark.slow +def test_pipeline_manager_handler_when_pipeline_initialisation_triggered_with_malformed_payload() -> ( + None +): + # given + socket = DummySocket() + payload = json.dumps({"type": "init"}).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + processes_table = {} + + try: + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table=processes_table, + ) + response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + + # then + assert ( + len(processes_table) == 1 + ), "Pipeline table should be filled with manager process" + assert ( + type(response["pipeline_id"]) is str + ), "Pipeline ID must be set to random string" + assert response["response"]["status"] == "failure", "Operation should failed" + assert ( + response["response"]["error_type"] == "invalid_payload" + ), "Pipeline could not be initialised due to invalid payload" + finally: + process = processes_table[list(processes_table.keys())[0]] + process[0].terminate() + + +@pytest.mark.timeout(30) +@pytest.mark.slow +def test_pipeline_manager_handler_when_termination_requested_after_failed_initialisation() -> ( + None +): + # given + socket = DummySocket() + payload = json.dumps({"type": "init"}).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + processes_table = {} + + try: + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table=processes_table, + ) + init_response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + socket = DummySocket() + payload = json.dumps( + {"type": "terminate", "pipeline_id": init_response["pipeline_id"]} + ).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table=processes_table, + ) + terminate_response = json.loads( + socket.get_data_that_was_sent()[4:].decode("utf-8") + ) + + # then + assert ( + len(processes_table) == 0 + ), "Pipeline should be removed from table after termination" + assert ( + type(init_response["pipeline_id"]) is str + ), "Pipeline ID must be set to random string" + assert ( + init_response["response"]["status"] == "failure" + ), "Operation should failed" + assert ( + init_response["response"]["error_type"] == "invalid_payload" + ), "Pipeline could not be initialised due to invalid payload" + assert ( + terminate_response["pipeline_id"] == init_response["pipeline_id"] + ), "Terminate request must refer the same pipeline that was created" + assert ( + terminate_response["response"]["status"] == "success" + ), "Termination operation should succeed" + finally: + if len(processes_table) > 0: + process = processes_table[list(processes_table.keys())] + process[0].terminate() + + +@pytest.mark.timeout(30) +@pytest.mark.slow +def test_pipeline_manager_handler_when_list_of_pipelines_requested_after_unsuccessful_initialisation() -> ( + None +): + # given + socket = DummySocket() + payload = json.dumps({"type": "init"}).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + processes_table = {} + + try: + # when + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table=processes_table, + ) + init_response = json.loads(socket.get_data_that_was_sent()[4:].decode("utf-8")) + socket = DummySocket() + payload = json.dumps({"type": "list_pipelines"}).encode("utf-8") + header = len(payload).to_bytes(length=4, byteorder="big") + socket.fill(header + payload) + _ = InferencePipelinesManagerHandler( + request=socket, + client_address=MagicMock(), + server=MagicMock(), + processes_table=processes_table, + ) + listing_response = json.loads( + socket.get_data_that_was_sent()[4:].decode("utf-8") + ) + + # then + assert ( + len(processes_table) == 1 + ), "Pipeline table should be filled with manager process" + assert ( + type(init_response["pipeline_id"]) is str + ), "Pipeline ID must be set to random string" + assert ( + init_response["response"]["status"] == "failure" + ), "Operation should failed" + assert ( + init_response["response"]["error_type"] == "invalid_payload" + ), "Pipeline could not be initialised due to invalid payload" + assert ( + listing_response["response"]["status"] == "success" + ), "Listing operation should succeed" + assert listing_response["response"]["pipelines"] == [ + init_response["pipeline_id"] + ] + finally: + process = processes_table[list(processes_table.keys())[0]] + process[0].terminate() diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/test_communucation.py b/tests/inference/unit_tests/enterprise/stream_management/manager/test_communucation.py new file mode 100644 index 000000000..942fbc6f9 --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/manager/test_communucation.py @@ -0,0 +1,189 @@ +""" +In this module, tests are written to mock, avoiding issues with trying to find empty socket in tests +""" +import json +import random +from unittest.mock import MagicMock + +import pytest + +from inference.enterprise.stream_management.manager.communication import ( + receive_socket_data, + send_data_trough_socket, +) +from inference.enterprise.stream_management.manager.errors import ( + MalformedHeaderError, + MalformedPayloadError, + TransmissionChannelClosed, +) + + +def test_receive_socket_data_when_header_is_malformed() -> None: + # given + socket = MagicMock() + socket.recv.side_effect = [b"A"] + + # when + with pytest.raises(MalformedHeaderError): + _ = receive_socket_data( + source=socket, + header_size=4, + buffer_size=512, + ) + + +def test_receive_socket_data_when_header_cannot_be_decoded_as_valid_value() -> None: + # given + socket = MagicMock() + zero = 0 + socket.recv.side_effect = [zero.to_bytes(length=4, byteorder="big")] + + # when + with pytest.raises(MalformedHeaderError): + _ = receive_socket_data( + source=socket, + header_size=4, + buffer_size=512, + ) + + +def test_receive_socket_data_when_header_indicated_invalid_payload_length() -> None: + # given + socket = MagicMock() + data = json.dumps({"some": "data"}).encode("utf-8") + header = len(data) + 32 + socket.recv.side_effect = [header.to_bytes(length=4, byteorder="big"), data, b""] + + # when + with pytest.raises(TransmissionChannelClosed): + _ = receive_socket_data( + source=socket, + header_size=4, + buffer_size=len(data), + ) + + +def test_receive_socket_data_when_malformed_payload_given() -> None: + # given + socket = MagicMock() + data = "FOR SURE NOT A JSON :)".encode("utf-8") + header = len(data) + socket.recv.side_effect = [header.to_bytes(length=4, byteorder="big"), data] + + # when + with pytest.raises(MalformedPayloadError): + _ = receive_socket_data( + source=socket, + header_size=4, + buffer_size=len(data), + ) + + +def test_receive_socket_data_complete_successfully_despite_fragmented_message() -> None: + # given + socket = MagicMock() + data = json.dumps({"some": "data"}).encode("utf-8") + header = len(data) + socket.recv.side_effect = [ + header.to_bytes(length=4, byteorder="big"), + data[:-3], + data[-3:], + ] + + # when + result = receive_socket_data( + source=socket, + header_size=4, + buffer_size=len(data) - 3, + ) + + # then + assert result == {"some": "data"}, "Decoded date must be equal to input payload" + + +def test_receive_socket_data_when_timeout_error_should_be_reraised() -> None: + # given + socket = MagicMock() + data = json.dumps({"some": "data"}).encode("utf-8") + header = len(data) + socket.recv.side_effect = [header.to_bytes(length=4, byteorder="big"), TimeoutError] + + # when + with pytest.raises(TimeoutError): + _ = receive_socket_data( + source=socket, + header_size=4, + buffer_size=len(data), + ) + + +def test_send_data_trough_socket_when_operation_succeeds() -> None: + # given + socket = MagicMock() + payload = json.dumps({"my": "data"}).encode("utf-8") + + # when + send_data_trough_socket( + target=socket, + header_size=4, + data=payload, + request_id="my_request", + pipeline_id="my_pipeline", + ) + + # then + socket.sendall.assert_called_once_with( + len(payload).to_bytes(length=4, byteorder="big") + payload + ) + + +def test_send_data_trough_socket_when_payload_overflow_happens() -> None: + # given + socket = MagicMock() + payload = json.dumps( + {"my": "data", "list": [random.randint(0, 100) for _ in range(128)]} + ).encode("utf-8") + expected_error_payload = json.dumps( + { + "request_id": "my_request", + "response": { + "status": "failure", + "error_type": "internal_error", + "error_class": "OverflowError", + "error_message": "int too big to convert", + }, + "pipeline_id": "my_pipeline", + } + ).encode("utf-8") + + # when + send_data_trough_socket( + target=socket, + header_size=1, + data=payload, + request_id="my_request", + pipeline_id="my_pipeline", + ) + + # then + socket.sendall.assert_called_once_with( + len(expected_error_payload).to_bytes(length=1, byteorder="big") + + expected_error_payload + ) + + +def test_send_data_trough_socket_when_connection_error_occurs() -> None: + # given + socket = MagicMock() + payload = json.dumps({"my": "data"}).encode("utf-8") + + # when + send_data_trough_socket( + target=socket, + header_size=4, + data=payload, + request_id="my_request", + pipeline_id="my_pipeline", + ) + + # then: Nothing happens - error just logged diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py b/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py new file mode 100644 index 000000000..6926cd7ef --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py @@ -0,0 +1,458 @@ +""" +Unit tests in this module are realised using `InferencePipeline` mock - and within single process, submitting +command queues upfront, and then handling one-by-one in the same process. +""" + +from multiprocessing import Queue +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from inference.core.exceptions import ( + RoboflowAPINotAuthorizedError, + RoboflowAPINotNotFoundError, +) +from inference.core.interfaces.camera.exceptions import StreamOperationNotAllowedError +from inference.enterprise.stream_management.manager import inference_pipeline_manager +from inference.enterprise.stream_management.manager.entities import ( + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.inference_pipeline_manager import ( + InferencePipelineManager, +) + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.return_value = MagicMock() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert status_1 == ( + "1", + {"status": OperationStatus.SUCCESS}, + ), "Initialisation operation must succeed" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination operation must succeed" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.return_value = MagicMock() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + del init_payload["model_configuration"] + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.INVALID_PAYLOAD + ), "Invalid Payload error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_api_key_not_given( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.return_value = MagicMock() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + del init_payload["api_key"] + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.INVALID_PAYLOAD + ), "Invalid Payload error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_roboflow_operation_not_authorised( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.side_effect = RoboflowAPINotAuthorizedError() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.AUTHORISATION_ERROR + ), "Authorisation error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_model_not_found( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.side_effect = RoboflowAPINotNotFoundError() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.NOT_FOUND + ), "Not found error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_unknown_error_appears( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.side_effect = Exception() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.INTERNAL_ERROR + ), "Internal error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +def test_inference_pipeline_manager_when_attempted_to_get_status_of_not_initialised_pipeline() -> ( + None +): + # given + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + + # when + command_queue.put(("1", {"type": CommandType.STATUS})) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.OPERATION_ERROR + ), "Operation error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +def test_inference_pipeline_manager_when_attempted_to_pause_of_not_initialised_pipeline() -> ( + None +): + # given + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + + # when + command_queue.put(("1", {"type": CommandType.MUTE})) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.OPERATION_ERROR + ), "Operation error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +def test_inference_pipeline_manager_when_attempted_to_resume_of_not_initialised_pipeline() -> ( + None +): + # given + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + + # when + command_queue.put(("1", {"type": CommandType.RESUME})) + command_queue.put(("2", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + + # then + assert ( + status_1[0] == "1" + ), "First request should be reported in responses_queue at first" + assert ( + status_1[1]["status"] == OperationStatus.FAILURE + ), "Init operation should fail" + assert ( + status_1[1]["error_type"] == ErrorType.OPERATION_ERROR + ), "Operation error is expected" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_attempted_to_init_pause_resume_actions_successfully( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.return_value = MagicMock() + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.MUTE})) + command_queue.put(("3", {"type": CommandType.RESUME})) + command_queue.put(("4", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + status_3 = responses_queue.get() + status_4 = responses_queue.get() + + # then + assert status_1 == ( + "1", + {"status": OperationStatus.SUCCESS}, + ), "Initialisation operation must succeed" + assert status_2 == ( + "2", + {"status": OperationStatus.SUCCESS}, + ), "Pause of pipeline must happen" + assert status_3 == ( + "3", + {"status": OperationStatus.SUCCESS}, + ), "Resume of pipeline must happen" + assert status_4 == ( + "4", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +@pytest.mark.timeout(30) +@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") +def test_inference_pipeline_manager_when_attempted_to_resume_running_sprint_causing_not_allowed_transition( + pipeline_init_mock: MagicMock, +) -> None: + # given + pipeline_init_mock.return_value = MagicMock() + pipeline_init_mock.return_value.resume_stream.side_effect = ( + StreamOperationNotAllowedError() + ) + command_queue, responses_queue = Queue(), Queue() + manager = InferencePipelineManager( + command_queue=command_queue, responses_queue=responses_queue + ) + init_payload = assembly_valid_init_payload() + + # when + command_queue.put(("1", init_payload)) + command_queue.put(("2", {"type": CommandType.RESUME})) + command_queue.put(("3", {"type": CommandType.TERMINATE})) + + manager.run() + + status_1 = responses_queue.get() + status_2 = responses_queue.get() + status_3 = responses_queue.get() + + # then + assert status_1 == ( + "1", + {"status": OperationStatus.SUCCESS}, + ), "Initialisation operation must succeed" + assert status_2[0] == "2", "Second result must refer to request `2`" + assert ( + status_2[1]["status"] is OperationStatus.FAILURE + ), "Second request should fail, as we requested forbidden action" + assert status_2[1]["error_type"] == ErrorType.OPERATION_ERROR + assert status_3 == ( + "3", + {"status": OperationStatus.SUCCESS}, + ), "Termination of pipeline must happen" + + +def assembly_valid_init_payload() -> dict: + return { + "type": CommandType.INIT, + "sink_configuration": { + "type": "udp_sink", + "host": "127.0.0.1", + "port": 6060, + }, + "video_reference": "rtsp://128.0.0.1", + "model_id": "some/1", + "api_key": "my_key", + "model_configuration": {"type": "object-detection"}, + } diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/test_serialisation.py b/tests/inference/unit_tests/enterprise/stream_management/manager/test_serialisation.py new file mode 100644 index 000000000..3bcb609e2 --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/manager/test_serialisation.py @@ -0,0 +1,126 @@ +import datetime +import json +from enum import Enum + +from inference.enterprise.stream_management.manager.entities import ( + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.serialisation import ( + describe_error, + prepare_error_response, + serialise_to_json, +) + + +def test_serialise_to_json_when_datetime_object_given() -> None: + # given + timestamp = datetime.datetime( + year=2020, month=10, day=13, hour=10, minute=30, second=12 + ) + + # when + serialised = json.dumps({"time": timestamp}, default=serialise_to_json) + + # then + + assert ( + "2020-10-13T10:30:12" in serialised + ), "Timestamp in format YYYY-MM-DDTHH:MM:HH must be present in serialised json" + + +def test_serialise_to_json_when_date_object_given() -> None: + # given + timestamp = datetime.date(year=2020, month=10, day=13) + + # when + serialised = json.dumps({"time": timestamp}, default=serialise_to_json) + + # then + + assert ( + "2020-10-13" in serialised + ), "Date in format YYYY-MM-DD must be present in serialised json" + + +class ExampleEnum(Enum): + SOME = "some" + OTHER = "other" + + +def test_serialise_to_json_when_enum_object_given() -> None: + # given + data = ExampleEnum.SOME + + # when + serialised = json.dumps({"payload": data}, default=serialise_to_json) + + # then + + assert "some" in serialised, "Enum value `some` must be present in serialised json" + + +def test_serialise_to_json_when_no_special_content_given() -> None: + # given + data = {"some": 1, "other": True} + + # when + serialised = json.dumps(data, default=serialise_to_json) + result = json.loads(serialised) + + # then + + assert result == data, "After deserialization, data must be recovered 100%" + + +def test_describe_error_when_exception_is_provided_as_context() -> None: + # given + exception = ValueError("Some value error") + + # when + result = describe_error(exception=exception, error_type=ErrorType.INVALID_PAYLOAD) + + # then + assert result == { + "status": OperationStatus.FAILURE, + "error_type": ErrorType.INVALID_PAYLOAD, + "error_class": "ValueError", + "error_message": "Some value error", + } + + +def test_describe_error_when_exception_is_not_provided() -> None: + # when + result = describe_error(exception=None, error_type=ErrorType.INVALID_PAYLOAD) + + # then + assert result == { + "status": OperationStatus.FAILURE, + "error_type": ErrorType.INVALID_PAYLOAD, + } + + +def test_prepare_error_response() -> None: + # given + exception = ValueError("Some value error") + + # when + error_response = prepare_error_response( + request_id="my_request", + error=exception, + error_type=ErrorType.INTERNAL_ERROR, + pipeline_id="my_pipeline", + ) + decoded_response = json.loads(error_response.decode("utf-8")) + + # then + assert decoded_response == { + "request_id": "my_request", + "response": { + "status": "failure", + "error_type": "internal_error", + "error_class": "ValueError", + "error_message": "Some value error", + }, + "pipeline_id": "my_pipeline", + } diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/test_tcp_server.py b/tests/inference/unit_tests/enterprise/stream_management/manager/test_tcp_server.py new file mode 100644 index 000000000..227c0a565 --- /dev/null +++ b/tests/inference/unit_tests/enterprise/stream_management/manager/test_tcp_server.py @@ -0,0 +1,25 @@ +from unittest.mock import MagicMock + +from inference.enterprise.stream_management.manager.tcp_server import RoboflowTCPServer + + +def test_roboflow_server_applies_connection_timeout() -> None: + # given + server = RoboflowTCPServer( + server_address=("127.0.0.1", 7070), + handler_class=MagicMock, + socket_operations_timeout=1.5, + ) + connection, address = MagicMock(), MagicMock() + server.socket = MagicMock() + server.socket.accept.return_value = (connection, address) + + # when + result = server.get_request() + + # then + connection.settimeout.assert_called_once_with(1.5) + assert result == ( + connection, + address, + ), "Method must return accepted connection and address, as per TCPServer interface requirement"