Skip to content

Commit

Permalink
Refactor pulling dataset rows (#617)
Browse files Browse the repository at this point in the history
Refactor pulling dataset rows
  • Loading branch information
ilongin authored Dec 19, 2024
1 parent b8e9856 commit 983cbd8
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 44 deletions.
88 changes: 50 additions & 38 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
import json
import logging
import math
import os
import os.path
import posixpath
Expand All @@ -13,7 +12,6 @@
from copy import copy
from dataclasses import dataclass
from functools import cached_property, reduce
from random import shuffle
from threading import Thread
from typing import (
IO,
Expand Down Expand Up @@ -58,11 +56,7 @@
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.remote.studio import StudioClient
from datachain.sql.types import DateTime, SQLType
from datachain.utils import (
DataChainDir,
batched,
datachain_paths_join,
)
from datachain.utils import DataChainDir, datachain_paths_join

from .datasource import DataSource

Expand Down Expand Up @@ -90,7 +84,7 @@
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11

# dataset pull
PULL_DATASET_MAX_THREADS = 10
PULL_DATASET_MAX_THREADS = 5
PULL_DATASET_CHUNK_TIMEOUT = 3600
PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be available
PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
Expand Down Expand Up @@ -130,6 +124,7 @@ def __init__(
local_ds_version: int,
schema: dict[str, Union[SQLType, type[SQLType]]],
max_threads: int = PULL_DATASET_MAX_THREADS,
progress_bar=None,
):
super().__init__(max_threads)
self._check_dependencies()
Expand All @@ -142,6 +137,7 @@ def __init__(
self.schema = schema
self.last_status_check: Optional[float] = None
self.studio_client = StudioClient()
self.progress_bar = progress_bar

def done_task(self, done):
for task in done:
Expand Down Expand Up @@ -198,6 +194,20 @@ def fix_columns(self, df) -> None:
for c in [c for c, t in self.schema.items() if t == DateTime]:
df[c] = pd.to_datetime(df[c], unit="s")

# id will be autogenerated in DB
return df.drop("sys__id", axis=1)

def get_parquet_content(self, url: str):
while True:
if self.should_check_for_status():
self.check_for_status()
r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
if r.status_code == 404:
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
continue
r.raise_for_status()
return r.content

def do_task(self, urls):
import lz4.frame
import pandas as pd
Expand All @@ -207,31 +217,22 @@ def do_task(self, urls):
local_ds = metastore.get_dataset(self.local_ds_name)

urls = list(urls)
while urls:
for url in urls:
if self.should_check_for_status():
self.check_for_status()

r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
if r.status_code == 404:
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
# moving to the next url
continue

r.raise_for_status()

df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content)))

self.fix_columns(df)
for url in urls:
if self.should_check_for_status():
self.check_for_status()

# id will be autogenerated in DB
df = df.drop("sys__id", axis=1)
df = pd.read_parquet(
io.BytesIO(lz4.frame.decompress(self.get_parquet_content(url)))
)
df = self.fix_columns(df)

inserted = warehouse.insert_dataset_rows(
df, local_ds, self.local_ds_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
urls.remove(url)
inserted = warehouse.insert_dataset_rows(
df, local_ds, self.local_ds_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
# sometimes progress bar doesn't get updated so manually updating it
self.update_progress_bar(self.progress_bar)


@dataclass
Expand Down Expand Up @@ -1291,7 +1292,7 @@ def ls(
for source in data_sources: # type: ignore [union-attr]
yield source, source.ls(fields)

def pull_dataset( # noqa: PLR0915
def pull_dataset( # noqa: C901, PLR0915
self,
remote_ds_uri: str,
output: Optional[str] = None,
Expand Down Expand Up @@ -1417,12 +1418,26 @@ def _instantiate(ds_uri: str) -> None:
signed_urls = export_response.data

if signed_urls:
shuffle(signed_urls)

with (
self.metastore.clone() as metastore,
self.warehouse.clone() as warehouse,
):

def batch(urls):
"""
Batching urls in a way that fetching is most efficient as
urls with lower id will be created first. Because that, we
are making sure all threads are pulling most recent urls
from beginning
"""
res = [[] for i in range(PULL_DATASET_MAX_THREADS)]
current_worker = 0
for url in signed_urls:
res[current_worker].append(url)
current_worker = (current_worker + 1) % PULL_DATASET_MAX_THREADS

return res

rows_fetcher = DatasetRowsFetcher(
metastore,
warehouse,
Expand All @@ -1431,14 +1446,11 @@ def _instantiate(ds_uri: str) -> None:
local_ds_name,
local_ds_version,
schema,
progress_bar=dataset_save_progress_bar,
)
try:
rows_fetcher.run(
batched(
signed_urls,
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
),
dataset_save_progress_bar,
iter(batch(signed_urls)), dataset_save_progress_bar
)
except:
self.remove_dataset(local_ds_name, local_ds_version)
Expand Down
26 changes: 20 additions & 6 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,12 @@ def execute(

@retry_sqlite_locks
def executemany(
self, query, params, cursor: Optional[sqlite3.Cursor] = None
self, query, params, cursor: Optional[sqlite3.Cursor] = None, conn=None
) -> sqlite3.Cursor:
if cursor:
return cursor.executemany(self.compile(query).string, params)
if conn:
return conn.executemany(self.compile(query).string, params)
return self.db.executemany(self.compile(query).string, params)

@retry_sqlite_locks
Expand All @@ -222,7 +224,14 @@ def execute_str(self, sql: str, parameters=None) -> sqlite3.Cursor:
return self.db.execute(sql, parameters)

def insert_dataframe(self, table_name: str, df) -> int:
return df.to_sql(table_name, self.db, if_exists="append", index=False)
return df.to_sql(
table_name,
self.db,
if_exists="append",
index=False,
method="multi",
chunksize=1000,
)

def cursor(self, factory=None):
if factory is None:
Expand Down Expand Up @@ -545,10 +554,15 @@ def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
rows = list(rows)
if not rows:
return
self.db.executemany(
table.insert().values({f: bindparam(f) for f in rows[0]}),
rows,
)

with self.db.transaction() as conn:
# transactions speeds up inserts significantly as there is no separate
# transaction created for each insert row
self.db.executemany(
table.insert().values({f: bindparam(f) for f in rows[0]}),
rows,
conn=conn,
)

def insert_dataset_rows(self, df, dataset: DatasetRecord, version: int) -> int:
dr = self.dataset_rows(dataset, version)
Expand Down
6 changes: 6 additions & 0 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def dataset_export_data_chunk(
@pytest.mark.parametrize("instantiate", [True, False])
@skip_if_not_sqlite
def test_pull_dataset_success(
mocker,
cloud_test_catalog,
remote_dataset_info,
remote_dataset_stats,
Expand All @@ -201,6 +202,11 @@ def test_pull_dataset_success(
local_ds_version,
instantiate,
):
mocker.patch(
"datachain.catalog.catalog.DatasetRowsFetcher.should_check_for_status",
return_value=True,
)

src_uri = cloud_test_catalog.src_uri
working_dir = cloud_test_catalog.working_dir
catalog = cloud_test_catalog.catalog
Expand Down

0 comments on commit 983cbd8

Please sign in to comment.