Skip to content

Commit

Permalink
add get_dataset_version method to catalog and metastore
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed Nov 28, 2024
1 parent 3bd22ad commit 4c738ba
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,9 @@ def register_dataset(
def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_dataset_version(self, name: str, version: int) -> DatasetRecord:
return self.metastore.get_dataset_version(name, version)

def get_remote_dataset(self, name: str) -> DatasetRecord:
studio_client = StudioClient()

Expand Down
22 changes: 20 additions & 2 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from datachain.error import (
DatasetNotFoundError,
DatasetVersionNotFoundError,
TableMissingError,
)
from datachain.job import Job
Expand Down Expand Up @@ -181,6 +182,10 @@ def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
def get_dataset(self, name: str) -> DatasetRecord:
"""Gets a single dataset by name."""

@abstractmethod
def get_dataset_version(self, name: str, version: int) -> DatasetRecord:
"""Gets a single dataset with a single version by name and version"""

@abstractmethod
def update_dataset_status(
self,
Expand Down Expand Up @@ -727,9 +732,9 @@ def _get_dataset_query(
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
return query.select_from(j)

def _base_dataset_query(self):
def _base_dataset_query(self, isouter=True):
return self._get_dataset_query(
self._dataset_fields, self._dataset_version_fields
self._dataset_fields, self._dataset_version_fields, isouter=isouter
)

def _base_list_datasets_query(self):
Expand Down Expand Up @@ -760,6 +765,19 @@ def get_dataset(self, name: str, conn=None) -> DatasetRecord:
raise DatasetNotFoundError(f"Dataset {name} not found.")
return ds

def get_dataset_version(self, name: str, version: int, conn=None) -> DatasetRecord:
"""Gets a single dataset with a single version by name and version"""
d = self._datasets
dv = self._datasets_versions
query = self._base_dataset_query(isouter=False)
query = query.where((d.c.name == name) & (dv.c.version == version)) # type: ignore [attr-defined]
ds = self._parse_dataset(self.db.execute(query, conn=conn))
if not ds:
raise DatasetVersionNotFoundError(
f"Dataset {name} with {version} not found."
)
return ds

def remove_dataset_version(
self, dataset: DatasetRecord, version: int
) -> DatasetRecord:
Expand Down
15 changes: 15 additions & 0 deletions tests/func/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,21 @@ def test_get_dataset(cloud_test_catalog, dogs_dataset):
catalog.get_dataset("wrong name")


def test_get_dataset_version(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog

dataset = catalog.get_dataset_version(dogs_dataset.name, 1)
assert dataset.name == dogs_dataset.name
assert len(dataset.versions)
assert dataset.versions[0].version == 1

with pytest.raises(DatasetVersionNotFoundError):
catalog.get_dataset_version("wrong name", 1)

with pytest.raises(DatasetVersionNotFoundError):
catalog.get_dataset_version(dogs_dataset.name, 10000000000000000)


# Returns None if the table does not exist
def get_table_row_count(db, table_name):
if not db.has_table(table_name):
Expand Down

0 comments on commit 4c738ba

Please sign in to comment.