Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alex/llm-4288-serializationerror-could-not-serialize-access-due-to #376

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions backend/deepchecks_monitoring/bgtasks/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging.handlers
import typing as t
from collections import defaultdict
from contextlib import asynccontextmanager
from time import perf_counter

import anyio
Expand Down Expand Up @@ -53,10 +54,10 @@ class AlertsScheduler:
"""Alerts scheduler."""

def __init__(
self,
engine: AsyncEngine,
sleep_seconds: int = TimeUnit.MINUTE * 5,
logger: t.Optional[logging.Logger] = None,
self,
engine: AsyncEngine,
sleep_seconds: int = TimeUnit.MINUTE * 5,
logger: t.Optional[logging.Logger] = None,
):
self.engine = engine
self.async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
Expand All @@ -83,6 +84,15 @@ async def run(self):
self.logger.warning('Scheduler interrupted')
raise

@asynccontextmanager
async def organization_error_handler(self, org_id, task_name):
try:
yield
except SerializationError: # We use 'Repeatable Read Isolation Level'.
self.logger.warning({'task': task_name, 'org_id': org_id})
except: # noqa: E722
self.logger.exception({'task': task_name, 'org_id': org_id})

async def run_all_organizations(self):
"""Enqueue tasks for execution."""
async with self.async_session_factory() as session:
Expand All @@ -95,19 +105,15 @@ async def run_all_organizations(self):
return

for org in organizations:
try:
async with self.organization_error_handler(org.id, 'run_organization'):
await self.run_organization(org)
except: # noqa: E722
self.logger.exception({'task': 'run_organization', 'org_id': org.id})
try:

async with self.organization_error_handler(org.id, 'run_organization_data_ingestion_alert'):
await self.run_organization_data_ingestion_alert(org)
except: # noqa: E722
self.logger.exception({'task': 'run_organization_data_ingestion_alert', 'org_id': org.id})

if with_ee:
try:
async with self.organization_error_handler(org.id, 'run_object_storage_ingestion'):
await self.run_object_storage_ingestion(org)
except: # noqa: E722
self.logger.exception({'task': 'run_organization_data_ingestion_alert', 'org_id': org.id})

async def run_organization(self, organization):
"""Try enqueue monitor execution tasks."""
Expand Down Expand Up @@ -160,8 +166,8 @@ async def run_organization(self, organization):

# IMPORTANT NOTE: Forwarding the schedule only if the rule is passing for ALL the model versions.
while (
schedule_time <= model.end_time
and rules_pass(versions_windows, monitor, schedule_time, model)
schedule_time <= model.end_time
and rules_pass(versions_windows, monitor, schedule_time, model)
):
schedules.append(schedule_time)
schedule_time = schedule_time + frequency
Expand Down Expand Up @@ -243,10 +249,10 @@ async def run_object_storage_ingestion(self, organization):


async def get_versions_hour_windows(
model: Model,
versions: t.List[ModelVersion],
session: AsyncSession,
minimum_time: 'PendulumDateTime'
model: Model,
versions: t.List[ModelVersion],
session: AsyncSession,
minimum_time: 'PendulumDateTime'
) -> t.List[t.Dict[int, t.Dict]]:
"""Get windows data for all given versions starting from minimum time.

Expand Down Expand Up @@ -276,8 +282,8 @@ async def get_versions_hour_windows(
func.count(mon_table.c[SAMPLE_PRED_COL]).label('count_predictions'),
func.max(mon_table.c[SAMPLE_LOGGED_TIME_COL]).label('max_logged_timestamp'),
func.count(labels_table.c[SAMPLE_LABEL_COL]).label('count_labels')
).join(labels_table, mon_table.c[SAMPLE_ID_COL] == labels_table.c[SAMPLE_ID_COL], isouter=True)\
.where(hour_window > minimum_time)\
).join(labels_table, mon_table.c[SAMPLE_ID_COL] == labels_table.c[SAMPLE_ID_COL], isouter=True) \
.where(hour_window > minimum_time) \
.group_by(hour_window)

records = (await session.execute(query)).all()
Expand All @@ -286,10 +292,10 @@ async def get_versions_hour_windows(


def rules_pass(
versions_windows: t.List[t.Dict[int, t.Dict]],
monitor: Monitor,
schedule_time: pdl.DateTime,
model: Model
versions_windows: t.List[t.Dict[int, t.Dict]],
monitor: Monitor,
schedule_time: pdl.DateTime,
model: Model
):
"""Check the versions windows for given schedule time. If in all versions at least one of the alerts delay rules \
passes, return True. Otherwise, return False."""
Expand Down Expand Up @@ -323,8 +329,9 @@ def rules_pass(
labels_percent = total_label_count / total_preds_count
# Test the rules. If both rules don't pass, return False.
if (
labels_percent < model.alerts_delay_labels_ratio
and max_timestamp and pdl.instance(max_timestamp).add(seconds=model.alerts_delay_seconds) > pdl.now()
labels_percent < model.alerts_delay_labels_ratio
and max_timestamp
and pdl.instance(max_timestamp).add(seconds=model.alerts_delay_seconds) > pdl.now()
):
return False
# In all versions at least one of the rules passed, return True
Expand Down Expand Up @@ -403,6 +410,7 @@ class SchedulerSettings(BaseSchedulerSettings):

def execute_alerts_scheduler(scheduler_implementation: t.Type[AlertsScheduler]):
"""Execute alrets scheduler."""

async def main():
settings = SchedulerSettings() # type: ignore
service_name = 'alerts-scheduler'
Expand Down Expand Up @@ -453,4 +461,5 @@ async def main():
# we need to reimport AlertsScheduler type
#
from deepchecks_monitoring.bgtasks import scheduler

execute_alerts_scheduler(scheduler.AlertsScheduler)
2 changes: 1 addition & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ alembic==1.8.0
click==8.1.3
psycopg2==2.9.3
orjson~=3.9.15
python-multipart~=0.0.9
python-multipart==0.0.12 # bugs with pypi version
jinja2==3.1.3
aiokafka==0.11.0
confluent-kafka==2.3.0
Expand Down
Loading