Skip to content

Commit

Permalink
Fixing too many files open, and adding reconnect (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtulga authored Aug 6, 2024
1 parent 1e5178b commit 8701ba5
Show file tree
Hide file tree
Showing 21 changed files with 616 additions and 391 deletions.
91 changes: 47 additions & 44 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,36 +236,36 @@ def do_task(self, urls):
import lz4.frame
import pandas as pd

metastore = self.metastore.clone() # metastore is not thread safe
warehouse = self.warehouse.clone() # warehouse is not thread safe
dataset = metastore.get_dataset(self.dataset_name)

urls = list(urls)
while urls:
for url in urls:
if self.should_check_for_status():
self.check_for_status()

r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
if r.status_code == 404:
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
# moving to the next url
continue
# metastore and warehouse are not thread safe
with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse:
dataset = metastore.get_dataset(self.dataset_name)

r.raise_for_status()
urls = list(urls)
while urls:
for url in urls:
if self.should_check_for_status():
self.check_for_status()

df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content)))
r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
if r.status_code == 404:
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
# moving to the next url
continue

self.fix_columns(df)
r.raise_for_status()

# id will be autogenerated in DB
df = df.drop("sys__id", axis=1)
df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content)))

inserted = warehouse.insert_dataset_rows(
df, dataset, self.dataset_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
urls.remove(url)
self.fix_columns(df)

# id will be autogenerated in DB
df = df.drop("sys__id", axis=1)

inserted = warehouse.insert_dataset_rows(
df, dataset, self.dataset_version
)
self.increase_counter(inserted) # type: ignore [arg-type]
urls.remove(url)


@dataclass
Expand Down Expand Up @@ -720,7 +720,6 @@ def enlist_source(
client.uri, posixpath.join(prefix, "")
)
source_metastore = self.metastore.clone(client.uri)
source_warehouse = self.warehouse.clone()

columns = [
Column("vtype", String),
Expand Down Expand Up @@ -1835,25 +1834,29 @@ def _instantiate_dataset():
if signed_urls:
shuffle(signed_urls)

rows_fetcher = DatasetRowsFetcher(
self.metastore.clone(),
self.warehouse.clone(),
remote_config,
dataset.name,
version,
schema,
)
try:
rows_fetcher.run(
batched(
signed_urls,
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
),
dataset_save_progress_bar,
with (
self.metastore.clone() as metastore,
self.warehouse.clone() as warehouse,
):
rows_fetcher = DatasetRowsFetcher(
metastore,
warehouse,
remote_config,
dataset.name,
version,
schema,
)
except:
self.remove_dataset(dataset.name, version)
raise
try:
rows_fetcher.run(
batched(
signed_urls,
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
),
dataset_save_progress_bar,
)
except:
self.remove_dataset(dataset.name, version)
raise

dataset = self.metastore.update_dataset_status(
dataset,
Expand Down
8 changes: 6 additions & 2 deletions src/datachain/data_storage/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

import sqlalchemy as sa
from attrs import frozen
from sqlalchemy.sql import FROM_LINTING
from sqlalchemy.sql.roles import DDLRole

Expand All @@ -23,13 +22,18 @@
SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time


@frozen
class DatabaseEngine(ABC, Serializable):
dialect: ClassVar["Dialect"]

engine: "Engine"
metadata: "MetaData"

def __enter__(self) -> "DatabaseEngine":
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

@abstractmethod
def clone(self) -> "DatabaseEngine":
"""Clones DatabaseEngine implementation."""
Expand Down
14 changes: 14 additions & 0 deletions src/datachain/data_storage/id_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def init(self) -> None:
def cleanup_for_tests(self):
"""Cleanup for tests."""

def close(self) -> None:
"""Closes any active database connections."""

def close_on_exit(self) -> None:
"""Closes any active database or HTTP connections, called on Session exit or
for test cleanup only, as some ID Generator implementations may handle this
differently.
"""
self.close()

@abstractmethod
def init_id(self, uri: str) -> None:
"""Initializes the ID generator for the given URI with zero last_id."""
Expand Down Expand Up @@ -83,6 +93,10 @@ def __init__(
def clone(self) -> "AbstractDBIDGenerator":
"""Clones AbstractIDGenerator implementation."""

def close(self) -> None:
"""Closes any active database connections."""
self.db.close()

@property
def db(self) -> "DatabaseEngine":
return self._db
Expand Down
13 changes: 13 additions & 0 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def __init__(
self.uri = uri
self.partial_id: Optional[int] = partial_id

def __enter__(self) -> "AbstractMetastore":
"""Returns self upon entering context manager."""
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Default behavior is to do nothing, as connections may be shared."""

@abstractmethod
def clone(
self,
Expand All @@ -97,6 +104,12 @@ def init(self, uri: StorageURI) -> None:
def close(self) -> None:
"""Closes any active database or HTTP connections."""

def close_on_exit(self) -> None:
"""Closes any active database or HTTP connections, called on Session exit or
for test cleanup only, as some Metastore implementations may handle this
differently."""
self.close()

def cleanup_tables(self, temp_table_names: list[str]) -> None:
"""Cleanup temp tables."""

Expand Down
51 changes: 45 additions & 6 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

import sqlalchemy
from attrs import frozen
from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
Expand All @@ -40,6 +39,7 @@

if TYPE_CHECKING:
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.engine.base import Engine
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
from sqlalchemy.sql.selectable import Select
Expand All @@ -52,6 +52,8 @@
RETRY_MAX_TIMES = 10
RETRY_FACTOR = 2

DETECT_TYPES = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES

Column = Union[str, "ColumnClause[Any]", "TextClause"]

datachain.sql.sqlite.setup()
Expand Down Expand Up @@ -80,26 +82,41 @@ def wrapper(*args, **kwargs):
return wrapper


@frozen
class SQLiteDatabaseEngine(DatabaseEngine):
dialect = sqlite_dialect

db: sqlite3.Connection
db_file: Optional[str]
is_closed: bool

def __init__(
self,
engine: "Engine",
metadata: "MetaData",
db: sqlite3.Connection,
db_file: Optional[str] = None,
):
self.engine = engine
self.metadata = metadata
self.db = db
self.db_file = db_file
self.is_closed = False

@classmethod
def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine":
detect_types = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
return cls(*cls._connect(db_file=db_file))

@staticmethod
def _connect(db_file: Optional[str] = None):
try:
if db_file == ":memory:":
# Enable multithreaded usage of the same in-memory db
db = sqlite3.connect(
"file::memory:?cache=shared", uri=True, detect_types=detect_types
"file::memory:?cache=shared", uri=True, detect_types=DETECT_TYPES
)
else:
db = sqlite3.connect(
db_file or DataChainDir.find().db, detect_types=detect_types
db_file or DataChainDir.find().db, detect_types=DETECT_TYPES
)
create_user_defined_sql_functions(db)
engine = sqlalchemy.create_engine(
Expand All @@ -118,7 +135,7 @@ def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine":

load_usearch_extension(db)

return cls(engine, MetaData(), db, db_file)
return engine, MetaData(), db, db_file
except RuntimeError:
raise DataChainError("Can't connect to SQLite DB") from None

Expand All @@ -138,13 +155,26 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
{},
)

def _reconnect(self) -> None:
if not self.is_closed:
raise RuntimeError("Cannot reconnect on still-open DB!")
engine, metadata, db, db_file = self._connect(db_file=self.db_file)
self.engine = engine
self.metadata = metadata
self.db = db
self.db_file = db_file
self.is_closed = False

@retry_sqlite_locks
def execute(
self,
query,
cursor: Optional[sqlite3.Cursor] = None,
conn=None,
) -> sqlite3.Cursor:
if self.is_closed:
# Reconnect in case of being closed previously.
self._reconnect()
if cursor is not None:
result = cursor.execute(*self.compile_to_args(query))
elif conn is not None:
Expand Down Expand Up @@ -179,6 +209,7 @@ def cursor(self, factory=None):

def close(self) -> None:
self.db.close()
self.is_closed = True

@contextmanager
def transaction(self):
Expand Down Expand Up @@ -359,6 +390,10 @@ def __init__(

self._init_tables()

def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Close connection upon exit from context manager."""
self.close()

def clone(
self,
uri: StorageURI = StorageURI(""),
Expand Down Expand Up @@ -521,6 +556,10 @@ def __init__(

self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)

def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Close connection upon exit from context manager."""
self.close()

def clone(self, use_new_connection: bool = False) -> "SQLiteWarehouse":
return SQLiteWarehouse(self.id_generator.clone(), db=self.db.clone())

Expand Down
13 changes: 13 additions & 0 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ class AbstractWarehouse(ABC, Serializable):
def __init__(self, id_generator: "AbstractIDGenerator"):
self.id_generator = id_generator

def __enter__(self) -> "AbstractWarehouse":
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
# Default behavior is to do nothing, as connections may be shared.
pass

def cleanup_for_tests(self):
"""Cleanup for tests."""

Expand Down Expand Up @@ -158,6 +165,12 @@ def close(self) -> None:
"""Closes any active database connections."""
self.db.close()

def close_on_exit(self) -> None:
"""Closes any active database or HTTP connections, called on Session exit or
for test cleanup only, as some Warehouse implementations may handle this
differently."""
self.close()

#
# Query Tables
#
Expand Down
Loading

0 comments on commit 8701ba5

Please sign in to comment.