diff --git a/backend/deepchecks_monitoring/bgtasks/scheduler.py b/backend/deepchecks_monitoring/bgtasks/scheduler.py index d8034ae1..2d858a69 100644 --- a/backend/deepchecks_monitoring/bgtasks/scheduler.py +++ b/backend/deepchecks_monitoring/bgtasks/scheduler.py @@ -15,6 +15,7 @@ import logging.handlers import typing as t from collections import defaultdict +from contextlib import asynccontextmanager from time import perf_counter import anyio @@ -53,10 +54,10 @@ class AlertsScheduler: """Alerts scheduler.""" def __init__( - self, - engine: AsyncEngine, - sleep_seconds: int = TimeUnit.MINUTE * 5, - logger: t.Optional[logging.Logger] = None, + self, + engine: AsyncEngine, + sleep_seconds: int = TimeUnit.MINUTE * 5, + logger: t.Optional[logging.Logger] = None, ): self.engine = engine self.async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @@ -83,6 +84,15 @@ async def run(self): self.logger.warning('Scheduler interrupted') raise + @asynccontextmanager + async def organization_error_handler(self, org_id, task_name): + try: + yield + except SerializationError: # We use 'Repeatable Read Isolation Level'. + self.logger.warning({'task': task_name, 'org_id': org_id}) + except: # noqa: E722 + self.logger.exception({'task': task_name, 'org_id': org_id}) + async def run_all_organizations(self): """Enqueue tasks for execution.""" async with self.async_session_factory() as session: @@ -95,19 +105,15 @@ async def run_all_organizations(self): return for org in organizations: - try: + async with self.organization_error_handler(org.id, 'run_organization'): await self.run_organization(org) - except: # noqa: E722 - self.logger.exception({'task': 'run_organization', 'org_id': org.id}) - try: + + async with self.organization_error_handler(org.id, 'run_organization_data_ingestion_alert'): await self.run_organization_data_ingestion_alert(org) - except: # noqa: E722 - self.logger.exception({'task': 'run_organization_data_ingestion_alert', 'org_id': org.id}) + if with_ee: - try: + async with self.organization_error_handler(org.id, 'run_object_storage_ingestion'): await self.run_object_storage_ingestion(org) - except: # noqa: E722 - self.logger.exception({'task': 'run_organization_data_ingestion_alert', 'org_id': org.id}) async def run_organization(self, organization): """Try enqueue monitor execution tasks.""" @@ -160,8 +166,8 @@ async def run_organization(self, organization): # IMPORTANT NOTE: Forwarding the schedule only if the rule is passing for ALL the model versions. while ( - schedule_time <= model.end_time - and rules_pass(versions_windows, monitor, schedule_time, model) + schedule_time <= model.end_time + and rules_pass(versions_windows, monitor, schedule_time, model) ): schedules.append(schedule_time) schedule_time = schedule_time + frequency @@ -243,10 +249,10 @@ async def run_object_storage_ingestion(self, organization): async def get_versions_hour_windows( - model: Model, - versions: t.List[ModelVersion], - session: AsyncSession, - minimum_time: 'PendulumDateTime' + model: Model, + versions: t.List[ModelVersion], + session: AsyncSession, + minimum_time: 'PendulumDateTime' ) -> t.List[t.Dict[int, t.Dict]]: """Get windows data for all given versions starting from minimum time. @@ -276,8 +282,8 @@ async def get_versions_hour_windows( func.count(mon_table.c[SAMPLE_PRED_COL]).label('count_predictions'), func.max(mon_table.c[SAMPLE_LOGGED_TIME_COL]).label('max_logged_timestamp'), func.count(labels_table.c[SAMPLE_LABEL_COL]).label('count_labels') - ).join(labels_table, mon_table.c[SAMPLE_ID_COL] == labels_table.c[SAMPLE_ID_COL], isouter=True)\ - .where(hour_window > minimum_time)\ + ).join(labels_table, mon_table.c[SAMPLE_ID_COL] == labels_table.c[SAMPLE_ID_COL], isouter=True) \ + .where(hour_window > minimum_time) \ .group_by(hour_window) records = (await session.execute(query)).all() @@ -286,10 +292,10 @@ async def get_versions_hour_windows( def rules_pass( - versions_windows: t.List[t.Dict[int, t.Dict]], - monitor: Monitor, - schedule_time: pdl.DateTime, - model: Model + versions_windows: t.List[t.Dict[int, t.Dict]], + monitor: Monitor, + schedule_time: pdl.DateTime, + model: Model ): """Check the versions windows for given schedule time. If in all versions at least one of the alerts delay rules \ passes, return True. Otherwise, return False.""" @@ -323,8 +329,9 @@ def rules_pass( labels_percent = total_label_count / total_preds_count # Test the rules. If both rules don't pass, return False. if ( - labels_percent < model.alerts_delay_labels_ratio - and max_timestamp and pdl.instance(max_timestamp).add(seconds=model.alerts_delay_seconds) > pdl.now() + labels_percent < model.alerts_delay_labels_ratio + and max_timestamp + and pdl.instance(max_timestamp).add(seconds=model.alerts_delay_seconds) > pdl.now() ): return False # In all versions at least one of the rules passed, return True @@ -403,6 +410,7 @@ class SchedulerSettings(BaseSchedulerSettings): def execute_alerts_scheduler(scheduler_implementation: t.Type[AlertsScheduler]): """Execute alrets scheduler.""" + async def main(): settings = SchedulerSettings() # type: ignore service_name = 'alerts-scheduler' @@ -453,4 +461,5 @@ async def main(): # we need to reimport AlertsScheduler type # from deepchecks_monitoring.bgtasks import scheduler + execute_alerts_scheduler(scheduler.AlertsScheduler) diff --git a/backend/requirements.txt b/backend/requirements.txt index 2b1841ce..9cefc786 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,7 +11,7 @@ alembic==1.8.0 click==8.1.3 psycopg2==2.9.3 orjson~=3.9.15 -python-multipart~=0.0.9 +python-multipart==0.0.12 # bugs with pypi version jinja2==3.1.3 aiokafka==0.11.0 confluent-kafka==2.3.0