From 1df10e9a9c3278de54d53be04923c019cc01a496 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Tue, 10 Sep 2024 22:50:29 +0700 Subject: [PATCH 1/2] Fix calculating datasets stats size --- src/datachain/catalog/catalog.py | 5 ++--- src/datachain/data_storage/warehouse.py | 9 ++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index e32a5e673..df7a1bf81 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1052,7 +1052,6 @@ def create_new_dataset_version( if create_rows_table: table_name = self.warehouse.dataset_table_name(dataset.name, version) self.warehouse.create_dataset_rows_table(table_name, columns=columns) - self.update_dataset_version_with_warehouse_info(dataset, version) return dataset @@ -1390,12 +1389,12 @@ def dataset_table_export_file_names(self, name: str, version: int) -> list[str]: dataset = self.get_dataset(name) return self.warehouse.dataset_table_export_file_names(dataset, version) - def dataset_stats(self, name: str, version: int) -> DatasetStats: + def dataset_stats(self, name: str, version: Optional[int]) -> DatasetStats: """ Returns tuple with dataset stats: total number of rows and total dataset size. """ dataset = self.get_dataset(name) - dataset_version = dataset.get_version(version) + dataset_version = dataset.get_version(version or dataset.latest_version) return DatasetStats( num_objects=dataset_version.num_objects, size=dataset_version.size, diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 22390deea..222a1af80 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -401,13 +401,12 @@ def dataset_stats( expressions: tuple[_ColumnsClauseArgument[Any], ...] = ( sa.func.count(table.c.sys__id), ) - if "file__size" in table.columns: - expressions = (*expressions, sa.func.sum(table.c.file__size)) - elif "size" in table.columns: - expressions = (*expressions, sa.func.sum(table.c.size)) + for c in table.columns: + if c.name.endswith("file__size"): + expressions = (*expressions, sa.func.sum(c)) query = select(*expressions) ((nrows, *rest),) = self.db.execute(query) - return nrows, rest[0] if rest else None + return nrows, sum(rest) if rest else 0 def prepare_entries( self, uri: str, entries: Iterable[Entry] From 2c8be936a3dc58fa578e570365476e4f87a2280d Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Wed, 11 Sep 2024 22:39:46 +0700 Subject: [PATCH 2/2] Fix tests --- src/datachain/catalog/catalog.py | 1 + src/datachain/data_storage/warehouse.py | 10 ++++++---- tests/func/test_catalog.py | 25 +++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index df7a1bf81..d015587bc 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1052,6 +1052,7 @@ def create_new_dataset_version( if create_rows_table: table_name = self.warehouse.dataset_table_name(dataset.name, version) self.warehouse.create_dataset_rows_table(table_name, columns=columns) + self.update_dataset_version_with_warehouse_info(dataset, version) return dataset diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 222a1af80..c210a621e 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -401,12 +401,14 @@ def dataset_stats( expressions: tuple[_ColumnsClauseArgument[Any], ...] = ( sa.func.count(table.c.sys__id), ) - for c in table.columns: - if c.name.endswith("file__size"): - expressions = (*expressions, sa.func.sum(c)) + size_columns = [ + c for c in table.columns if c.name == "size" or c.name.endswith("__size") + ] + if size_columns: + expressions = (*expressions, sa.func.sum(sum(size_columns))) query = select(*expressions) ((nrows, *rest),) = self.db.execute(query) - return nrows, sum(rest) if rest else 0 + return nrows, rest[0] if rest else 0 def prepare_entries( self, uri: str, entries: Iterable[Entry] diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index c36f96665..87f5eed7e 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -8,6 +8,7 @@ import yaml from fsspec.implementations.local import LocalFileSystem +from datachain import DataChain, File from datachain.catalog import parse_edatachain_file from datachain.cli import garbage_collect from datachain.error import ( @@ -940,6 +941,30 @@ def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created): assert dataset_version.size == 15 +def test_dataset_stats(test_session): + ids = [1, 2, 3] + values = tuple(zip(["a", "b", "c"], [1, 2, 3])) + + ds1 = DataChain.from_values( + ids=ids, + file=[File(path=name, size=size) for name, size in values], + session=test_session, + ).save() + dataset_version1 = test_session.catalog.get_dataset(ds1.name).get_version(1) + assert dataset_version1.num_objects == 3 + assert dataset_version1.size == 6 + + ds2 = DataChain.from_values( + ids=ids, + file1=[File(path=name, size=size) for name, size in values], + file2=[File(path=name, size=size * 2) for name, size in values], + session=test_session, + ).save() + dataset_version2 = test_session.catalog.get_dataset(ds2.name).get_version(1) + assert dataset_version2.num_objects == 3 + assert dataset_version2.size == 18 + + def test_query_fail_to_compile(cloud_test_catalog): catalog = cloud_test_catalog.catalog