Skip to content

Commit

Permalink
fix ingestion for large sample data (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
JKL98ISR authored Nov 27, 2023
1 parent 809bc32 commit 392f1cb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,22 @@ async def run(self, task: 'Task', session: AsyncSession, resources_provider: Res
for prefix in version_prefixes:
for df, time in self.ingest_prefix(s3, bucket, f'{version_path}/{prefix}', version.latest_file_time,
errors, version.model_id, version.id):
# For each file, set lock expiry to 240 seconds from now
await lock.extend(240, replace_ttl=True)
await self.ingestion_backend.log_samples(version, df, session, organization_id, new_scan_time)
version.latest_file_time = max(version.latest_file_time or
pdl.datetime(year=1970, month=1, day=1), time)
# For each file, set lock expiry to 120 seconds from now
await lock.extend(120, replace_ttl=True)

# Ingest labels
for prefix in model_prefixes:
labels_path = f'{model_path}/labels/{prefix}'
for df, time in self.ingest_prefix(s3, bucket, labels_path, model.latest_labels_file_time,
errors, model_id):
# For each file, set lock expiry to 240 seconds from now
await lock.extend(240, replace_ttl=True)
await self.ingestion_backend.log_labels(model, df, session, organization_id)
model.latest_labels_file_time = max(model.latest_labels_file_time
or pdl.datetime(year=1970, month=1, day=1), time)
# For each file, set lock expiry to 120 seconds from now
await lock.extend(120, replace_ttl=True)

model.obj_store_last_scan_time = new_scan_time
except Exception: # pylint: disable=broad-except
Expand Down
17 changes: 12 additions & 5 deletions backend/deepchecks_monitoring/logic/data_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
__all__ = ["DataIngestionBackend", "log_data", "log_labels", "save_failures"]


QUERY_PARAM_LIMIT = 32765


async def log_data(
model_version: ModelVersion,
data: t.List[t.Dict[t.Any, t.Any]],
Expand Down Expand Up @@ -113,17 +116,21 @@ async def log_data(
# Starting by adding to the version map
versions_map = model.get_samples_versions_map_table(session)
ids_to_log = [{SAMPLE_ID_COL: sample_id, "version_id": model_version.id} for sample_id in valid_data]
statement = (postgresql.insert(versions_map).values(ids_to_log)
.on_conflict_do_nothing(index_elements=versions_map.primary_key.columns)
.returning(versions_map.c[SAMPLE_ID_COL]))
ids_not_existing = set((await session.execute(statement)).scalars())
ids_not_existing = set()
max_messages_per_insert = QUERY_PARAM_LIMIT // 5
for start_index in range(0, len(ids_to_log), max_messages_per_insert):
statement = (postgresql.insert(versions_map)
.values(ids_to_log[start_index:start_index + max_messages_per_insert])
.on_conflict_do_nothing(index_elements=versions_map.primary_key.columns)
.returning(versions_map.c[SAMPLE_ID_COL]))
ids_not_existing.update((await session.execute(statement)).scalars())
# Filter from the data ids which weren't logged to the versions table
data_list = [sample for id, sample in valid_data.items() if id in ids_not_existing]
if data_list:
# Postgres driver has a limit of 32767 query params, which for 1000 messages, limits us to 32 columns. In
# order to solve that we can either pre-compile the statement with bind literals, or separate to batches
num_columns = len(data_list[0])
max_messages_per_insert = 32767 // num_columns
max_messages_per_insert = QUERY_PARAM_LIMIT // num_columns
monitor_table = model_version.get_monitor_table(session)
for start_index in range(0, len(data_list), max_messages_per_insert):
batch = data_list[start_index:start_index + max_messages_per_insert]
Expand Down

0 comments on commit 392f1cb

Please sign in to comment.