From 98c1f95a8ccc2e0d9de01e1ea98c5cec4be51e0e Mon Sep 17 00:00:00 2001 From: Roy Schossberger <85231212+royischoss@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:35:26 +0200 Subject: [PATCH] [Model Monitoring] Controller stream, chief worker implementation (#7045) --- .../schemas/model_monitoring/constants.py | 19 + mlrun/config.py | 29 +- mlrun/datastore/sources.py | 5 + mlrun/model_monitoring/controller.py | 369 +++++++++++++----- .../db/tsdb/tdengine/tdengine_connector.py | 2 +- .../db/tsdb/v3io/v3io_connector.py | 4 +- mlrun/model_monitoring/stream_processing.py | 74 +++- .../api/crud/model_monitoring/deployment.py | 36 +- .../test_stream_processing.py | 9 +- tests/system/model_monitoring/test_app.py | 11 +- 10 files changed, 429 insertions(+), 129 deletions(-) diff --git a/mlrun/common/schemas/model_monitoring/constants.py b/mlrun/common/schemas/model_monitoring/constants.py index 807c75a32192..7fd8fd86bf20 100644 --- a/mlrun/common/schemas/model_monitoring/constants.py +++ b/mlrun/common/schemas/model_monitoring/constants.py @@ -183,6 +183,25 @@ class WriterEventKind(MonitoringStrEnum): STATS = "stats" +class ControllerEvent(MonitoringStrEnum): + KIND = "kind" + ENDPOINT_ID = "endpoint_id" + ENDPOINT_NAME = "endpoint_name" + PROJECT = "project" + TIMESTAMP = "timestamp" + FIRST_REQUEST = "first_request" + FEATURE_SET_URI = "feature_set_uri" + ENDPOINT_TYPE = "endpoint_type" + ENDPOINT_POLICY = "endpoint_policy" + # Note: currently under endpoint policy we will have a dictionary including the keys: "application_names" + # and "base_period" + + +class ControllerEventKind(MonitoringStrEnum): + NOP_EVENT = "nop_event" + REGULAR_EVENT = "regular_event" + + class MetricData(MonitoringStrEnum): METRIC_NAME = "metric_name" METRIC_VALUE = "metric_value" diff --git a/mlrun/config.py b/mlrun/config.py index c072cb3fcc26..e969361f862b 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -596,6 +596,22 @@ "max_replicas": 1, }, }, + "controller_stream_args": { + "v3io": { + "shard_count": 10, + "retention_period_hours": 24, + "num_workers": 10, + "min_replicas": 1, + "max_replicas": 1, + }, + "kafka": { + "partition_count": 10, + "replication_factor": 1, + "num_workers": 10, + "min_replicas": 1, + "max_replicas": 1, + }, + }, # Store prefixes are used to handle model monitoring storing policies based on project and kind, such as events, # stream, and endpoints. "store_prefixes": { @@ -1282,6 +1298,8 @@ def get_model_monitoring_file_target_path( function_name and function_name != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.STREAM + and function_name + != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER ): return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format( project=project, @@ -1289,12 +1307,21 @@ def get_model_monitoring_file_target_path( if function_name is None else f"{kind}-{function_name.lower()}", ) - elif kind == "stream": + elif ( + kind == "stream" + and function_name + != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER + ): return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format( project=project, kind=kind, ) else: + if ( + function_name + == mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER + ): + kind = function_name return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( project=project, kind=kind, diff --git a/mlrun/datastore/sources.py b/mlrun/datastore/sources.py index 784538c2de0e..5d36469d7cc6 100644 --- a/mlrun/datastore/sources.py +++ b/mlrun/datastore/sources.py @@ -1128,8 +1128,13 @@ def add_nuclio_trigger(self, function): extra_attributes["workerAllocationMode"] = extra_attributes.get( "worker_allocation_mode", "static" ) + else: + extra_attributes["workerAllocationMode"] = extra_attributes.get( + "worker_allocation_mode", "pool" + ) trigger_kwargs = {} + if "max_workers" in extra_attributes: trigger_kwargs = {"max_workers": extra_attributes.pop("max_workers")} diff --git a/mlrun/model_monitoring/controller.py b/mlrun/model_monitoring/controller.py index 5188d1808558..bfe3a154fc63 100644 --- a/mlrun/model_monitoring/controller.py +++ b/mlrun/model_monitoring/controller.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import concurrent.futures import datetime import json import os from collections.abc import Iterator from contextlib import AbstractContextManager from types import TracebackType -from typing import NamedTuple, Optional, cast +from typing import Any, NamedTuple, Optional, cast import nuclio_sdk @@ -28,6 +27,10 @@ import mlrun.feature_store as fstore import mlrun.model_monitoring from mlrun.common.schemas import EndpointType +from mlrun.common.schemas.model_monitoring.constants import ( + ControllerEvent, + ControllerEventKind, +) from mlrun.datastore import get_stream_pusher from mlrun.errors import err_to_str from mlrun.model_monitoring.db._schedules import ModelMonitoringSchedulesFile @@ -140,6 +143,7 @@ def __init__(self, project: str, endpoint_id: str, window_length: int) -> None: Initialize a batch window generator object that generates batch window objects for the monitoring functions. """ + self.batch_window: _BatchWindow = None self._project = project self._endpoint_id = endpoint_id self._timedelta = window_length @@ -199,14 +203,14 @@ def get_intervals( `first_request` and `last_request` are the timestamps of the first request and last request to the endpoint, respectively. They are guaranteed to be nonempty at this point. """ - batch_window = _BatchWindow( + self.batch_window = _BatchWindow( schedules_file=self._schedules_file, application=application, timedelta_seconds=self._timedelta, last_updated=self._get_last_updated_time(last_request, not_batch_endpoint), first_request=int(first_request.timestamp()), ) - yield from batch_window.get_intervals() + yield from self.batch_window.get_intervals() def _get_window_length() -> int: @@ -237,6 +241,7 @@ def __init__(self) -> None: self._window_length = _get_window_length() self.model_monitoring_access_key = self._get_model_monitoring_access_key() + self.v3io_access_key = mlrun.get_secret_or_env("V3IO_ACCESS_KEY") self.storage_options = None if mlrun.mlconf.artifact_path.startswith("s3://"): self.storage_options = mlrun.mlconf.get_s3_storage_options() @@ -262,112 +267,65 @@ def _should_monitor_endpoint(endpoint: mlrun.common.schemas.ModelEndpoint) -> bo != mm_constants.EndpointType.ROUTER.value ) - def run(self) -> None: + def run(self, event: nuclio_sdk.Event) -> None: """ - Main method for run all the relevant monitoring applications on each endpoint. + Main method for controller chief, runs all the relevant monitoring applications for a single endpoint. + Handles nop events logic. This method handles the following: - 1. List model endpoints - 2. List applications - 3. Check model monitoring windows - 4. Send data to applications - 5. Delete old parquets + 1. Read applications from the event (endpoint_policy) + 2. Check model monitoring windows + 3. Send data to applications + 4. Pushes nop event to main stream if needed """ - logger.info("Start running monitoring controller") + logger.info("Start running monitoring controller worker") try: - applications_names = [] - endpoints_list = mlrun.db.get_run_db().list_model_endpoints( - project=self.project, tsdb_metrics=True - ) - endpoints = endpoints_list.endpoints - if not endpoints: - logger.info("No model endpoints found", project=self.project) - return - monitoring_functions = self.project_obj.list_model_monitoring_functions() - if monitoring_functions: - applications_names = list( - {app.metadata.name for app in monitoring_functions} - ) - # if monitoring_functions: - TODO : ML-7700 - # Gets only application in ready state - # applications_names = list( - # { - # app.metadata.name - # for app in monitoring_functions - # if ( - # app.status.state == "ready" - # # workaround for the default app, as its `status.state` is `None` - # or app.metadata.name - # == mm_constants.HistogramDataDriftApplicationConstants.NAME - # ) - # } - # ) - if not applications_names: - logger.info("No monitoring functions found", project=self.project) - return - logger.info( - "Starting to iterate over the applications", - applications=applications_names, - ) - + body = json.loads(event.body.decode("utf-8")) except Exception as e: logger.error( - "Failed to list endpoints and monitoring applications", + "Failed to decode event", exc=err_to_str(e), ) return - # Initialize a thread pool that will be used to monitor each endpoint on a dedicated thread - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(len(endpoints), 10) - ) as pool: - for endpoint in endpoints: - if self._should_monitor_endpoint(endpoint): - pool.submit( - MonitoringApplicationController.model_endpoint_process, - project=self.project, - endpoint=endpoint, - applications_names=applications_names, - window_length=self._window_length, - model_monitoring_access_key=self.model_monitoring_access_key, - storage_options=self.storage_options, - ) - else: - logger.debug( - "Skipping endpoint, not ready or not suitable for monitoring", - endpoint_id=endpoint.metadata.uid, - endpoint_name=endpoint.metadata.name, - ) - logger.info("Finished running monitoring controller") + # Run single endpoint process + self.model_endpoint_process(event=body) - @classmethod def model_endpoint_process( - cls, - project: str, - endpoint: mlrun.common.schemas.ModelEndpoint, - applications_names: list[str], - window_length: int, - model_monitoring_access_key: str, - storage_options: Optional[dict] = None, + self, + event: Optional[dict] = None, ) -> None: """ Process a model endpoint and trigger the monitoring applications. This function running on different process - for each endpoint. In addition, this function will generate a parquet file that includes the relevant data - for a specific time range. - - :param endpoint: (dict) Model endpoint record. - :param applications_names: (list[str]) List of application names to push results to. - :param batch_window_generator: (_BatchWindowGenerator) An object that generates _BatchWindow objects. - :param project: (str) Project name. - :param model_monitoring_access_key: (str) Access key to apply the model monitoring process. - :param storage_options: (dict) Storage options for reading the infer parquet files. + for each endpoint. + + :param event: (dict) Event that triggered the monitoring process. """ - endpoint_id = endpoint.metadata.uid - not_batch_endpoint = not ( - endpoint.metadata.endpoint_type == EndpointType.BATCH_EP - ) - m_fs = fstore.get_feature_set(endpoint.spec.monitoring_feature_set_uri) + logger.info("Model endpoint process started", event=event) + try: + project_name = event[ControllerEvent.PROJECT] + endpoint_id = event[ControllerEvent.ENDPOINT_ID] + endpoint_name = event[ControllerEvent.ENDPOINT_NAME] + applications_names = event[ControllerEvent.ENDPOINT_POLICY][ + "monitoring_applications" + ] + + not_batch_endpoint = ( + event[ControllerEvent.ENDPOINT_POLICY] != EndpointType.BATCH_EP + ) + m_fs = fstore.get_feature_set(event[ControllerEvent.FEATURE_SET_URI]) + logger.info( + "Starting analyzing for:", timestamp=event[ControllerEvent.TIMESTAMP] + ) + last_stream_timestamp = datetime.datetime.fromisoformat( + event[ControllerEvent.TIMESTAMP] + ) + first_request = datetime.datetime.fromisoformat( + event[ControllerEvent.FIRST_REQUEST] + ) with _BatchWindowGenerator( - project=project, endpoint_id=endpoint_id, window_length=window_length + project=project_name, + endpoint_id=endpoint_id, + window_length=self._window_length, ) as batch_window_generator: for application in applications_names: for ( @@ -375,15 +333,15 @@ def model_endpoint_process( end_infer_time, ) in batch_window_generator.get_intervals( application=application, - first_request=endpoint.status.first_request, - last_request=endpoint.status.last_request, not_batch_endpoint=not_batch_endpoint, + first_request=first_request, + last_request=last_stream_timestamp, ): df = m_fs.to_dataframe( start_time=start_infer_time, end_time=end_infer_time, time_column=mm_constants.EventFieldType.TIMESTAMP, - storage_options=storage_options, + storage_options=self.storage_options, ) if len(df) == 0: logger.info( @@ -399,21 +357,53 @@ def model_endpoint_process( end=end_infer_time, endpoint_id=endpoint_id, ) - cls._push_to_applications( + self._push_to_applications( start_infer_time=start_infer_time, end_infer_time=end_infer_time, endpoint_id=endpoint_id, - endpoint_name=endpoint.metadata.name, - project=project, + endpoint_name=endpoint_name, + project=project_name, applications_names=[application], - model_monitoring_access_key=model_monitoring_access_key, + model_monitoring_access_key=self.model_monitoring_access_key, ) - logger.info("Finished processing endpoint", endpoint_id=endpoint_id) + base_period = event[ControllerEvent.ENDPOINT_POLICY]["base_period"] + current_time = mlrun.utils.datetime_now() + if ( + current_time.timestamp() + - batch_window_generator.batch_window._get_last_analyzed() + >= datetime.timedelta(minutes=base_period).total_seconds() + and event[ControllerEvent.KIND] != ControllerEventKind.NOP_EVENT + ): + event = { + ControllerEvent.KIND: mm_constants.ControllerEventKind.NOP_EVENT, + ControllerEvent.PROJECT: project_name, + ControllerEvent.ENDPOINT_ID: endpoint_id, + ControllerEvent.ENDPOINT_NAME: endpoint_name, + ControllerEvent.TIMESTAMP: current_time.isoformat( + timespec="microseconds" + ), + ControllerEvent.ENDPOINT_POLICY: event[ + ControllerEvent.ENDPOINT_POLICY + ], + ControllerEvent.ENDPOINT_TYPE: event[ + ControllerEvent.ENDPOINT_TYPE + ], + ControllerEvent.FEATURE_SET_URI: event[ + ControllerEvent.FEATURE_SET_URI + ], + ControllerEvent.FIRST_REQUEST: event[ + ControllerEvent.FIRST_REQUEST + ], + } + self._push_to_main_stream( + event=event, + endpoint_id=endpoint_id, + ) except Exception: logger.exception( "Encountered an exception", - endpoint_id=endpoint.metadata.uid, + endpoint_id=event[ControllerEvent.ENDPOINT_ID], ) @staticmethod @@ -465,6 +455,168 @@ def _push_to_applications( [data] ) + def push_regular_event_to_controller_stream(self, event: nuclio_sdk.Event) -> None: + """ + pushes a regular event to the controller stream. + :param event: the nuclio trigger event + """ + logger.info("Starting monitoring controller chief") + applications_names = [] + db = mlrun.get_run_db() + endpoints = db.list_model_endpoints( + project=self.project, tsdb_metrics=True + ).endpoints + if not endpoints: + logger.info("No model endpoints found", project=self.project) + return + monitoring_functions = self.project_obj.list_model_monitoring_functions() + if monitoring_functions: + # if monitoring_functions: - TODO : ML-7700 + # Gets only application in ready state + # applications_names = list( + # { + # app.metadata.name + # for app in monitoring_functions + # if ( + # app.status.state == "ready" + # # workaround for the default app, as its `status.state` is `None` + # or app.metadata.name + # == mm_constants.HistogramDataDriftApplicationConstants.NAME + # ) + # } + # ) + applications_names = list( + {app.metadata.name for app in monitoring_functions} + ) + if not applications_names: + logger.info("No monitoring functions found", project=self.project) + return + policy = { + "monitoring_applications": applications_names, + "base_period": int( + batch_dict2timedelta( + json.loads( + cast( + str, + os.getenv(mm_constants.EventFieldType.BATCH_INTERVALS_DICT), + ) + ) + ).total_seconds() + // 60 + ), + } + for endpoint in endpoints: + if self._should_monitor_endpoint(endpoint): + logger.info( + "Regular event is being pushed to controller stream for model endpoint", + endpoint_id=endpoint.metadata.uid, + endpoint_name=endpoint.metadata.name, + timestamp=endpoint.status.last_request.isoformat( + sep=" ", timespec="microseconds" + ), + first_request=endpoint.status.first_request.isoformat( + sep=" ", timespec="microseconds" + ), + endpoint_type=endpoint.metadata.endpoint_type, + feature_set_uri=endpoint.spec.monitoring_feature_set_uri, + endpoint_policy=json.dumps(policy), + ) + self.push_to_controller_stream( + kind=mm_constants.ControllerEventKind.REGULAR_EVENT, + project=self.project, + endpoint_id=endpoint.metadata.uid, + endpoint_name=endpoint.metadata.name, + stream_access_key=self.v3io_access_key, + timestamp=endpoint.status.last_request.isoformat( + sep=" ", timespec="microseconds" + ), + first_request=endpoint.status.first_request.isoformat( + sep=" ", timespec="microseconds" + ), + endpoint_type=endpoint.metadata.endpoint_type, + feature_set_uri=endpoint.spec.monitoring_feature_set_uri, + endpoint_policy=policy, + ) + else: + logger.info( + "Should not monitor model endpoint, didn't push regular event", + endpoint_id=endpoint.metadata.uid, + endpoint_name=endpoint.metadata.name, + timestamp=endpoint.status.last_request, + first_request=endpoint.status.first_request, + endpoint_type=endpoint.metadata.endpoint_type, + feature_set_uri=endpoint.spec.monitoring_feature_set_uri, + ) + + @staticmethod + def push_to_controller_stream( + kind: str, + project: str, + endpoint_id: str, + endpoint_name: str, + stream_access_key: str, + timestamp: str, + first_request: str, + endpoint_type: str, + feature_set_uri: str, + endpoint_policy: dict[str, Any], + ) -> None: + """ + Pushes event data to controller stream. + :param timestamp: the event timestamp str isoformat utc timezone + :param first_request: the first request str isoformat utc timezone + :param endpoint_policy: dictionary hold the monitoring policy + :param kind: str event kind + :param project: project name + :param endpoint_id: endpoint id string + :param endpoint_name: the endpoint name string + :param endpoint_type: Enum of the endpoint type + :param feature_set_uri: the feature set uri string + :param stream_access_key: access key to apply the model monitoring process. + """ + stream_uri = get_stream_path( + project=project, + function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, + ) + event = { + ControllerEvent.KIND.value: kind, + ControllerEvent.PROJECT.value: project, + ControllerEvent.ENDPOINT_ID.value: endpoint_id, + ControllerEvent.ENDPOINT_NAME.value: endpoint_name, + ControllerEvent.TIMESTAMP.value: timestamp, + ControllerEvent.FIRST_REQUEST.value: first_request, + ControllerEvent.ENDPOINT_TYPE.value: endpoint_type, + ControllerEvent.FEATURE_SET_URI.value: feature_set_uri, + ControllerEvent.ENDPOINT_POLICY.value: endpoint_policy, + } + logger.info( + "Pushing data to controller stream", + event=event, + endpoint_id=endpoint_id, + stream_uri=stream_uri, + ) + get_stream_pusher(stream_uri, access_key=stream_access_key).push( + [event], partition_key=endpoint_id + ) + + def _push_to_main_stream(self, event: dict, endpoint_id: str) -> None: + """ + Pushes the given event to model monitoring stream + :param event: event dictionary to push to stream + :param endpoint_id: endpoint id string + """ + stream_uri = get_stream_path(project=event.get(ControllerEvent.PROJECT)) + + logger.info( + "Pushing data to main stream, NOP event is been generated", + event=json.dumps(event), + endpoint_id=endpoint_id, + stream_uri=stream_uri, + ) + get_stream_pusher(stream_uri, access_key=self.model_monitoring_access_key).push( + [event], partition_key=endpoint_id + ) + def handler(context: nuclio_sdk.Context, event: nuclio_sdk.Event) -> None: """ @@ -473,4 +625,15 @@ def handler(context: nuclio_sdk.Context, event: nuclio_sdk.Event) -> None: :param context: the Nuclio context :param event: trigger event """ - MonitoringApplicationController().run() + logger.info( + "Controller got event", + trigger=event.trigger, + trigger_kind=event.trigger.kind, + ) + + if event.trigger.kind == "http": + # Runs controller chief: + MonitoringApplicationController().push_regular_event_to_controller_stream(event) + else: + # Runs controller worker: + MonitoringApplicationController().run(event=event) diff --git a/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py b/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py index 7480a5c59348..6901b5b1af05 100644 --- a/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +++ b/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py @@ -188,7 +188,7 @@ def apply_process_before_tsdb(): graph.add_step( "mlrun.model_monitoring.db.tsdb.tdengine.stream_graph_steps.ProcessBeforeTDEngine", name="ProcessBeforeTDEngine", - after="MapFeatureNames", + after="FilterNOP", ) def apply_tdengine_target(name, after): diff --git a/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py b/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py index eb03cad71844..70ee7e5af15d 100644 --- a/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +++ b/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py @@ -204,7 +204,7 @@ def apply_storey_aggregations(): } ], name=EventFieldType.LATENCY, - after="MapFeatureNames", + after="FilterNOP", step_name="Aggregates", table=".", key_field=EventFieldType.ENDPOINT_ID, @@ -225,7 +225,7 @@ def apply_storey_aggregations(): graph.add_step( "storey.TSDBTarget", name="tsdb_predictions", - after="MapFeatureNames", + after="FilterNOP", path=f"{self.container}/{self.tables[mm_schemas.FileTargetKind.PREDICTIONS]}", rate="1/s", time_col=mm_schemas.EventFieldType.TIMESTAMP, diff --git a/mlrun/model_monitoring/stream_processing.py b/mlrun/model_monitoring/stream_processing.py index 073b91f2fc7a..dc6fc63af9f9 100644 --- a/mlrun/model_monitoring/stream_processing.py +++ b/mlrun/model_monitoring/stream_processing.py @@ -29,11 +29,14 @@ import mlrun.serving.states import mlrun.utils from mlrun.common.schemas.model_monitoring.constants import ( + ControllerEvent, + ControllerEventKind, EndpointType, EventFieldType, FileTargetKind, ProjectSecretKeys, ) +from mlrun.datastore import parse_kafka_url from mlrun.model_monitoring.db import TSDBConnector from mlrun.utils import logger @@ -88,7 +91,9 @@ def _initialize_v3io_configurations( self.v3io_framesd = v3io_framesd or mlrun.mlconf.v3io_framesd self.v3io_api = v3io_api or mlrun.mlconf.v3io_api - self.v3io_access_key = v3io_access_key or os.environ.get("V3IO_ACCESS_KEY") + self.v3io_access_key = v3io_access_key or mlrun.get_secret_or_env( + "V3IO_ACCESS_KEY" + ) self.model_monitoring_access_key = ( model_monitoring_access_key or os.environ.get(ProjectSecretKeys.ACCESS_KEY) @@ -118,6 +123,7 @@ def apply_monitoring_serving_graph( self, fn: mlrun.runtimes.ServingRuntime, tsdb_connector: TSDBConnector, + controller_stream_uri: str, ) -> None: """ Apply monitoring serving graph to a given serving function. The following serving graph includes about 4 main @@ -146,6 +152,8 @@ def apply_monitoring_serving_graph( :param fn: A serving function. :param tsdb_connector: Time series database connector. + :param controller_stream_uri: The controller stream URI. Runs on server api pod so needed to be provided as + input """ graph = typing.cast( @@ -209,6 +217,20 @@ def apply_map_feature_names(): ) apply_map_feature_names() + # split the graph between event with error vs valid event + graph.add_step( + "storey.Filter", + "FilterNOP", + after="MapFeatureNames", + _fn="(event.get('kind', " ") != 'nop_event')", + ) + graph.add_step( + "storey.Filter", + "ForwardNOP", + after="MapFeatureNames", + _fn="(event.get('kind', " ") == 'nop_event')", + ) + tsdb_connector.apply_monitoring_stream_steps( graph=graph, aggregate_windows=self.aggregate_windows, @@ -221,7 +243,7 @@ def apply_process_before_parquet(): graph.add_step( "ProcessBeforeParquet", name="ProcessBeforeParquet", - after="MapFeatureNames", + after="FilterNOP", _fn="(event)", ) @@ -248,6 +270,44 @@ def apply_parquet_target(): apply_parquet_target() + # controller branch + def apply_push_controller_stream(stream_uri: str): + if stream_uri.startswith("v3io://"): + graph.add_step( + ">>", + "controller_stream_v3io", + path=stream_uri, + sharding_func=ControllerEvent.ENDPOINT_ID, + access_key=self.v3io_access_key, + after="ForwardNOP", + ) + elif stream_uri.startswith("kafka://"): + topic, brokers = parse_kafka_url(stream_uri) + logger.info( + "Controller stream uri for kafka", + stream_uri=stream_uri, + topic=topic, + brokers=brokers, + ) + if isinstance(brokers, list): + path = f"kafka://{brokers[0]}/{topic}" + elif isinstance(brokers, str): + path = f"kafka://{brokers}/{topic}" + else: + raise mlrun.errors.MLRunInvalidArgumentError( + "Brokers must be a list or str check controller stream uri" + ) + graph.add_step( + ">>", + "controller_stream_kafka", + path=path, + kafka_brokers=brokers, + _sharding_func="kafka_sharding_func", # TODO: remove this when storey handle str key + after="ForwardNOP", + ) + + apply_push_controller_stream(controller_stream_uri) + class ProcessBeforeParquet(mlrun.feature_store.steps.MapClass): def __init__(self, **kwargs): @@ -321,6 +381,9 @@ def __init__( def do(self, full_event): event = full_event.body + if event.get(ControllerEvent.KIND, "") == ControllerEventKind.NOP_EVENT: + logger.info("Skipped nop event inside of ProcessEndpointEvent", event=event) + return storey.Event(body=[event]) # Getting model version and function uri from event # and use them for retrieving the endpoint_id function_uri = full_event.body.get(EventFieldType.FUNCTION_URI) @@ -589,6 +652,9 @@ def _infer_label_columns_from_data(self, event): return None def do(self, event: dict): + if event.get(ControllerEvent.KIND, "") == ControllerEventKind.NOP_EVENT: + logger.info("Skipped nop event inside of MapFeatureNames", event=event) + return event endpoint_id = event[EventFieldType.ENDPOINT_ID] feature_values = event[EventFieldType.FEATURES] @@ -827,3 +893,7 @@ def update_monitoring_feature_set( ) monitoring_feature_set.save() + + +def kafka_sharding_func(event): + return event.body[ControllerEvent.ENDPOINT_ID].encode("UTF-8") diff --git a/server/py/services/api/crud/model_monitoring/deployment.py b/server/py/services/api/crud/model_monitoring/deployment.py index dea9abb0fc65..cacb400e635b 100644 --- a/server/py/services/api/crud/model_monitoring/deployment.py +++ b/server/py/services/api/crud/model_monitoring/deployment.py @@ -328,7 +328,8 @@ def apply_and_create_stream_trigger( function=function, function_name=function_name ) - function.spec.disable_default_http_trigger = True + if function_name != mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER: + function.spec.disable_default_http_trigger = True return function @@ -346,7 +347,10 @@ def _apply_and_create_kafka_source( stream_source = mlrun.datastore.sources.KafkaSource( brokers=brokers, topics=[topic], - attributes={"max_workers": stream_args.kafka.num_workers}, + attributes={ + "max_workers": stream_args.kafka.num_workers, + "worker_allocation_mode": "static", + }, ) try: stream_source.create_topics( @@ -375,11 +379,16 @@ def _apply_and_create_v3io_source( function_name: str, stream_args: mlrun.config.Config, ): - access_key = self.model_monitoring_access_key - kwargs = {"access_key": self.model_monitoring_access_key} + access_key = ( + self.model_monitoring_access_key + if function_name + != mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER + else mlrun.mlconf.get_v3io_access_key() + ) + kwargs = {"access_key": access_key} if mlrun.mlconf.is_explicit_ack_enabled(): kwargs["explicit_ack_mode"] = "explicitOnly" - kwargs["worker_allocation_mode"] = "static" + kwargs["worker_allocation_mode"] = "static" kwargs["max_workers"] = stream_args.v3io.num_workers services.api.api.endpoints.nuclio.create_model_monitoring_stream( project=self.project, @@ -444,10 +453,15 @@ def _initial_model_monitoring_stream_processing_function( project=self.project, secret_provider=self._secret_provider ) + controller_stream_uri = mlrun.model_monitoring.get_stream_path( + project=self.project, + function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, + secret_provider=self._secret_provider, + ) + # Create monitoring serving graph stream_processor.apply_monitoring_serving_graph( - function, - tsdb_connector, + function, tsdb_connector, controller_stream_uri ) # Set the project to the serving function @@ -489,11 +503,17 @@ def _get_model_monitoring_controller_function(self, image: str): # Set the project to the job function function.metadata.project = self.project + # Add stream triggers + function = self.apply_and_create_stream_trigger( + function=function, + function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, + stream_args=config.model_endpoint_monitoring.controller_stream_args, + ) + function = self._apply_access_key_and_mount_function( function=function, function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, ) - function.spec.max_replicas = 1 # Enrich runtime with the required configurations framework.api.utils.apply_enrichment_and_validation_on_function( function, self.auth_info diff --git a/tests/model_monitoring/test_stream_processing.py b/tests/model_monitoring/test_stream_processing.py index 2db1c72f7f85..9ad57b2c6e5e 100644 --- a/tests/model_monitoring/test_stream_processing.py +++ b/tests/model_monitoring/test_stream_processing.py @@ -20,7 +20,8 @@ @pytest.mark.parametrize("tsdb_connector", ["v3io", "taosws"]) -def test_plot_monitoring_serving_graph(tsdb_connector): +@pytest.mark.parametrize("stream_path", ["v3io", "kafka://192.168.226.176:9092/topic"]) +def test_plot_monitoring_serving_graph(tsdb_connector, stream_path): project_name = "test-stream-processing" project = mlrun.get_or_create_project(project_name) @@ -40,11 +41,13 @@ def test_plot_monitoring_serving_graph(tsdb_connector): project=project_name, tsdb_connection_string=tsdb_connector ) - processor.apply_monitoring_serving_graph(fn, tsdb_connector) + processor.apply_monitoring_serving_graph(fn, tsdb_connector, stream_path) graph = fn.spec.graph.plot(rankdir="TB") print() - print(f"Graphviz graph definition with tsdb_connector={tsdb_connector}") + print( + f"Graphviz graph definition with tsdb_connector={tsdb_connector} and stream_path={stream_path}" + ) print("Feed this to graphviz, or to https://dreampuf.github.io/GraphvizOnline") print() print(graph) diff --git a/tests/system/model_monitoring/test_app.py b/tests/system/model_monitoring/test_app.py index e8a46074141a..c248db5b9d2e 100644 --- a/tests/system/model_monitoring/test_app.py +++ b/tests/system/model_monitoring/test_app.py @@ -685,23 +685,16 @@ def test_app_flow(self, with_training_set: bool) -> None: self._add_error_alert() time.sleep(5) - self._infer( + last_request = self._infer( serving_fn, num_events=self.num_events, with_training_set=with_training_set ) self._infer_with_error(serving_fn, with_training_set=with_training_set) # mark the first window as "done" with another request time.sleep( - self.app_interval_seconds + 2 * self.app_interval_seconds + mlrun.mlconf.model_endpoint_monitoring.parquet_batching_timeout_secs - + 2 ) - for i in range(10): - last_request = self._infer( - serving_fn, num_events=1, with_training_set=with_training_set - ) - # wait for the completed window to be processed - time.sleep(1.2 * self.app_interval_seconds) mep = mlrun.db.get_run_db().get_model_endpoint( name=f"{self.model_name}_{with_training_set}",