diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 89951d446..5543353df 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -236,36 +236,36 @@ def do_task(self, urls): import lz4.frame import pandas as pd - metastore = self.metastore.clone() # metastore is not thread safe - warehouse = self.warehouse.clone() # warehouse is not thread safe - dataset = metastore.get_dataset(self.dataset_name) - - urls = list(urls) - while urls: - for url in urls: - if self.should_check_for_status(): - self.check_for_status() - - r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT) - if r.status_code == 404: - time.sleep(PULL_DATASET_SLEEP_INTERVAL) - # moving to the next url - continue + # metastore and warehouse are not thread safe + with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse: + dataset = metastore.get_dataset(self.dataset_name) - r.raise_for_status() + urls = list(urls) + while urls: + for url in urls: + if self.should_check_for_status(): + self.check_for_status() - df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content))) + r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT) + if r.status_code == 404: + time.sleep(PULL_DATASET_SLEEP_INTERVAL) + # moving to the next url + continue - self.fix_columns(df) + r.raise_for_status() - # id will be autogenerated in DB - df = df.drop("sys__id", axis=1) + df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content))) - inserted = warehouse.insert_dataset_rows( - df, dataset, self.dataset_version - ) - self.increase_counter(inserted) # type: ignore [arg-type] - urls.remove(url) + self.fix_columns(df) + + # id will be autogenerated in DB + df = df.drop("sys__id", axis=1) + + inserted = warehouse.insert_dataset_rows( + df, dataset, self.dataset_version + ) + self.increase_counter(inserted) # type: ignore [arg-type] + urls.remove(url) @dataclass @@ -720,7 +720,6 @@ def enlist_source( client.uri, posixpath.join(prefix, "") ) source_metastore = self.metastore.clone(client.uri) - source_warehouse = self.warehouse.clone() columns = [ Column("vtype", String), @@ -1835,25 +1834,29 @@ def _instantiate_dataset(): if signed_urls: shuffle(signed_urls) - rows_fetcher = DatasetRowsFetcher( - self.metastore.clone(), - self.warehouse.clone(), - remote_config, - dataset.name, - version, - schema, - ) - try: - rows_fetcher.run( - batched( - signed_urls, - math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS), - ), - dataset_save_progress_bar, + with ( + self.metastore.clone() as metastore, + self.warehouse.clone() as warehouse, + ): + rows_fetcher = DatasetRowsFetcher( + metastore, + warehouse, + remote_config, + dataset.name, + version, + schema, ) - except: - self.remove_dataset(dataset.name, version) - raise + try: + rows_fetcher.run( + batched( + signed_urls, + math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS), + ), + dataset_save_progress_bar, + ) + except: + self.remove_dataset(dataset.name, version) + raise dataset = self.metastore.update_dataset_status( dataset, diff --git a/src/datachain/data_storage/db_engine.py b/src/datachain/data_storage/db_engine.py index e21d0f0cd..508ae6ced 100644 --- a/src/datachain/data_storage/db_engine.py +++ b/src/datachain/data_storage/db_engine.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union import sqlalchemy as sa -from attrs import frozen from sqlalchemy.sql import FROM_LINTING from sqlalchemy.sql.roles import DDLRole @@ -23,13 +22,18 @@ SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time -@frozen class DatabaseEngine(ABC, Serializable): dialect: ClassVar["Dialect"] engine: "Engine" metadata: "MetaData" + def __enter__(self) -> "DatabaseEngine": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + @abstractmethod def clone(self) -> "DatabaseEngine": """Clones DatabaseEngine implementation.""" diff --git a/src/datachain/data_storage/id_generator.py b/src/datachain/data_storage/id_generator.py index b311cb5d2..583f99206 100644 --- a/src/datachain/data_storage/id_generator.py +++ b/src/datachain/data_storage/id_generator.py @@ -33,6 +33,16 @@ def init(self) -> None: def cleanup_for_tests(self): """Cleanup for tests.""" + def close(self) -> None: + """Closes any active database connections.""" + + def close_on_exit(self) -> None: + """Closes any active database or HTTP connections, called on Session exit or + for test cleanup only, as some ID Generator implementations may handle this + differently. + """ + self.close() + @abstractmethod def init_id(self, uri: str) -> None: """Initializes the ID generator for the given URI with zero last_id.""" @@ -83,6 +93,10 @@ def __init__( def clone(self) -> "AbstractDBIDGenerator": """Clones AbstractIDGenerator implementation.""" + def close(self) -> None: + """Closes any active database connections.""" + self.db.close() + @property def db(self) -> "DatabaseEngine": return self._db diff --git a/src/datachain/data_storage/metastore.py b/src/datachain/data_storage/metastore.py index 85c1fd90f..4df2a85b3 100644 --- a/src/datachain/data_storage/metastore.py +++ b/src/datachain/data_storage/metastore.py @@ -78,6 +78,13 @@ def __init__( self.uri = uri self.partial_id: Optional[int] = partial_id + def __enter__(self) -> "AbstractMetastore": + """Returns self upon entering context manager.""" + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """Default behavior is to do nothing, as connections may be shared.""" + @abstractmethod def clone( self, @@ -97,6 +104,12 @@ def init(self, uri: StorageURI) -> None: def close(self) -> None: """Closes any active database or HTTP connections.""" + def close_on_exit(self) -> None: + """Closes any active database or HTTP connections, called on Session exit or + for test cleanup only, as some Metastore implementations may handle this + differently.""" + self.close() + def cleanup_tables(self, temp_table_names: list[str]) -> None: """Cleanup temp tables.""" diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index a280552d0..18f48ddd0 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -15,7 +15,6 @@ ) import sqlalchemy -from attrs import frozen from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select from sqlalchemy.dialects import sqlite from sqlalchemy.schema import CreateIndex, CreateTable, DropTable @@ -40,6 +39,7 @@ if TYPE_CHECKING: from sqlalchemy.dialects.sqlite import Insert + from sqlalchemy.engine.base import Engine from sqlalchemy.schema import SchemaItem from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause from sqlalchemy.sql.selectable import Select @@ -52,6 +52,8 @@ RETRY_MAX_TIMES = 10 RETRY_FACTOR = 2 +DETECT_TYPES = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES + Column = Union[str, "ColumnClause[Any]", "TextClause"] datachain.sql.sqlite.setup() @@ -80,26 +82,41 @@ def wrapper(*args, **kwargs): return wrapper -@frozen class SQLiteDatabaseEngine(DatabaseEngine): dialect = sqlite_dialect db: sqlite3.Connection db_file: Optional[str] + is_closed: bool + + def __init__( + self, + engine: "Engine", + metadata: "MetaData", + db: sqlite3.Connection, + db_file: Optional[str] = None, + ): + self.engine = engine + self.metadata = metadata + self.db = db + self.db_file = db_file + self.is_closed = False @classmethod def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine": - detect_types = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES + return cls(*cls._connect(db_file=db_file)) + @staticmethod + def _connect(db_file: Optional[str] = None): try: if db_file == ":memory:": # Enable multithreaded usage of the same in-memory db db = sqlite3.connect( - "file::memory:?cache=shared", uri=True, detect_types=detect_types + "file::memory:?cache=shared", uri=True, detect_types=DETECT_TYPES ) else: db = sqlite3.connect( - db_file or DataChainDir.find().db, detect_types=detect_types + db_file or DataChainDir.find().db, detect_types=DETECT_TYPES ) create_user_defined_sql_functions(db) engine = sqlalchemy.create_engine( @@ -118,7 +135,7 @@ def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine": load_usearch_extension(db) - return cls(engine, MetaData(), db, db_file) + return engine, MetaData(), db, db_file except RuntimeError: raise DataChainError("Can't connect to SQLite DB") from None @@ -138,6 +155,16 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]: {}, ) + def _reconnect(self) -> None: + if not self.is_closed: + raise RuntimeError("Cannot reconnect on still-open DB!") + engine, metadata, db, db_file = self._connect(db_file=self.db_file) + self.engine = engine + self.metadata = metadata + self.db = db + self.db_file = db_file + self.is_closed = False + @retry_sqlite_locks def execute( self, @@ -145,6 +172,9 @@ def execute( cursor: Optional[sqlite3.Cursor] = None, conn=None, ) -> sqlite3.Cursor: + if self.is_closed: + # Reconnect in case of being closed previously. + self._reconnect() if cursor is not None: result = cursor.execute(*self.compile_to_args(query)) elif conn is not None: @@ -179,6 +209,7 @@ def cursor(self, factory=None): def close(self) -> None: self.db.close() + self.is_closed = True @contextmanager def transaction(self): @@ -359,6 +390,10 @@ def __init__( self._init_tables() + def __exit__(self, exc_type, exc_value, traceback) -> None: + """Close connection upon exit from context manager.""" + self.close() + def clone( self, uri: StorageURI = StorageURI(""), @@ -521,6 +556,10 @@ def __init__( self.db = db or SQLiteDatabaseEngine.from_db_file(db_file) + def __exit__(self, exc_type, exc_value, traceback) -> None: + """Close connection upon exit from context manager.""" + self.close() + def clone(self, use_new_connection: bool = False) -> "SQLiteWarehouse": return SQLiteWarehouse(self.id_generator.clone(), db=self.db.clone()) diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 5fd9dbe36..aac04781f 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -70,6 +70,13 @@ class AbstractWarehouse(ABC, Serializable): def __init__(self, id_generator: "AbstractIDGenerator"): self.id_generator = id_generator + def __enter__(self) -> "AbstractWarehouse": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # Default behavior is to do nothing, as connections may be shared. + pass + def cleanup_for_tests(self): """Cleanup for tests.""" @@ -158,6 +165,12 @@ def close(self) -> None: """Closes any active database connections.""" self.db.close() + def close_on_exit(self) -> None: + """Closes any active database or HTTP connections, called on Session exit or + for test cleanup only, as some Warehouse implementations may handle this + differently.""" + self.close() + # # Query Tables # diff --git a/src/datachain/listing.py b/src/datachain/listing.py index 1fe62bbb2..c4658ce45 100644 --- a/src/datachain/listing.py +++ b/src/datachain/listing.py @@ -44,6 +44,16 @@ def clone(self) -> "Listing": self.dataset, ) + def __enter__(self) -> "Listing": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + + def close(self) -> None: + self.metastore.close() + self.warehouse.close() + @property def id(self): return self.storage.id @@ -56,16 +66,18 @@ def fetch(self, start_prefix="", method: str = "default") -> None: sync(get_loop(), self._fetch, start_prefix, method) async def _fetch(self, start_prefix: str, method: str) -> None: - self = self.clone() - if start_prefix: - start_prefix = start_prefix.rstrip("/") - try: - async for entries in self.client.scandir(start_prefix, method=method): - self.insert_entries(entries) - if len(entries) > 1: - self.metastore.update_last_inserted_at() - finally: - self.insert_entries_done() + with self.clone() as fetch_listing: + if start_prefix: + start_prefix = start_prefix.rstrip("/") + try: + async for entries in fetch_listing.client.scandir( + start_prefix, method=method + ): + fetch_listing.insert_entries(entries) + if len(entries) > 1: + fetch_listing.metastore.update_last_inserted_at() + finally: + fetch_listing.insert_entries_done() def insert_entry(self, entry: Entry) -> None: self.warehouse.insert_rows( diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 5a932a773..8ebebfe26 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -1051,8 +1051,11 @@ def __init__( if anon: client_config["anon"] = True + self.session = Session.get( + session, catalog=catalog, client_config=client_config + ) + self.catalog = catalog or self.session.catalog self.steps: list[Step] = [] - self.catalog = catalog or get_catalog(client_config=client_config) self._chunk_index: Optional[int] = None self._chunk_total: Optional[int] = None self.temp_table_names: list[str] = [] @@ -1063,7 +1066,6 @@ def __init__( self.version: Optional[int] = None self.feature_schema: Optional[dict] = None self.column_types: Optional[dict[str, Any]] = None - self.session = Session.get(session, catalog=catalog) if path: kwargs = {"update": True} if update else {} @@ -1200,12 +1202,10 @@ def cleanup(self) -> None: # This is needed to always use a new connection with all metastore and warehouse # implementations, as errors may close or render unusable the existing # connections. - metastore = self.catalog.metastore.clone(use_new_connection=True) - metastore.cleanup_tables(self.temp_table_names) - metastore.close() - warehouse = self.catalog.warehouse.clone(use_new_connection=True) - warehouse.cleanup_tables(self.temp_table_names) - warehouse.close() + with self.catalog.metastore.clone(use_new_connection=True) as metastore: + metastore.cleanup_tables(self.temp_table_names) + with self.catalog.warehouse.clone(use_new_connection=True) as warehouse: + warehouse.cleanup_tables(self.temp_table_names) self.temp_table_names = [] def db_results(self, row_factory=None, **kwargs): @@ -1248,19 +1248,12 @@ def extract( def row_iter() -> Generator[RowDict, None, None]: # warehouse isn't threadsafe, we need to clone() it # in the thread that uses the results - warehouse = None - try: - warehouse = self.catalog.warehouse.clone() + with self.catalog.warehouse.clone() as warehouse: gen = warehouse.dataset_select_paginated( query, limit=query._limit, order_by=query._order_by_clauses ) with contextlib.closing(gen) as rows: yield from rows - finally: - # clone doesn't necessarily create a new connection - # we can't do `warehouse.close()` for now. It is a bad design - # in clone / close interface that needs to be fixed. - pass async def get_params(row: RowDict) -> tuple: return tuple( diff --git a/src/datachain/query/session.py b/src/datachain/query/session.py index 8a8ae1e67..f0a34ec91 100644 --- a/src/datachain/query/session.py +++ b/src/datachain/query/session.py @@ -41,7 +41,12 @@ class Session: SESSION_UUID_LEN = 6 TEMP_TABLE_UUID_LEN = 6 - def __init__(self, name="", catalog: Optional["Catalog"] = None): + def __init__( + self, + name="", + catalog: Optional["Catalog"] = None, + client_config: Optional[dict] = None, + ): if re.match(r"^[0-9a-zA-Z]+$", name) is None: raise ValueError( f"Session name can contain only letters or numbers - '{name}' given." @@ -52,13 +57,18 @@ def __init__(self, name="", catalog: Optional["Catalog"] = None): session_uuid = uuid4().hex[: self.SESSION_UUID_LEN] self.name = f"{name}_{session_uuid}" - self.catalog = catalog or get_catalog() + self.is_new_catalog = not catalog + self.catalog = catalog or get_catalog(client_config=client_config) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._cleanup_temp_datasets() + if self.is_new_catalog: + self.catalog.metastore.close_on_exit() + self.catalog.warehouse.close_on_exit() + self.catalog.id_generator.close_on_exit() def generate_temp_dataset_name(self) -> str: tmp_table_uid = uuid4().hex[: self.TEMP_TABLE_UUID_LEN] @@ -75,7 +85,10 @@ def _cleanup_temp_datasets(self) -> None: @classmethod def get( - cls, session: Optional["Session"] = None, catalog: Optional["Catalog"] = None + cls, + session: Optional["Session"] = None, + catalog: Optional["Catalog"] = None, + client_config: Optional[dict] = None, ) -> "Session": """Creates a Session() object from a catalog. @@ -88,7 +101,9 @@ def get( return session if cls.GLOBAL_SESSION is None: - cls.GLOBAL_SESSION_CTX = Session(cls.GLOBAL_SESSION_NAME, catalog) + cls.GLOBAL_SESSION_CTX = Session( + cls.GLOBAL_SESSION_NAME, catalog, client_config=client_config + ) cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__() atexit.register(cls._global_cleanup) return cls.GLOBAL_SESSION diff --git a/tests/conftest.py b/tests/conftest.py index 49ad4a193..b5da7cda8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,8 @@ def clean_environment( @pytest.fixture def sqlite_db(): - return SQLiteDatabaseEngine.from_db_file(":memory:") + with SQLiteDatabaseEngine.from_db_file(":memory:") as db: + yield db def cleanup_sqlite_db( @@ -105,9 +106,10 @@ def id_generator(): _id_generator.cleanup_for_tests() - # Close the connection so that the SQLite file is no longer open, to avoid - # pytest throwing: OSError: [Errno 24] Too many open files - _id_generator.db.close() + # Close the connection so that the SQLite file is no longer open, to avoid + # pytest throwing: OSError: [Errno 24] Too many open files + # Or, for other implementations, prevent "too many clients" errors. + _id_generator.close_on_exit() @pytest.fixture @@ -122,23 +124,23 @@ def metastore(id_generator): yield _metastore cleanup_sqlite_db(_metastore.db.clone(), _metastore.default_table_names) - Session.cleanup_for_tests() - # Close the connection so that the SQLite file is no longer open, to avoid - # pytest throwing: OSError: [Errno 24] Too many open files - _metastore.db.close() + # Close the connection so that the SQLite file is no longer open, to avoid + # pytest throwing: OSError: [Errno 24] Too many open files + # Or, for other implementations, prevent "too many clients" errors. + _metastore.close_on_exit() def check_temp_tables_cleaned_up(original_warehouse): """Ensure that temporary tables are cleaned up.""" - warehouse = original_warehouse.clone() - assert [ - t - for t in sqlalchemy.inspect(warehouse.db.engine).get_table_names() - if t.startswith( - (warehouse.UDF_TABLE_NAME_PREFIX, warehouse.TMP_TABLE_NAME_PREFIX) - ) - ] == [] + with original_warehouse.clone() as warehouse: + assert [ + t + for t in sqlalchemy.inspect(warehouse.db.engine).get_table_names() + if t.startswith( + (warehouse.UDF_TABLE_NAME_PREFIX, warehouse.TMP_TABLE_NAME_PREFIX) + ) + ] == [] @pytest.fixture @@ -168,6 +170,12 @@ def catalog(id_generator, metastore, warehouse): return Catalog(id_generator=id_generator, metastore=metastore, warehouse=warehouse) +@pytest.fixture +def test_session(catalog): + with Session("TestSession", catalog=catalog) as session: + yield session + + @pytest.fixture def id_generator_tmpfile(tmp_path): if os.environ.get("DATACHAIN_ID_GENERATOR"): @@ -182,9 +190,10 @@ def id_generator_tmpfile(tmp_path): _id_generator.cleanup_for_tests() - # Close the connection so that the SQLite file is no longer open, to avoid - # pytest throwing: OSError: [Errno 24] Too many open files - _id_generator.db.close() + # Close the connection so that the SQLite file is no longer open, to avoid + # pytest throwing: OSError: [Errno 24] Too many open files + # Or, for other implementations, prevent "too many clients" errors. + _id_generator.close_on_exit() @pytest.fixture @@ -199,11 +208,11 @@ def metastore_tmpfile(tmp_path, id_generator_tmpfile): yield _metastore cleanup_sqlite_db(_metastore.db.clone(), _metastore.default_table_names) - Session.cleanup_for_tests() - # Close the connection so that the SQLite file is no longer open, to avoid - # pytest throwing: OSError: [Errno 24] Too many open files - _metastore.db.close() + # Close the connection so that the SQLite file is no longer open, to avoid + # pytest throwing: OSError: [Errno 24] Too many open files + # Or, for other implementations, prevent "too many clients" errors. + _metastore.close_on_exit() @pytest.fixture @@ -230,6 +239,24 @@ def warehouse_tmpfile(tmp_path, id_generator_tmpfile, metastore_tmpfile): _warehouse.db.close() +@pytest.fixture +def catalog_tmpfile(id_generator_tmpfile, metastore_tmpfile, warehouse_tmpfile): + # For testing parallel and distributed processing, as these cannot use + # in-memory databases. + return Catalog( + id_generator=id_generator_tmpfile, + metastore=metastore_tmpfile, + warehouse=warehouse_tmpfile, + ) + + +@pytest.fixture +def test_session_tmpfile(catalog_tmpfile): + # For testing parallel and distributed processing, as these cannot use + # in-memory databases. + return Session("TestSession", catalog=catalog_tmpfile) + + @pytest.fixture def tmp_dir(tmp_path_factory, monkeypatch): dpath = tmp_path_factory.mktemp("datachain-test") diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 896df2367..0b7954909 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -86,7 +86,7 @@ def test_export_files( tmp_dir, cloud_test_catalog, placement, use_map, use_cache, file_type ): ctc = cloud_test_catalog - df = DataChain.from_storage(ctc.src_uri, type=file_type) + df = DataChain.from_storage(ctc.src_uri, type=file_type, catalog=ctc.catalog) if use_map: df.export_files(tmp_dir / "output", placement=placement, use_cache=use_cache) df.map( @@ -117,7 +117,7 @@ def test_export_files( @pytest.mark.parametrize("use_cache", [True, False]) -def test_export_images_files(tmp_dir, tmp_path, use_cache): +def test_export_images_files(test_session, tmp_dir, tmp_path, use_cache): images = [ {"name": "img1.jpg", "data": Image.new(mode="RGB", size=(64, 64))}, {"name": "img2.jpg", "data": Image.new(mode="RGB", size=(128, 128))}, @@ -130,6 +130,7 @@ def test_export_images_files(tmp_dir, tmp_path, use_cache): file=[ ImageFile(name=img["name"], source=f"file://{tmp_path}") for img in images ], + session=test_session, ).export_files(tmp_dir / "output", placement="filename", use_cache=use_cache) for img in images: @@ -137,7 +138,7 @@ def test_export_images_files(tmp_dir, tmp_path, use_cache): assert images_equal(img["data"], exported_img) -def test_export_files_filename_placement_not_unique_files(tmp_dir, catalog): +def test_export_files_filename_placement_not_unique_files(tmp_dir, test_session): data = b"some\x00data\x00is\x48\x65\x6c\x57\x6f\x72\x6c\x64\xff\xffheRe" bucket_name = "mybucket" files = ["dir1/a.json", "dir1/dir2/a.json"] @@ -151,12 +152,12 @@ def test_export_files_filename_placement_not_unique_files(tmp_dir, catalog): with open(file_path, "wb") as fd: fd.write(data) - df = DataChain.from_storage((tmp_dir / bucket_name).as_uri()) + df = DataChain.from_storage((tmp_dir / bucket_name).as_uri(), session=test_session) with pytest.raises(ValueError): df.export_files(tmp_dir / "output", placement="filename") -def test_show(capsys, catalog): +def test_show(capsys, test_session): first_name = ["Alice", "Bob", "Charlie"] DataChain.from_values( first_name=first_name, @@ -166,6 +167,7 @@ def test_show(capsys, catalog): "Los Angeles", None, ], + session=test_session, ).show() captured = capsys.readouterr() normalized_output = re.sub(r"\s+", " ", captured.out) @@ -174,7 +176,7 @@ def test_show(capsys, catalog): assert f"{i} {first_name[i]}" in normalized_output -def test_show_limit(capsys, catalog): +def test_show_limit(capsys, test_session): first_name = ["Alice", "Bob", "Charlie"] DataChain.from_values( first_name=first_name, @@ -184,18 +186,20 @@ def test_show_limit(capsys, catalog): "Los Angeles", None, ], + session=test_session, ).limit(1).show() captured = capsys.readouterr() new_line_count = captured.out.count("\n") assert new_line_count == 2 -def test_show_transpose(capsys, catalog): +def test_show_transpose(capsys, test_session): first_name = ["Alice", "Bob", "Charlie"] last_name = ["A", "B", "C"] DataChain.from_values( first_name=first_name, last_name=last_name, + session=test_session, ).show(transpose=True) captured = capsys.readouterr() stripped_output = re.sub(r"\s+", " ", captured.out) @@ -203,7 +207,7 @@ def test_show_transpose(capsys, catalog): assert " ".join(last_name) in stripped_output -def test_show_truncate(capsys, catalog): +def test_show_truncate(capsys, test_session): client = ["Alice A", "Bob B", "Charles C"] details = [ "This is a very long piece of text that would not fit in the default output " @@ -215,6 +219,7 @@ def test_show_truncate(capsys, catalog): dc = DataChain.from_values( client=client, details=details, + session=test_session, ) dc.show() @@ -226,7 +231,7 @@ def test_show_truncate(capsys, catalog): assert f"{client[i]} {details[i]}" in normalized_output -def test_show_no_truncate(capsys, catalog): +def test_show_no_truncate(capsys, test_session): client = ["Alice A", "Bob B", "Charles C"] details = [ "This is a very long piece of text that would not fit in the default output " @@ -238,6 +243,7 @@ def test_show_no_truncate(capsys, catalog): dc = DataChain.from_values( client=client, details=details, + session=test_session, ) dc.show(truncate=False) @@ -248,25 +254,29 @@ def test_show_no_truncate(capsys, catalog): assert details[i] in normalized_output -def test_from_storage_dataset_stats(tmp_dir, catalog): +def test_from_storage_dataset_stats(tmp_dir, test_session): for i in range(4): (tmp_dir / f"file{i}.txt").write_text(f"file{i}") - dc = DataChain.from_storage(tmp_dir.as_uri(), catalog=catalog).save("test-data") - stats = catalog.dataset_stats(dc.name, dc.version) + dc = DataChain.from_storage(tmp_dir.as_uri(), session=test_session).save( + "test-data" + ) + stats = test_session.catalog.dataset_stats(dc.name, dc.version) assert stats == DatasetStats(num_objects=4, size=20) -def test_from_storage_check_rows(tmp_dir, catalog): +def test_from_storage_check_rows(tmp_dir, test_session): stats = {} for i in range(4): file = tmp_dir / f"{i}.txt" file.write_text(f"file{i}") stats[file.name] = file.stat() - dc = DataChain.from_storage(tmp_dir.as_uri(), catalog=catalog).save("test-data") + dc = DataChain.from_storage(tmp_dir.as_uri(), session=test_session).save( + "test-data" + ) - is_sqlite = isinstance(catalog.warehouse, SQLiteWarehouse) + is_sqlite = isinstance(test_session.catalog.warehouse, SQLiteWarehouse) tz = timezone.utc if is_sqlite else pytz.UTC for (file,) in dc.collect(): diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index aacf88ba8..79394e1af 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -931,6 +931,49 @@ def sum(self, size): ] +@pytest.mark.parametrize( + "cloud_type,version_aware", + [("s3", True)], + indirect=True, +) +def test_udf_reuse_on_error(cloud_test_catalog_tmpfile): + catalog = cloud_test_catalog_tmpfile.catalog + sources = [cloud_test_catalog_tmpfile.src_uri] + globs = [s.rstrip("/") + "/*" for s in sources] + catalog.index(sources) + catalog.create_dataset_from_sources("animals", globs, recursive=True) + + error_state = {"error": True} + + @udf((C.name,), {"name_len": Int}) + def name_len_maybe_error(name): + if error_state["error"]: + # A udf that raises an exception + raise RuntimeError("Test Error!") + return (len(name),) + + q = ( + DatasetQuery(name="animals", version=1, catalog=catalog) + .filter(C.size < 13) + .filter(C.parent.glob("cats*") | (C.size < 4)) + .add_signals(name_len_maybe_error) + .select(C.name, C.name_len) + ) + with pytest.raises(RuntimeError, match="Test Error!"): + q.db_results() + + # Simulate fixing the error + error_state["error"] = False + + # Retry Query + result = q.db_results() + + assert len(result) == 3 + for r in result: + # Check that the UDF ran successfully + assert len(r[0]) == r[1] + + @pytest.mark.parametrize( "cloud_type,version_aware", [("s3", True)], diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index a126beefb..db976c4b6 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -49,53 +49,60 @@ class MyNested(BaseModel): ] -def test_pandas_conversion(catalog): +def test_pandas_conversion(test_session): df = pd.DataFrame(DF_DATA) - df1 = DataChain.from_pandas(df) + df1 = DataChain.from_pandas(df, session=test_session) df1 = df1.select("first_name", "age", "city").to_pandas() assert df1.equals(df) -def test_pandas_file_column_conflict(catalog): +def test_pandas_file_column_conflict(test_session): file_records = {"name": ["aa.txt", "bb.txt", "ccc.jpg", "dd", "e.txt"]} with pytest.raises(DataChainParamsError): - DataChain.from_pandas(pd.DataFrame(DF_DATA | file_records)) + DataChain.from_pandas( + pd.DataFrame(DF_DATA | file_records), session=test_session + ) file_records = {"etag": [1, 2, 3, 4, 5]} with pytest.raises(DataChainParamsError): - DataChain.from_pandas(pd.DataFrame(DF_DATA | file_records)) + DataChain.from_pandas( + pd.DataFrame(DF_DATA | file_records), session=test_session + ) -def test_pandas_uppercase_columns(catalog): +def test_pandas_uppercase_columns(test_session): data = { "FirstName": ["Alice", "Bob", "Charlie", "David", "Eva"], "Age": [25, 30, 35, 40, 45], "City": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], } - df = DataChain.from_pandas(pd.DataFrame(data)).to_pandas() + df = DataChain.from_pandas(pd.DataFrame(data), session=test_session).to_pandas() assert all(col not in df.columns for col in data) assert all(col.lower() in df.columns for col in data) -def test_pandas_incorrect_column_names(catalog): +def test_pandas_incorrect_column_names(test_session): with pytest.raises(DataChainParamsError): DataChain.from_pandas( - pd.DataFrame({"First Name": ["Alice", "Bob", "Charlie", "David", "Eva"]}) + pd.DataFrame({"First Name": ["Alice", "Bob", "Charlie", "David", "Eva"]}), + session=test_session, ) with pytest.raises(DataChainParamsError): DataChain.from_pandas( - pd.DataFrame({"": ["Alice", "Bob", "Charlie", "David", "Eva"]}) + pd.DataFrame({"": ["Alice", "Bob", "Charlie", "David", "Eva"]}), + session=test_session, ) with pytest.raises(DataChainParamsError): DataChain.from_pandas( - pd.DataFrame({"First@Name": ["Alice", "Bob", "Charlie", "David", "Eva"]}) + pd.DataFrame({"First@Name": ["Alice", "Bob", "Charlie", "David", "Eva"]}), + session=test_session, ) -def test_from_features_basic(catalog): - ds = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD) +def test_from_features_basic(test_session): + ds = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD, session=test_session) ds = ds.gen(lambda prm: [File(name="")] * 5, params="parent", output={"file": File}) ds_name = "my_ds" @@ -108,8 +115,8 @@ def test_from_features_basic(catalog): assert set(ds.schema.values()) == {File} -def test_from_features(catalog): - ds = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD) +def test_from_features(test_session): + ds = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD, session=test_session) ds = ds.gen( lambda prm: list(zip([File(name="")] * len(features), features)), params="parent", @@ -119,26 +126,28 @@ def test_from_features(catalog): assert t1 == features[i] -def test_datasets(catalog): - ds = DataChain.datasets() +def test_datasets(test_session): + ds = DataChain.datasets(session=test_session) datasets = [d for d in ds.collect("dataset") if d.name == "fibonacci"] assert len(datasets) == 0 - DataChain.from_values(fib=[1, 1, 2, 3, 5, 8]).save("fibonacci") + DataChain.from_values(fib=[1, 1, 2, 3, 5, 8], session=test_session).save( + "fibonacci" + ) - ds = DataChain.datasets() + ds = DataChain.datasets(session=test_session) datasets = [d for d in ds.collect("dataset") if d.name == "fibonacci"] assert len(datasets) == 1 assert datasets[0].num_objects == 6 - ds = DataChain.datasets(object_name="foo") + ds = DataChain.datasets(object_name="foo", session=test_session) datasets = [d for d in ds.collect("foo") if d.name == "fibonacci"] assert len(datasets) == 1 assert datasets[0].num_objects == 6 -def test_preserve_feature_schema(catalog): - ds = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD) +def test_preserve_feature_schema(test_session): + ds = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD, session=test_session) ds = ds.gen( lambda prm: list(zip([File(name="")] * len(features), features, features)), params="parent", @@ -155,11 +164,11 @@ def test_preserve_feature_schema(catalog): assert set(ds.schema.values()) == {MyFr, File} -def test_from_features_simple_types(catalog): +def test_from_features_simple_types(test_session): fib = [1, 1, 2, 3, 5, 8] values = ["odd" if num % 2 else "even" for num in fib] - ds = DataChain.from_values(fib=fib, odds=values) + ds = DataChain.from_values(fib=fib, odds=values, session=test_session) df = ds.to_pandas() assert len(df) == len(fib) @@ -167,7 +176,7 @@ def test_from_features_simple_types(catalog): assert df["odds"].tolist() == values -def test_from_features_more_simple_types(catalog): +def test_from_features_more_simple_types(test_session): ds_name = "my_ds_type" DataChain.from_values( t1=features, @@ -180,6 +189,7 @@ def test_from_features_more_simple_types(catalog): datetime.datetime.today(), ], f=[3.14, 2.72, 1.62], + session=test_session, ).save(ds_name) ds = DataChain(name=ds_name) @@ -201,24 +211,24 @@ def test_from_features_more_simple_types(catalog): } -def test_file_list(catalog): +def test_file_list(test_session): names = ["f1.jpg", "f1.json", "f1.txt", "f2.jpg", "f2.json"] sizes = [1, 2, 3, 4, 5] files = [File(name=name, size=size) for name, size in zip(names, sizes)] - ds = DataChain.from_values(file=files) + ds = DataChain.from_values(file=files, session=test_session) for i, values in enumerate(ds.collect()): assert values[0] == files[i] -def test_gen(catalog): +def test_gen(test_session): class _TestFr(BaseModel): file: File sqrt: float my_name: str - ds = DataChain.from_values(t1=features) + ds = DataChain.from_values(t1=features, session=test_session) ds = ds.gen( x=lambda m_fr: [ _TestFr( @@ -241,12 +251,12 @@ class _TestFr(BaseModel): assert x.my_name == test_fr.my_name -def test_map(catalog): +def test_map(test_session): class _TestFr(BaseModel): sqrt: float my_name: str - dc = DataChain.from_values(t1=features).map( + dc = DataChain.from_values(t1=features, session=test_session).map( x=lambda m_fr: _TestFr( sqrt=math.sqrt(m_fr.count), my_name=m_fr.nnn + "_suf", @@ -267,13 +277,13 @@ class _TestFr(BaseModel): assert x.my_name == test_fr.my_name -def test_agg(catalog): +def test_agg(test_session): class _TestFr(BaseModel): f: File cnt: int my_name: str - dc = DataChain.from_values(t1=features).agg( + dc = DataChain.from_values(t1=features, session=test_session).agg( x=lambda frs: [ _TestFr( f=File(name=""), @@ -300,7 +310,7 @@ class _TestFr(BaseModel): ] -def test_agg_two_params(catalog): +def test_agg_two_params(test_session): class _TestFr(BaseModel): f: File cnt: int @@ -312,7 +322,7 @@ class _TestFr(BaseModel): MyFr(nnn="n1", count=2), ] - ds = DataChain.from_values(t1=features, t2=features2).agg( + ds = DataChain.from_values(t1=features, t2=features2, session=test_session).agg( x=lambda frs1, frs2: [ _TestFr( f=File(name=""), @@ -329,22 +339,22 @@ class _TestFr(BaseModel): assert list(ds.collect("x.cnt")) == [12, 15] -def test_agg_simple_iterator(catalog): +def test_agg_simple_iterator(test_session): def func(key, val) -> Iterator[tuple[File, str]]: for i in range(val): yield File(name=""), f"{key}_{i}" keys = ["a", "b", "c"] values = [3, 1, 2] - ds = DataChain.from_values(key=keys, val=values).gen(res=func) + ds = DataChain.from_values(key=keys, val=values, session=test_session).gen(res=func) df = ds.to_pandas() res = df["res_1"].tolist() assert res == ["a_0", "a_1", "a_2", "b_0", "c_0", "c_1"] -def test_agg_simple_iterator_error(catalog): - chain = DataChain.from_values(key=["a", "b", "c"]) +def test_agg_simple_iterator_error(test_session): + chain = DataChain.from_values(key=["a", "b", "c"], session=test_session) with pytest.raises(UdfSignatureError): @@ -371,7 +381,7 @@ def func(key) -> tuple[File, str]: # type: ignore[misc] chain.gen(res=func) -def test_agg_tuple_result_iterator(catalog): +def test_agg_tuple_result_iterator(test_session): class _ImageGroup(BaseModel): name: str size: int @@ -383,13 +393,15 @@ def func(key, val) -> Iterator[tuple[File, _ImageGroup]]: keys = ["n1", "n2", "n1"] values = [1, 5, 9] - ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key")) + ds = DataChain.from_values(key=keys, val=values, session=test_session).agg( + x=func, partition_by=C("key") + ) assert list(ds.collect("x_1.name")) == ["n1-n1", "n2"] assert list(ds.collect("x_1.size")) == [10, 5] -def test_agg_tuple_result_generator(catalog): +def test_agg_tuple_result_generator(test_session): class _ImageGroup(BaseModel): name: str size: int @@ -401,18 +413,20 @@ def func(key, val) -> Generator[tuple[File, _ImageGroup], None, None]: keys = ["n1", "n2", "n1"] values = [1, 5, 9] - ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key")) + ds = DataChain.from_values(key=keys, val=values, session=test_session).agg( + x=func, partition_by=C("key") + ) assert list(ds.collect("x_1.name")) == ["n1-n1", "n2"] assert list(ds.collect("x_1.size")) == [10, 5] -def test_batch_map(catalog): +def test_batch_map(test_session): class _TestFr(BaseModel): sqrt: float my_name: str - dc = DataChain.from_values(t1=features).batch_map( + dc = DataChain.from_values(t1=features, session=test_session).batch_map( x=lambda m_frs: [ _TestFr( sqrt=math.sqrt(m_fr.count), @@ -436,12 +450,12 @@ class _TestFr(BaseModel): assert x.my_name == test_fr.my_name -def test_batch_map_wrong_size(catalog): +def test_batch_map_wrong_size(test_session): class _TestFr(BaseModel): total: int names: str - dc = DataChain.from_values(t1=features).batch_map( + dc = DataChain.from_values(t1=features, session=test_session).batch_map( x=lambda m_frs: [ _TestFr( total=sum(m_fr.count for m_fr in m_frs), @@ -456,7 +470,7 @@ class _TestFr(BaseModel): list(dc.collect()) -def test_batch_map_two_params(catalog): +def test_batch_map_two_params(test_session): class _TestFr(BaseModel): f: File cnt: int @@ -468,7 +482,9 @@ class _TestFr(BaseModel): MyFr(nnn="n1", count=2), ] - ds = DataChain.from_values(t1=features, t2=features2).batch_map( + ds = DataChain.from_values( + t1=features, t2=features2, session=test_session + ).batch_map( x=lambda frs1, frs2: [ _TestFr( f=File(name=""), @@ -485,18 +501,20 @@ class _TestFr(BaseModel): assert list(ds.collect("x.cnt")) == [9, 15, 3] -def test_batch_map_tuple_result_iterator(catalog): +def test_batch_map_tuple_result_iterator(test_session): def sqrt(t1: list[int]) -> Iterator[float]: for val in t1: yield math.sqrt(val) - dc = DataChain.from_values(t1=[1, 4, 9]).batch_map(x=sqrt) + dc = DataChain.from_values(t1=[1, 4, 9], session=test_session).batch_map(x=sqrt) assert list(dc.collect("x")) == [1, 2, 3] -def test_collect(catalog): - dc = DataChain.from_values(f1=features, num=range(len(features))) +def test_collect(test_session): + dc = DataChain.from_values( + f1=features, num=range(len(features)), session=test_session + ) n = 0 for sample in dc.collect(): @@ -513,8 +531,8 @@ def test_collect(catalog): assert n == len(features) -def test_collect_nested_feature(catalog): - dc = DataChain.from_values(sign1=features_nested) +def test_collect_nested_feature(test_session): + dc = DataChain.from_values(sign1=features_nested, session=test_session) for n, sample in enumerate(dc.collect()): assert len(sample) == 1 @@ -524,8 +542,8 @@ def test_collect_nested_feature(catalog): assert nested == features_nested[n] -def test_select_feature(catalog): - dc = DataChain.from_values(my_n=features_nested) +def test_select_feature(test_session): + dc = DataChain.from_values(my_n=features_nested, session=test_session) samples = dc.select("my_n").collect() n = 0 @@ -551,8 +569,8 @@ def test_select_feature(catalog): assert n == len(features_nested) -def test_select_columns_intersection(catalog): - dc = DataChain.from_values(my_n=features_nested) +def test_select_columns_intersection(test_session): + dc = DataChain.from_values(my_n=features_nested, session=test_session) samples = dc.select("my_n.fr", "my_n.fr.count").collect() n = 0 @@ -564,8 +582,8 @@ def test_select_columns_intersection(catalog): assert n == len(features_nested) -def test_select_except(catalog): - dc = DataChain.from_values(fr1=features_nested, fr2=features) +def test_select_except(test_session): + dc = DataChain.from_values(fr1=features_nested, fr2=features, session=test_session) samples = dc.select_except("fr2").collect() n = 0 @@ -576,8 +594,8 @@ def test_select_except(catalog): assert n == len(features_nested) -def test_select_wrong_type(catalog): - dc = DataChain.from_values(fr1=features_nested, fr2=features) +def test_select_wrong_type(test_session): + dc = DataChain.from_values(fr1=features_nested, fr2=features, session=test_session) with pytest.raises(SignalResolvingTypeError): list(dc.select(4).collect()) @@ -586,8 +604,8 @@ def test_select_wrong_type(catalog): list(dc.select_except(features[0]).collect()) -def test_select_except_error(catalog): - dc = DataChain.from_values(fr1=features_nested, fr2=features) +def test_select_except_error(test_session): + dc = DataChain.from_values(fr1=features_nested, fr2=features, session=test_session) with pytest.raises(SignalResolvingError): list(dc.select_except("not_exist", "file").collect()) @@ -596,8 +614,8 @@ def test_select_except_error(catalog): list(dc.select_except("fr1.label", "file").collect()) -def test_select_restore_from_saving(catalog): - dc = DataChain.from_values(my_n=features_nested) +def test_select_restore_from_saving(test_session): + dc = DataChain.from_values(my_n=features_nested, session=test_session) name = "test_test_select_save" dc.select("my_n.fr").save(name) @@ -612,7 +630,7 @@ def test_select_restore_from_saving(catalog): assert n == len(features_nested) -def test_select_distinct(catalog): +def test_select_distinct(test_session): class Embedding(BaseModel): id: int filename: str @@ -634,6 +652,7 @@ class Embedding(BaseModel): Embedding(id=4, filename="d.jpg", values=expected[1]), Embedding(id=5, filename="e.jpg", values=expected[3]), ], + session=test_session, ) .select("embedding.values", "embedding.filename") .distinct("embedding.values") @@ -647,7 +666,7 @@ class Embedding(BaseModel): assert np.allclose([emb[i] for emb in actual], [emp[i] for emp in expected]) -def test_from_dataset_name_version(catalog): +def test_from_dataset_name_version(test_session): name = "test-version" DataChain.from_values( first_name=["Alice", "Bob", "Charlie"], @@ -657,6 +676,7 @@ def test_from_dataset_name_version(catalog): "Los Angeles", None, ], + session=test_session, ).save(name) dc = DataChain.from_dataset(name) @@ -664,9 +684,9 @@ def test_from_dataset_name_version(catalog): assert dc.version -def test_chain_of_maps(catalog): +def test_chain_of_maps(test_session): dc = ( - DataChain.from_values(my_n=features_nested) + DataChain.from_values(my_n=features_nested, session=test_session) .map(full_name=lambda my_n: my_n.label + "-" + my_n.fr.nnn, output=str) .map(square=lambda my_n: my_n.fr.count**2, output=int) ) @@ -681,25 +701,25 @@ def test_chain_of_maps(catalog): assert signal in preserved.schema -def test_vector(catalog): +def test_vector(test_session): vector = [3.14, 2.72, 1.62] def get_vector(key) -> list[float]: return vector - ds = DataChain.from_values(key=[123]).map(emd=get_vector) + ds = DataChain.from_values(key=[123], session=test_session).map(emd=get_vector) df = ds.to_pandas() assert np.allclose(df["emd"].tolist()[0], vector) -def test_vector_of_vectors(catalog): +def test_vector_of_vectors(test_session): vector = [[3.14, 2.72, 1.62], [1.0, 2.0, 3.0]] def get_vector(key) -> list[list[float]]: return vector - ds = DataChain.from_values(key=[123]).map(emd_list=get_vector) + ds = DataChain.from_values(key=[123], session=test_session).map(emd_list=get_vector) df = ds.to_pandas() actual = df["emd_list"].tolist()[0] @@ -708,24 +728,24 @@ def get_vector(key) -> list[list[float]]: assert np.allclose(actual[1], vector[1]) -def test_unsupported_output_type(catalog): +def test_unsupported_output_type(test_session): vector = [3.14, 2.72, 1.62] def get_vector(key) -> list[np.float64]: return [vector] with pytest.raises(TypeError): - DataChain.from_values(key=[123]).map(emd=get_vector) + DataChain.from_values(key=[123], session=test_session).map(emd=get_vector) -def test_collect_single_item(catalog): +def test_collect_single_item(test_session): names = ["f1.jpg", "f1.json", "f1.txt", "f2.jpg", "f2.json"] sizes = [1, 2, 3, 4, 5] files = [File(name=name, size=size) for name, size in zip(names, sizes)] scores = [0.1, 0.2, 0.3, 0.4, 0.5] - chain = DataChain.from_values(file=files, score=scores) + chain = DataChain.from_values(file=files, score=scores, session=test_session) assert list(chain.collect("file")) == files assert list(chain.collect("file.name")) == names @@ -741,40 +761,44 @@ def test_collect_single_item(catalog): assert math.isclose(actual[1], expected[1], rel_tol=1e-7) -def test_default_output_type(catalog): +def test_default_output_type(test_session): names = ["f1.jpg", "f1.json", "f1.txt", "f2.jpg", "f2.json"] suffix = "-new" - chain = DataChain.from_values(name=names).map(res1=lambda name: name + suffix) + chain = DataChain.from_values(name=names, session=test_session).map( + res1=lambda name: name + suffix + ) assert list(chain.collect("res1")) == [t + suffix for t in names] -def test_parse_tabular(tmp_dir, catalog): +def test_parse_tabular(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path) - dc = DataChain.from_storage(path.as_uri()).parse_tabular() + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular() df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df) -def test_parse_tabular_format(tmp_dir, catalog): +def test_parse_tabular_format(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.jsonl" path.write_text(df.to_json(orient="records", lines=True)) - dc = DataChain.from_storage(path.as_uri()).parse_tabular(format="json") + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( + format="json" + ) df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df) -def test_parse_tabular_partitions(tmp_dir, catalog): +def test_parse_tabular_partitions(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path, partition_cols=["first_name"]) dc = ( - DataChain.from_storage(path.as_uri()) + DataChain.from_storage(path.as_uri(), session=test_session) .filter(C("parent").glob("*first_name=Alice*")) .parse_tabular(partitioning="hive") ) @@ -783,13 +807,13 @@ def test_parse_tabular_partitions(tmp_dir, catalog): assert df1.equals(df.loc[:0]) -def test_parse_tabular_empty(tmp_dir, catalog): +def test_parse_tabular_empty(tmp_dir, test_session): path = tmp_dir / "test.parquet" with pytest.raises(FileNotFoundError): - DataChain.from_storage(path.as_uri()).parse_tabular() + DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular() -def test_parse_tabular_unify_schema(tmp_dir, catalog): +def test_parse_tabular_unify_schema(tmp_dir, test_session): df1 = pd.DataFrame(DF_DATA) df2 = pd.DataFrame(DF_OTHER_DATA) path1 = tmp_dir / "df1.parquet" @@ -804,7 +828,7 @@ def test_parse_tabular_unify_schema(tmp_dir, catalog): .reset_index(drop=True) ) dc = ( - DataChain.from_storage(tmp_dir.as_uri()) + DataChain.from_storage(tmp_dir.as_uri(), session=test_session) .filter(C("name").glob("*.parquet")) .parse_tabular() ) @@ -817,12 +841,12 @@ def test_parse_tabular_unify_schema(tmp_dir, catalog): assert df.equals(df_combined) -def test_parse_tabular_output_dict(tmp_dir, catalog): +def test_parse_tabular_output_dict(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.jsonl" path.write_text(df.to_json(orient="records", lines=True)) output = {"fname": str, "age": int, "loc": str} - dc = DataChain.from_storage(path.as_uri()).parse_tabular( + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( format="json", output=output ) df1 = dc.select("fname", "age", "loc").to_pandas() @@ -830,7 +854,7 @@ def test_parse_tabular_output_dict(tmp_dir, catalog): assert df1.equals(df) -def test_parse_tabular_output_feature(tmp_dir, catalog): +def test_parse_tabular_output_feature(tmp_dir, test_session): class Output(BaseModel): fname: str age: int @@ -839,7 +863,7 @@ class Output(BaseModel): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.jsonl" path.write_text(df.to_json(orient="records", lines=True)) - dc = DataChain.from_storage(path.as_uri()).parse_tabular( + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( format="json", output=Output ) df1 = dc.select("fname", "age", "loc").to_pandas() @@ -847,12 +871,12 @@ class Output(BaseModel): assert df1.equals(df) -def test_parse_tabular_output_list(tmp_dir, catalog): +def test_parse_tabular_output_list(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.jsonl" path.write_text(df.to_json(orient="records", lines=True)) output = ["fname", "age", "loc"] - dc = DataChain.from_storage(path.as_uri()).parse_tabular( + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( format="json", output=output ) df1 = dc.select("fname", "age", "loc").to_pandas() @@ -860,53 +884,60 @@ def test_parse_tabular_output_list(tmp_dir, catalog): assert df1.equals(df) -def test_parse_tabular_nrows(tmp_dir, catalog): +def test_parse_tabular_nrows(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_json(path, orient="records", lines=True) - dc = DataChain.from_storage(path.as_uri()).parse_tabular(nrows=2, format="json") + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( + nrows=2, format="json" + ) df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df[:2]) -def test_parse_tabular_nrows_invalid(tmp_dir, catalog): +def test_parse_tabular_nrows_invalid(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path) with pytest.raises(DataChainParamsError): - DataChain.from_storage(path.as_uri()).parse_tabular(nrows=2) + DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( + nrows=2 + ) -def test_from_csv(tmp_dir, catalog): +def test_from_csv(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.csv" df.to_csv(path, index=False) - dc = DataChain.from_csv(path.as_uri()) + dc = DataChain.from_csv(path.as_uri(), session=test_session) df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df) -def test_from_csv_no_header_error(tmp_dir, catalog): +def test_from_csv_no_header_error(tmp_dir, test_session): df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" df.to_csv(path, header=False, index=False) with pytest.raises(DataChainParamsError): - DataChain.from_csv(path.as_uri(), header=False) + DataChain.from_csv(path.as_uri(), header=False, session=test_session) -def test_from_csv_no_header_output_dict(tmp_dir, catalog): +def test_from_csv_no_header_output_dict(tmp_dir, test_session): df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" df.to_csv(path, header=False, index=False) dc = DataChain.from_csv( - path.as_uri(), header=False, output={"first_name": str, "age": int, "city": str} + path.as_uri(), + header=False, + output={"first_name": str, "age": int, "city": str}, + session=test_session, ) df1 = dc.select("first_name", "age", "city").to_pandas() assert (df1.values != df.values).sum() == 0 -def test_from_csv_no_header_output_feature(tmp_dir, catalog): +def test_from_csv_no_header_output_feature(tmp_dir, test_session): class Output(BaseModel): first_name: str age: int @@ -915,32 +946,37 @@ class Output(BaseModel): df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" df.to_csv(path, header=False, index=False) - dc = DataChain.from_csv(path.as_uri(), header=False, output=Output) + dc = DataChain.from_csv( + path.as_uri(), header=False, output=Output, session=test_session + ) df1 = dc.select("first_name", "age", "city").to_pandas() assert (df1.values != df.values).sum() == 0 -def test_from_csv_no_header_output_list(tmp_dir, catalog): +def test_from_csv_no_header_output_list(tmp_dir, test_session): df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" df.to_csv(path, header=False, index=False) dc = DataChain.from_csv( - path.as_uri(), header=False, output=["first_name", "age", "city"] + path.as_uri(), + header=False, + output=["first_name", "age", "city"], + session=test_session, ) df1 = dc.select("first_name", "age", "city").to_pandas() assert (df1.values != df.values).sum() == 0 -def test_from_csv_tab_delimited(tmp_dir, catalog): +def test_from_csv_tab_delimited(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.csv" df.to_csv(path, sep="\t", index=False) - dc = DataChain.from_csv(path.as_uri(), delimiter="\t") + dc = DataChain.from_csv(path.as_uri(), delimiter="\t", session=test_session) df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df) -def test_from_csv_null_collect(tmp_dir, catalog): +def test_from_csv_null_collect(tmp_dir, test_session): # Clickhouse requires setting type to Nullable(Type). # See https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/189. skip_if_not_sqlite() @@ -949,43 +985,43 @@ def test_from_csv_null_collect(tmp_dir, catalog): df["height"] = height path = tmp_dir / "test.csv" df.to_csv(path, index=False) - dc = DataChain.from_csv(path.as_uri(), object_name="csv") + dc = DataChain.from_csv(path.as_uri(), object_name="csv", session=test_session) for i, row in enumerate(dc.collect()): assert row[1].height == height[i] -def test_from_csv_nrows(tmp_dir, catalog): +def test_from_csv_nrows(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.csv" df.to_csv(path, index=False) - dc = DataChain.from_csv(path.as_uri(), nrows=2) + dc = DataChain.from_csv(path.as_uri(), nrows=2, session=test_session) df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df[:2]) -def test_from_parquet(tmp_dir, catalog): +def test_from_parquet(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path) - dc = DataChain.from_parquet(path.as_uri()) + dc = DataChain.from_parquet(path.as_uri(), session=test_session) df1 = dc.select("first_name", "age", "city").to_pandas() assert df1.equals(df) -def test_from_parquet_partitioned(tmp_dir, catalog): +def test_from_parquet_partitioned(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path, partition_cols=["first_name"]) - dc = DataChain.from_parquet(path.as_uri()) + dc = DataChain.from_parquet(path.as_uri(), session=test_session) df1 = dc.select("first_name", "age", "city").to_pandas() df1 = df1.sort_values("first_name").reset_index(drop=True) assert df1.equals(df) -def test_to_parquet(tmp_dir, catalog): +def test_to_parquet(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) - dc = DataChain.from_pandas(df) + dc = DataChain.from_pandas(df, session=test_session) path = tmp_dir / "test.parquet" dc.to_parquet(path) @@ -994,9 +1030,9 @@ def test_to_parquet(tmp_dir, catalog): pd.testing.assert_frame_equal(pd.read_parquet(path), df) -def test_to_parquet_partitioned(tmp_dir, catalog): +def test_to_parquet_partitioned(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) - dc = DataChain.from_pandas(df) + dc = DataChain.from_pandas(df, session=test_session) path = tmp_dir / "parquets" dc.to_parquet(path, partition_cols=["first_name"]) @@ -1012,12 +1048,12 @@ def test_to_parquet_partitioned(tmp_dir, catalog): @pytest.mark.parametrize("processes", [False, 2, True]) -def test_parallel(processes, catalog): +def test_parallel(processes, test_session_tmpfile): prefix = "t & " vals = ["a", "b", "c", "d", "e", "f", "g", "h", "i"] res = list( - DataChain.from_values(key=vals) + DataChain.from_values(key=vals, session=test_session_tmpfile) .settings(parallel=processes) .map(res=lambda key: prefix + key) .collect("res") @@ -1026,12 +1062,12 @@ def test_parallel(processes, catalog): assert res == [prefix + v for v in vals] -def test_exec(catalog): +def test_exec(test_session): names = ("f1.jpg", "f1.json", "f1.txt", "f2.jpg", "f2.json") all_names = set() dc = ( - DataChain.from_values(name=names) + DataChain.from_values(name=names, session=test_session) .map(nop=lambda name: all_names.add(name)) .exec() ) @@ -1039,8 +1075,10 @@ def test_exec(catalog): assert all_names == set(names) -def test_extend_features(catalog): - dc = DataChain.from_values(f1=features, num=range(len(features))) +def test_extend_features(test_session): + dc = DataChain.from_values( + f1=features, num=range(len(features)), session=test_session + ) res = dc._extend_to_data_model("select", "num") assert isinstance(res, DataChain) @@ -1050,32 +1088,38 @@ def test_extend_features(catalog): assert res == sum(range(len(features))) -def test_from_storage_object_name(tmp_dir, catalog): +def test_from_storage_object_name(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path) - dc = DataChain.from_storage(path.as_uri(), object_name="custom") + dc = DataChain.from_storage( + path.as_uri(), object_name="custom", session=test_session + ) assert dc.schema["custom"] == File -def test_from_features_object_name(tmp_dir, catalog): +def test_from_features_object_name(test_session): fib = [1, 1, 2, 3, 5, 8] values = ["odd" if num % 2 else "even" for num in fib] - dc = DataChain.from_values(fib=fib, odds=values, object_name="custom") + dc = DataChain.from_values( + fib=fib, odds=values, object_name="custom", session=test_session + ) assert "custom.fib" in dc.to_pandas(flatten=True).columns -def test_parse_tabular_object_name(tmp_dir, catalog): +def test_parse_tabular_object_name(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" df.to_parquet(path) - dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="tbl") + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( + object_name="tbl" + ) assert "tbl.first_name" in dc.to_pandas(flatten=True).columns -def test_sys_feature(tmp_dir, catalog): - ds = DataChain.from_values(t1=features) +def test_sys_feature(test_session): + ds = DataChain.from_values(t1=features, session=test_session) ds_sys = ds.settings(sys=True) assert not ds._sys assert ds_sys._sys @@ -1104,8 +1148,8 @@ def test_sys_feature(tmp_dir, catalog): assert "sys" not in ds_no_sys.catalog.get_dataset("ds_no_sys").feature_schema -def test_to_pandas_multi_level(): - df = DataChain.from_values(t1=features).to_pandas() +def test_to_pandas_multi_level(test_session): + df = DataChain.from_values(t1=features, session=test_session).to_pandas() assert "t1" in df.columns assert "nnn" in df["t1"].columns @@ -1113,8 +1157,8 @@ def test_to_pandas_multi_level(): assert df["t1"]["count"].tolist() == [3, 5, 1] -def test_mutate(): - chain = DataChain.from_values(t1=features).mutate( +def test_mutate(test_session): + chain = DataChain.from_values(t1=features, session=test_session).mutate( circle=2 * 3.14 * Column("t1.count"), place="pref_" + Column("t1.nnn") ) @@ -1126,10 +1170,12 @@ def test_mutate(): @pytest.mark.parametrize("with_function", [True, False]) -def test_order_by_with_nested_columns(with_function): +def test_order_by_with_nested_columns(test_session, with_function): names = ["a.txt", "c.txt", "d.txt", "a.txt", "b.txt"] - dc = DataChain.from_values(file=[File(name=name) for name in names]) + dc = DataChain.from_values( + file=[File(name=name) for name in names], session=test_session + ) if with_function: from datachain.sql.functions import rand @@ -1147,10 +1193,12 @@ def test_order_by_with_nested_columns(with_function): @pytest.mark.parametrize("with_function", [True, False]) -def test_order_by_descending(with_function): +def test_order_by_descending(test_session, with_function): names = ["a.txt", "c.txt", "d.txt", "a.txt", "b.txt"] - dc = DataChain.from_values(file=[File(name=name) for name in names]) + dc = DataChain.from_values( + file=[File(name=name) for name in names], session=test_session + ) if with_function: from datachain.sql.functions import rand @@ -1167,44 +1215,44 @@ def test_order_by_descending(with_function): ] -def test_union(catalog): - chain1 = DataChain.from_values(value=[1, 2]) - chain2 = DataChain.from_values(value=[3, 4]) +def test_union(test_session): + chain1 = DataChain.from_values(value=[1, 2], session=test_session) + chain2 = DataChain.from_values(value=[3, 4], session=test_session) chain3 = chain1 | chain2 assert chain3.count() == 4 assert sorted(chain3.collect("value")) == [1, 2, 3, 4] -def test_subtract(catalog): - chain1 = DataChain.from_values(a=[1, 1, 2], b=["x", "y", "z"]) - chain2 = DataChain.from_values(a=[1, 2], b=["x", "y"]) +def test_subtract(test_session): + chain1 = DataChain.from_values(a=[1, 1, 2], b=["x", "y", "z"], session=test_session) + chain2 = DataChain.from_values(a=[1, 2], b=["x", "y"], session=test_session) assert set(chain1.subtract(chain2, on=["a", "b"]).collect()) == {(1, "y"), (2, "z")} assert set(chain1.subtract(chain2, on=["b"]).collect()) == {(2, "z")} assert set(chain1.subtract(chain2, on=["a"]).collect()) == set() assert set(chain1.subtract(chain2).collect()) == {(1, "y"), (2, "z")} assert chain1.subtract(chain1).count() == 0 - chain3 = DataChain.from_values(a=[1, 3], c=["foo", "bar"]) + chain3 = DataChain.from_values(a=[1, 3], c=["foo", "bar"], session=test_session) assert set(chain1.subtract(chain3, on="a").collect()) == {(2, "z")} assert set(chain1.subtract(chain3).collect()) == {(2, "z")} -def test_subtract_error(catalog): - chain1 = DataChain.from_values(a=[1, 1, 2], b=["x", "y", "z"]) - chain2 = DataChain.from_values(a=[1, 2], b=["x", "y"]) +def test_subtract_error(test_session): + chain1 = DataChain.from_values(a=[1, 1, 2], b=["x", "y", "z"], session=test_session) + chain2 = DataChain.from_values(a=[1, 2], b=["x", "y"], session=test_session) with pytest.raises(DataChainParamsError): chain1.subtract(chain2, on=[]) with pytest.raises(TypeError): chain1.subtract(chain2, on=42) - chain3 = DataChain.from_values(c=["foo", "bar"]) + chain3 = DataChain.from_values(c=["foo", "bar"], session=test_session) with pytest.raises(DataChainParamsError): chain1.subtract(chain3) -def test_column_math(): +def test_column_math(test_session): fib = [1, 1, 2, 3, 5, 8] - chain = DataChain.from_values(num=fib) + chain = DataChain.from_values(num=fib, session=test_session) ch = chain.mutate(add2=Column("num") + 2) assert list(ch.collect("add2")) == [x + 2 for x in fib] @@ -1216,14 +1264,14 @@ def test_column_math(): assert list(ch2.collect("x")) == [1 - (x / 2.0) for x in fib] -def test_from_values_array_of_floats(): +def test_from_values_array_of_floats(test_session): embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - chain = DataChain.from_values(emd=embeddings) + chain = DataChain.from_values(emd=embeddings, session=test_session) assert list(chain.collect("emd")) == embeddings -def test_custom_model_with_nested_lists(): +def test_custom_model_with_nested_lists(test_session): ds_name = "nested" class Trace(BaseModel): @@ -1246,6 +1294,7 @@ class Nested(BaseModel): ) ], nums=[1], + session=test_session, ).save(ds_name) assert list(DataChain(name=ds_name).collect("nested")) == [ diff --git a/tests/unit/lib/test_datachain_merge.py b/tests/unit/lib/test_datachain_merge.py index 6efd6be1b..869ab3fd6 100644 --- a/tests/unit/lib/test_datachain_merge.py +++ b/tests/unit/lib/test_datachain_merge.py @@ -44,13 +44,13 @@ class TeamMember(BaseModel): ] -def test_merge_objects(catalog): - ch1 = DataChain.from_values(emp=employees) - ch2 = DataChain.from_values(team=team) +def test_merge_objects(test_session): + ch1 = DataChain.from_values(emp=employees, session=test_session) + ch2 = DataChain.from_values(team=team, session=test_session) ch = ch1.merge(ch2, "emp.person.name", "team.player") - str_default = String.default_value(catalog.warehouse.db.dialect) - float_default = Float.default_value(catalog.warehouse.db.dialect) + str_default = String.default_value(test_session.catalog.warehouse.db.dialect) + float_default = Float.default_value(test_session.catalog.warehouse.db.dialect) i = 0 j = 0 @@ -79,15 +79,15 @@ def test_merge_objects(catalog): assert j == len(team) -def test_merge_similar_objects(catalog): +def test_merge_similar_objects(test_session): new_employees = [ Employee(id=152, person=User(name="Bob", age=27)), Employee(id=201, person=User(name="Karl", age=18)), Employee(id=154, person=User(name="David", age=29)), ] - ch1 = DataChain.from_values(emp=employees) - ch2 = DataChain.from_values(emp=new_employees) + ch1 = DataChain.from_values(emp=employees, session=test_session) + ch2 = DataChain.from_values(emp=new_employees, session=test_session) rname = "qq" ch = ch1.merge(ch2, "emp.person.name", rname=rname) @@ -102,17 +102,19 @@ def test_merge_similar_objects(catalog): assert len(list(ch_inner.collect())) == 2 -def test_merge_values(catalog): +def test_merge_values(test_session): order_ids = [11, 22, 33, 44] order_descr = ["water", "water", "paper", "water"] delivery_ids = [11, 44] delivery_time = [24.0, 16.5] - float_default = Float.default_value(catalog.warehouse.db.dialect) + float_default = Float.default_value(test_session.catalog.warehouse.db.dialect) - ch1 = DataChain.from_values(id=order_ids, descr=order_descr) - ch2 = DataChain.from_values(id=delivery_ids, time=delivery_time) + ch1 = DataChain.from_values(id=order_ids, descr=order_descr, session=test_session) + ch2 = DataChain.from_values( + id=delivery_ids, time=delivery_time, session=test_session + ) ch = ch1.merge(ch2, "id") @@ -144,7 +146,7 @@ def test_merge_values(catalog): assert j == len(delivery_ids) -def test_merge_multi_conditions(catalog): +def test_merge_multi_conditions(test_session): order_ids = [11, 22, 33, 44] order_name = ["water", "water", "paper", "water"] order_descr = ["still water", "still water", "white paper", "sparkling water"] @@ -153,9 +155,11 @@ def test_merge_multi_conditions(catalog): delivery_name = ["water", "unknown"] delivery_time = [24.0, 16.5] - ch1 = DataChain.from_values(id=order_ids, name=order_name, descr=order_descr) + ch1 = DataChain.from_values( + id=order_ids, name=order_name, descr=order_descr, session=test_session + ) ch2 = DataChain.from_values( - id=delivery_ids, d_name=delivery_name, time=delivery_time + id=delivery_ids, d_name=delivery_name, time=delivery_time, session=test_session ) ch = ch1.merge(ch2, ("id", "name"), ("id", "d_name")) @@ -171,9 +175,9 @@ def test_merge_multi_conditions(catalog): assert success_ids == {11} -def test_merge_errors(catalog): - ch1 = DataChain.from_values(emp=employees) - ch2 = DataChain.from_values(team=team) +def test_merge_errors(test_session): + ch1 = DataChain.from_values(emp=employees, session=test_session) + ch2 = DataChain.from_values(team=team, session=test_session) with pytest.raises(SignalResolvingError): ch1.merge(ch2, "unknown") @@ -190,8 +194,8 @@ def test_merge_errors(catalog): ch1.merge(ch2, "emp.person.name", True) -def test_merge_with_itself(catalog): - ch = DataChain.from_values(emp=employees) +def test_merge_with_itself(test_session): + ch = DataChain.from_values(emp=employees, session=test_session) merged = ch.merge(ch, "emp.id") count = 0 diff --git a/tests/unit/lib/test_feature_utils.py b/tests/unit/lib/test_feature_utils.py index 47066001f..d3f22622e 100644 --- a/tests/unit/lib/test_feature_utils.py +++ b/tests/unit/lib/test_feature_utils.py @@ -25,11 +25,11 @@ def test_basic(): assert vals[-1] == (fib[-1], values[-1]) -def test_e2e(catalog): +def test_e2e(test_session): fib = [1, 1, 2, 3, 5, 8] values = ["odd" if num % 2 else "even" for num in fib] - dc = DataChain.from_values(fib=fib, odds=values) + dc = DataChain.from_values(fib=fib, odds=values, session=test_session) vals = list(dc.collect()) lst1 = [item[0] for item in vals] @@ -48,10 +48,10 @@ def test_single_value(): assert vals == fib -def test_single_e2e(catalog): +def test_single_e2e(test_session): fib = [1, 1, 2, 3, 5, 8] - dc = DataChain.from_values(fib=fib) + dc = DataChain.from_values(fib=fib, session=test_session) vals = list(dc.collect()) flattened = [item for sublist in vals for item in sublist] diff --git a/tests/unit/test_catalog_loader.py b/tests/unit/test_catalog_loader.py index d1893a894..8069e48a9 100644 --- a/tests/unit/test_catalog_loader.py +++ b/tests/unit/test_catalog_loader.py @@ -11,7 +11,6 @@ get_warehouse, ) from datachain.data_storage.sqlite import ( - SQLiteDatabaseEngine, SQLiteIDGenerator, SQLiteMetastore, SQLiteWarehouse, @@ -24,37 +23,34 @@ def __init__(self, **kwargs): self.kwargs = kwargs -def test_get_id_generator(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - - id_generator = SQLiteIDGenerator(db, table_prefix="prefix") - assert id_generator.db == db +def test_get_id_generator(sqlite_db): + id_generator = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") + assert id_generator.db == sqlite_db assert id_generator._table_prefix == "prefix" with patch.dict(os.environ, {"DATACHAIN__ID_GENERATOR": id_generator.serialize()}): id_generator2 = get_id_generator() assert id_generator2 assert isinstance(id_generator2, SQLiteIDGenerator) - assert id_generator2._db.db_file == db.db_file + assert id_generator2._db.db_file == sqlite_db.db_file assert id_generator2._table_prefix == "prefix" assert id_generator2.clone_params() == id_generator.clone_params() - with patch.dict(os.environ, {"DATACHAIN__ID_GENERATOR": db.serialize()}): + with patch.dict(os.environ, {"DATACHAIN__ID_GENERATOR": sqlite_db.serialize()}): with pytest.raises(RuntimeError, match="instance of AbstractIDGenerator"): get_id_generator() -def test_get_metastore(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - id_generator = SQLiteIDGenerator(db, table_prefix="prefix") +def test_get_metastore(sqlite_db): + id_generator = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") uri = StorageURI("s3://bucket") partial_id = 37 - metastore = SQLiteMetastore(id_generator, uri, partial_id, db) + metastore = SQLiteMetastore(id_generator, uri, partial_id, sqlite_db) assert metastore.id_generator == id_generator assert metastore.uri == uri assert metastore.partial_id == partial_id - assert metastore.db == db + assert metastore.db == sqlite_db with patch.dict(os.environ, {"DATACHAIN__METASTORE": metastore.serialize()}): metastore2 = get_metastore(None) @@ -67,21 +63,20 @@ def test_get_metastore(): ) assert metastore2.uri == uri assert metastore2.partial_id == partial_id - assert metastore2.db.db_file == db.db_file + assert metastore2.db.db_file == sqlite_db.db_file assert metastore2.clone_params() == metastore.clone_params() - with patch.dict(os.environ, {"DATACHAIN__METASTORE": db.serialize()}): + with patch.dict(os.environ, {"DATACHAIN__METASTORE": sqlite_db.serialize()}): with pytest.raises(RuntimeError, match="instance of AbstractMetastore"): get_metastore(None) -def test_get_warehouse(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - id_generator = SQLiteIDGenerator(db, table_prefix="prefix") +def test_get_warehouse(sqlite_db): + id_generator = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") - warehouse = SQLiteWarehouse(id_generator, db) + warehouse = SQLiteWarehouse(id_generator, sqlite_db) assert warehouse.id_generator == id_generator - assert warehouse.db == db + assert warehouse.db == sqlite_db with patch.dict(os.environ, {"DATACHAIN__WAREHOUSE": warehouse.serialize()}): warehouse2 = get_warehouse(None) @@ -92,10 +87,10 @@ def test_get_warehouse(): warehouse2.id_generator._table_prefix == warehouse.id_generator._table_prefix ) - assert warehouse2.db.db_file == db.db_file + assert warehouse2.db.db_file == sqlite_db.db_file assert warehouse2.clone_params() == warehouse.clone_params() - with patch.dict(os.environ, {"DATACHAIN__WAREHOUSE": db.serialize()}): + with patch.dict(os.environ, {"DATACHAIN__WAREHOUSE": sqlite_db.serialize()}): with pytest.raises(RuntimeError, match="instance of AbstractWarehouse"): get_warehouse(None) @@ -135,13 +130,12 @@ def test_get_distributed_class(): get_distributed_class() -def test_get_catalog(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - id_generator = SQLiteIDGenerator(db, table_prefix="prefix") +def test_get_catalog(sqlite_db): + id_generator = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") uri = StorageURI("s3://bucket") partial_id = 73 - metastore = SQLiteMetastore(id_generator, uri, partial_id, db) - warehouse = SQLiteWarehouse(id_generator, db) + metastore = SQLiteMetastore(id_generator, uri, partial_id, sqlite_db) + warehouse = SQLiteWarehouse(id_generator, sqlite_db) env = { "DATACHAIN__ID_GENERATOR": id_generator.serialize(), "DATACHAIN__METASTORE": metastore.serialize(), @@ -154,7 +148,7 @@ def test_get_catalog(): assert catalog.id_generator assert isinstance(catalog.id_generator, SQLiteIDGenerator) - assert catalog.id_generator._db.db_file == db.db_file + assert catalog.id_generator._db.db_file == sqlite_db.db_file assert catalog.id_generator._table_prefix == "prefix" assert catalog.id_generator.clone_params() == id_generator.clone_params() @@ -170,7 +164,7 @@ def test_get_catalog(): ) assert catalog.metastore.uri == uri assert catalog.metastore.partial_id == partial_id - assert catalog.metastore.db.db_file == db.db_file + assert catalog.metastore.db.db_file == sqlite_db.db_file assert catalog.metastore.clone_params() == metastore.clone_params() assert catalog.warehouse @@ -183,5 +177,5 @@ def test_get_catalog(): catalog.warehouse.id_generator._table_prefix == warehouse.id_generator._table_prefix ) - assert catalog.warehouse.db.db_file == db.db_file + assert catalog.warehouse.db.db_file == sqlite_db.db_file assert catalog.warehouse.clone_params() == warehouse.clone_params() diff --git a/tests/unit/test_data_storage.py b/tests/unit/test_data_storage.py index 73a97a28f..b79b703ed 100644 --- a/tests/unit/test_data_storage.py +++ b/tests/unit/test_data_storage.py @@ -36,23 +36,23 @@ def test_dir_expansion(cloud_test_catalog, version_aware, cloud_type): ds = create_tar_dataset(catalog, ctc.src_uri, "ds2") dataset = catalog.get_dataset(ds.name) - st = catalog.warehouse.clone() - q = st.dataset_rows(dataset).dir_expansion() - columns = ( - "id", - "vtype", - "is_dir", - "source", - "parent", - "name", - "version", - "location", - ) - result = [dict(zip(columns, r)) for r in st.db.execute(q)] - to_compare = [ - (r["parent"], r["name"], r["vtype"], r["is_dir"], r["version"] != "") - for r in result - ] + with catalog.warehouse.clone() as warehouse: + q = warehouse.dataset_rows(dataset).dir_expansion() + columns = ( + "id", + "vtype", + "is_dir", + "source", + "parent", + "name", + "version", + "location", + ) + result = [dict(zip(columns, r)) for r in warehouse.db.execute(q)] + to_compare = [ + (r["parent"], r["name"], r["vtype"], r["is_dir"], r["version"] != "") + for r in result + ] assert all(r["source"] == ctc.storage_uri for r in result) if cloud_type == "file": diff --git a/tests/unit/test_database_engine.py b/tests/unit/test_database_engine.py index 3f1dc9f88..998815c6d 100644 --- a/tests/unit/test_database_engine.py +++ b/tests/unit/test_database_engine.py @@ -10,20 +10,18 @@ @pytest.mark.parametrize("db_file", [":memory:", "file.db"]) def test_init_clone(db_file): - db = SQLiteDatabaseEngine.from_db_file(db_file) - assert db.db_file == db_file + with SQLiteDatabaseEngine.from_db_file(db_file) as db: + assert db.db_file == db_file - # Test clone - db2 = db.clone() - assert isinstance(db2, SQLiteDatabaseEngine) - assert db2.db_file == db_file + # Test clone + with db.clone() as db2: + assert isinstance(db2, SQLiteDatabaseEngine) + assert db2.db_file == db_file -def test_serialize(): - obj = SQLiteDatabaseEngine.from_db_file(":memory:") - +def test_serialize(sqlite_db): # Test serialization - serialized = obj.serialize() + serialized = sqlite_db.serialize() assert serialized serialized_pickled = base64.b64decode(serialized.encode()) assert serialized_pickled @@ -36,7 +34,7 @@ def test_serialize(): obj3 = deserialize(serialized) assert isinstance(obj3, SQLiteDatabaseEngine) assert obj3.db_file == ":memory:" - assert obj3.clone_params() == obj.clone_params() + assert obj3.clone_params() == sqlite_db.clone_params() def test_table(sqlite_db): diff --git a/tests/unit/test_id_generator.py b/tests/unit/test_id_generator.py index 7c0a665e6..054f9b515 100644 --- a/tests/unit/test_id_generator.py +++ b/tests/unit/test_id_generator.py @@ -4,7 +4,7 @@ from sqlalchemy import select from datachain.data_storage.serializer import deserialize -from datachain.data_storage.sqlite import SQLiteDatabaseEngine, SQLiteIDGenerator +from datachain.data_storage.sqlite import SQLiteIDGenerator def get_rows(id_generator): @@ -79,11 +79,9 @@ def test_clone_params(id_generator): assert not clone.db.has_table("id_generator") -def test_serialize(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - - obj = SQLiteIDGenerator(db, table_prefix="prefix") - assert obj.db == db +def test_serialize(sqlite_db): + obj = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") + assert obj.db == sqlite_db assert obj._table_prefix == "prefix" # Test clone @@ -101,13 +99,13 @@ def test_serialize(): (f, args, kwargs) = pickle.loads(serialized_pickled) # noqa: S301 assert str(f) == str(SQLiteIDGenerator.init_after_clone) assert args == [] - assert str(kwargs["db_clone_params"]) == str(db.clone_params()) + assert str(kwargs["db_clone_params"]) == str(sqlite_db.clone_params()) assert kwargs["table_prefix"] == "prefix" # Test deserialization obj3 = deserialize(serialized) assert isinstance(obj3, SQLiteIDGenerator) - assert obj3.db.db_file == db.db_file + assert obj3.db.db_file == sqlite_db.db_file assert obj3._table_prefix == "prefix" diff --git a/tests/unit/test_metastore.py b/tests/unit/test_metastore.py index 113528850..4be751e32 100644 --- a/tests/unit/test_metastore.py +++ b/tests/unit/test_metastore.py @@ -3,23 +3,21 @@ from datachain.data_storage.serializer import deserialize from datachain.data_storage.sqlite import ( - SQLiteDatabaseEngine, SQLiteIDGenerator, SQLiteMetastore, ) from datachain.storage import StorageURI -def test_sqlite_metastore(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - id_generator = SQLiteIDGenerator(db, table_prefix="prefix") +def test_sqlite_metastore(sqlite_db): + id_generator = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") uri = StorageURI("s3://bucket") - obj = SQLiteMetastore(id_generator, uri, 1, db) + obj = SQLiteMetastore(id_generator, uri, 1, sqlite_db) assert obj.id_generator == id_generator assert obj.uri == uri assert obj.partial_id == 1 - assert obj.db == db + assert obj.db == sqlite_db # Test clone obj2 = obj.clone() @@ -28,7 +26,7 @@ def test_sqlite_metastore(): assert obj2.id_generator._table_prefix == obj.id_generator._table_prefix assert obj2.uri == uri assert obj2.partial_id == 1 - assert obj2.db.db_file == db.db_file + assert obj2.db.db_file == sqlite_db.db_file assert obj2.clone_params() == obj.clone_params() # Test serialization @@ -42,7 +40,7 @@ def test_sqlite_metastore(): assert str(kwargs["id_generator_clone_params"]) == str(id_generator.clone_params()) assert kwargs["uri"] == uri assert kwargs["partial_id"] == 1 - assert str(kwargs["db_clone_params"]) == str(db.clone_params()) + assert str(kwargs["db_clone_params"]) == str(sqlite_db.clone_params()) # Test deserialization obj3 = deserialize(serialized) @@ -51,5 +49,5 @@ def test_sqlite_metastore(): assert obj3.id_generator._table_prefix == id_generator._table_prefix assert obj3.uri == uri assert obj3.partial_id == 1 - assert obj3.db.db_file == db.db_file + assert obj3.db.db_file == sqlite_db.db_file assert obj3.clone_params() == obj.clone_params() diff --git a/tests/unit/test_warehouse.py b/tests/unit/test_warehouse.py index b1b27a255..441ba8a6a 100644 --- a/tests/unit/test_warehouse.py +++ b/tests/unit/test_warehouse.py @@ -3,26 +3,24 @@ from datachain.data_storage.serializer import deserialize from datachain.data_storage.sqlite import ( - SQLiteDatabaseEngine, SQLiteIDGenerator, SQLiteWarehouse, ) -def test_serialize(): - db = SQLiteDatabaseEngine.from_db_file(":memory:") - id_generator = SQLiteIDGenerator(db, table_prefix="prefix") +def test_serialize(sqlite_db): + id_generator = SQLiteIDGenerator(sqlite_db, table_prefix="prefix") - obj = SQLiteWarehouse(id_generator, db) + obj = SQLiteWarehouse(id_generator, sqlite_db) assert obj.id_generator == id_generator - assert obj.db == db + assert obj.db == sqlite_db # Test clone obj2 = obj.clone() assert isinstance(obj2, SQLiteWarehouse) assert obj2.id_generator.db.db_file == obj.id_generator.db.db_file assert obj2.id_generator._table_prefix == obj.id_generator._table_prefix - assert obj2.db.db_file == db.db_file + assert obj2.db.db_file == sqlite_db.db_file assert obj2.clone_params() == obj.clone_params() # Test serialization @@ -34,14 +32,14 @@ def test_serialize(): assert str(f) == str(SQLiteWarehouse.init_after_clone) assert args == [] assert str(kwargs["id_generator_clone_params"]) == str(id_generator.clone_params()) - assert str(kwargs["db_clone_params"]) == str(db.clone_params()) + assert str(kwargs["db_clone_params"]) == str(sqlite_db.clone_params()) # Test deserialization obj3 = deserialize(serialized) assert isinstance(obj3, SQLiteWarehouse) assert obj3.id_generator.db.db_file == id_generator.db.db_file assert obj3.id_generator._table_prefix == id_generator._table_prefix - assert obj3.db.db_file == db.db_file + assert obj3.db.db_file == sqlite_db.db_file assert obj3.clone_params() == obj.clone_params()