From a84687959c6d04fc100abb4a4ac0597cfec8fd06 Mon Sep 17 00:00:00 2001 From: Peter Ke Date: Tue, 15 Oct 2024 11:14:39 -0700 Subject: [PATCH] update to return pytransaction --- python/deltalake/_internal.pyi | 11 ++++++- python/deltalake/table.py | 15 ++------- python/src/lib.rs | 57 +++++++++++++++++++++++++++------- python/tests/test_writer.py | 23 +++++++------- 4 files changed, 70 insertions(+), 36 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 41e0bb5196..66b5dc8f8f 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -221,7 +221,7 @@ class RawDeltaTable: starting_timestamp: Optional[str] = None, ending_timestamp: Optional[str] = None, ) -> pyarrow.RecordBatchReader: ... - def transaction_versions(self) -> Dict[str, str]: ... + def transaction_versions(self) -> Dict[str, Transaction]: ... def rust_core_version() -> str: ... def write_new_deltalake( @@ -907,3 +907,12 @@ FilterConjunctionType = List[FilterLiteralType] FilterDNFType = List[FilterConjunctionType] FilterType = Union[FilterConjunctionType, FilterDNFType] PartitionFilterType = List[Tuple[str, str, Union[str, List[str]]]] + +class Transaction: + app_id: str + version: int + last_updated: Optional[int] + + def __init__( + self, app_id: str, version: int, last_updated: Optional[int] = None + ) -> None: ... diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f3c9e3bf6f..e54a1c3f8c 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -43,6 +43,7 @@ PyMergeBuilder, RawDeltaTable, TableFeatures, + Transaction, ) from deltalake._internal import create_deltalake as _create_deltalake from deltalake._util import encode_partition_value @@ -150,13 +151,6 @@ def __init__( self.cleanup_expired_logs = cleanup_expired_logs -@dataclass -class Transaction: - app_id: str - version: int - last_updated: Optional[int] = None - - @dataclass(init=True) class CommitProperties: """The commit properties. Controls the behaviour of the commit.""" @@ -1426,11 +1420,8 @@ def repair( ) return json.loads(metrics) - def transaction_versions(self) -> Dict[str, Dict[str, Any]]: - return { - app_id: json.loads(transaction) - for app_id, transaction in self._table.transaction_versions().items() - } + def transaction_versions(self) -> Dict[str, Transaction]: + return self._table.transaction_versions() class TableMerger: diff --git a/python/src/lib.rs b/python/src/lib.rs index d48fab5a5b..005076c719 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1233,16 +1233,11 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } - pub fn transaction_versions(&self) -> HashMap { + pub fn transaction_versions(&self) -> HashMap { self._table .get_app_transaction_version() - .iter() - .map(|(app_id, transaction)| { - ( - app_id.to_owned(), - serde_json::to_string(transaction).unwrap(), - ) - }) + .into_iter() + .map(|(app_id, transaction)| (app_id, PyTransaction::from(transaction))) .collect() } } @@ -1674,11 +1669,48 @@ pub struct PyPostCommitHookProperties { cleanup_expired_logs: Option, } -#[derive(FromPyObject)] +#[derive(Clone)] +#[pyclass(name = "Transaction", module = "deltalake._internal")] pub struct PyTransaction { - app_id: String, - version: i64, - last_updated: Option, + #[pyo3(get)] + pub app_id: String, + #[pyo3(get)] + pub version: i64, + #[pyo3(get)] + pub last_updated: Option, +} + +#[pymethods] +impl PyTransaction { + #[new] + #[pyo3(signature = (app_id, version, last_updated = None))] + fn new(app_id: String, version: i64, last_updated: Option) -> Self { + Self { + app_id, + version, + last_updated, + } + } + + fn __repr__(&self) -> String { + format!( + "Transaction(app_id={}, version={}, last_updated={})", + self.app_id, + self.version, + self.last_updated + .map_or("None".to_owned(), |n| n.to_string()) + ) + } +} + +impl From for PyTransaction { + fn from(value: Transaction) -> Self { + PyTransaction { + app_id: value.app_id, + version: value.version, + last_updated: value.last_updated, + } + } } impl From<&PyTransaction> for Transaction { @@ -2039,6 +2071,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; // There are issues with submodules, so we will expose them flat for now // See also: https://github.com/PyO3/pyo3/issues/759 m.add_class::()?; diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 2ee1770e62..c43e5d1136 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1996,11 +1996,11 @@ def test_write_timestamp(tmp_path: pathlib.Path): def test_write_transactions(tmp_path: pathlib.Path, sample_data: pa.Table): - transactions = [ + expected_transactions = [ Transaction(app_id="app_1", version=1), Transaction(app_id="app_2", version=2, last_updated=123456), ] - commit_properties = CommitProperties(app_transactions=transactions) + commit_properties = CommitProperties(app_transactions=expected_transactions) write_deltalake( table_or_uri=tmp_path, data=sample_data, @@ -2013,12 +2013,13 @@ def test_write_transactions(tmp_path: pathlib.Path, sample_data: pa.Table): transactions = delta_table.transaction_versions() assert len(transactions) == 2 - assert transactions["app_1"] == { - "appId": "app_1", - "version": 1, - } - assert transactions["app_2"] == { - "appId": "app_2", - "version": 2, - "lastUpdated": 123456, - } + + transaction_1 = transactions["app_1"] + assert transaction_1.app_id == "app_1" + assert transaction_1.version == 1 + assert transaction_1.last_updated is None + + transaction_2 = transactions["app_2"] + assert transaction_2.app_id == "app_2" + assert transaction_2.version == 2 + assert transaction_2.last_updated == 123456