Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pulling dataset rows #617

Merged
merged 11 commits into from
Dec 19, 2024
88 changes: 50 additions & 38 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
import json
import logging
import math
import os
import os.path
import posixpath
Expand All @@ -13,7 +12,6 @@
from copy import copy
from dataclasses import dataclass
from functools import cached_property, reduce
from random import shuffle
from threading import Thread
from typing import (
IO,
Expand Down Expand Up @@ -58,11 +56,7 @@
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.remote.studio import StudioClient
from datachain.sql.types import DateTime, SQLType
from datachain.utils import (
DataChainDir,
batched,
datachain_paths_join,
)
from datachain.utils import DataChainDir, datachain_paths_join

from .datasource import DataSource

Expand Down Expand Up @@ -90,7 +84,7 @@
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11

# dataset pull
PULL_DATASET_MAX_THREADS = 10
PULL_DATASET_MAX_THREADS = 5
PULL_DATASET_CHUNK_TIMEOUT = 3600
PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be available
PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
Expand Down Expand Up @@ -130,6 +124,7 @@
local_ds_version: int,
schema: dict[str, Union[SQLType, type[SQLType]]],
max_threads: int = PULL_DATASET_MAX_THREADS,
progress_bar=None,
):
super().__init__(max_threads)
self._check_dependencies()
Expand All @@ -142,6 +137,7 @@
self.schema = schema
self.last_status_check: Optional[float] = None
self.studio_client = StudioClient()
self.progress_bar = progress_bar

def done_task(self, done):
for task in done:
Expand Down Expand Up @@ -198,6 +194,20 @@
for c in [c for c, t in self.schema.items() if t == DateTime]:
df[c] = pd.to_datetime(df[c], unit="s")

# id will be autogenerated in DB
return df.drop("sys__id", axis=1)

def get_parquet_content(self, url: str):
while True:
if self.should_check_for_status():
self.check_for_status()
r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
if r.status_code == 404:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if it is a silly question. How likely is it that we will be stuck on forever loop here if the url is indeed incorrect url and returning 404 response?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Do we need retry counter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, but we don't need retry counter here. 404 is expected as this particular chunk may not be exported yet into s3 (in parallel with this Studio is exporting chunks). If something actually fails and we are not able to export chunk to s3 which leads to 404 forever, export itself will fail in Studio and in this loop we are checking for export (whole export job) status on Studio as well every 20 seconds. When we realize that exporting dataset failed on Studio, we will print an error and end the loop.

time.sleep(PULL_DATASET_SLEEP_INTERVAL)
continue

Check warning on line 207 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L206-L207

Added lines #L206 - L207 were not covered by tests
r.raise_for_status()
return r.content

def do_task(self, urls):
import lz4.frame
import pandas as pd
Expand All @@ -207,31 +217,22 @@
local_ds = metastore.get_dataset(self.local_ds_name)

urls = list(urls)
while urls:
for url in urls:
if self.should_check_for_status():
self.check_for_status()

r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
if r.status_code == 404:
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
# moving to the next url
continue

r.raise_for_status()

df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content)))

self.fix_columns(df)
for url in urls:
if self.should_check_for_status():
self.check_for_status()

# id will be autogenerated in DB
df = df.drop("sys__id", axis=1)
df = pd.read_parquet(
io.BytesIO(lz4.frame.decompress(self.get_parquet_content(url)))
)
df = self.fix_columns(df)

inserted = warehouse.insert_dataset_rows(
df, local_ds, self.local_ds_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
urls.remove(url)
inserted = warehouse.insert_dataset_rows(
df, local_ds, self.local_ds_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
# sometimes progress bar doesn't get updated so manually updating it
self.update_progress_bar(self.progress_bar)


@dataclass
Expand Down Expand Up @@ -1291,7 +1292,7 @@
for source in data_sources: # type: ignore [union-attr]
yield source, source.ls(fields)

def pull_dataset( # noqa: PLR0915
def pull_dataset( # noqa: C901, PLR0915
self,
remote_ds_uri: str,
output: Optional[str] = None,
Expand Down Expand Up @@ -1417,12 +1418,26 @@
signed_urls = export_response.data

if signed_urls:
shuffle(signed_urls)

with (
self.metastore.clone() as metastore,
self.warehouse.clone() as warehouse,
):

def batch(urls):
"""
Batching urls in a way that fetching is most efficient as
urls with lower id will be created first. Because that, we
are making sure all threads are pulling most recent urls
from beginning
"""
res = [[] for i in range(PULL_DATASET_MAX_THREADS)]
current_worker = 0
for url in signed_urls:
res[current_worker].append(url)
current_worker = (current_worker + 1) % PULL_DATASET_MAX_THREADS

return res

rows_fetcher = DatasetRowsFetcher(
metastore,
warehouse,
Expand All @@ -1431,14 +1446,11 @@
local_ds_name,
local_ds_version,
schema,
progress_bar=dataset_save_progress_bar,
)
try:
rows_fetcher.run(
batched(
signed_urls,
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
),
dataset_save_progress_bar,
iter(batch(signed_urls)), dataset_save_progress_bar
)
except:
self.remove_dataset(local_ds_name, local_ds_version)
Expand Down
26 changes: 20 additions & 6 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,12 @@ def execute(

@retry_sqlite_locks
def executemany(
self, query, params, cursor: Optional[sqlite3.Cursor] = None
self, query, params, cursor: Optional[sqlite3.Cursor] = None, conn=None
) -> sqlite3.Cursor:
if cursor:
return cursor.executemany(self.compile(query).string, params)
if conn:
return conn.executemany(self.compile(query).string, params)
return self.db.executemany(self.compile(query).string, params)

@retry_sqlite_locks
Expand All @@ -222,7 +224,14 @@ def execute_str(self, sql: str, parameters=None) -> sqlite3.Cursor:
return self.db.execute(sql, parameters)

def insert_dataframe(self, table_name: str, df) -> int:
return df.to_sql(table_name, self.db, if_exists="append", index=False)
return df.to_sql(
table_name,
self.db,
if_exists="append",
index=False,
method="multi",
chunksize=1000,
)

def cursor(self, factory=None):
if factory is None:
Expand Down Expand Up @@ -545,10 +554,15 @@ def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
rows = list(rows)
if not rows:
return
self.db.executemany(
table.insert().values({f: bindparam(f) for f in rows[0]}),
rows,
)

with self.db.transaction() as conn:
# transactions speeds up inserts significantly as there is no separate
# transaction created for each insert row
self.db.executemany(
table.insert().values({f: bindparam(f) for f in rows[0]}),
rows,
conn=conn,
)

def insert_dataset_rows(self, df, dataset: DatasetRecord, version: int) -> int:
dr = self.dataset_rows(dataset, version)
Expand Down
6 changes: 6 additions & 0 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def dataset_export_data_chunk(
@pytest.mark.parametrize("instantiate", [True, False])
@skip_if_not_sqlite
def test_pull_dataset_success(
mocker,
cloud_test_catalog,
remote_dataset_info,
remote_dataset_stats,
Expand All @@ -201,6 +202,11 @@ def test_pull_dataset_success(
local_ds_version,
instantiate,
):
mocker.patch(
"datachain.catalog.catalog.DatasetRowsFetcher.should_check_for_status",
return_value=True,
)

src_uri = cloud_test_catalog.src_uri
working_dir = cloud_test_catalog.working_dir
catalog = cloud_test_catalog.catalog
Expand Down
Loading