Skip to content

Commit

Permalink
#2456 switch to lazy loading for incremental data sets
Browse files Browse the repository at this point in the history
  • Loading branch information
bludau-peter committed Apr 6, 2023
1 parent a2746f2 commit 78c483a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
2 changes: 1 addition & 1 deletion docs/source/data/kedro_io.md
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ The checkpoint file is only created _after_ [the partitioned dataset is explicit
#### Incremental dataset load

Loading `IncrementalDataSet` works similarly to [`PartitionedDataSet`](#partitioned-dataset-load) with several exceptions:
1. `IncrementalDataSet` loads the data _eagerly_, so the values in the returned dictionary represent the actual data stored in the corresponding partition, rather than a pointer to the load function. `IncrementalDataSet` considers a partition relevant for processing if its ID satisfies the comparison function, given the checkpoint value.
1. `IncrementalDataSet` considers a partition relevant for processing if its ID satisfies the comparison function, given the checkpoint value. A load function is returned for all relevant partitions.
2. `IncrementalDataSet` _does not_ raise a `DataSetError` if load finds no partitions to return - an empty dictionary is returned instead. An empty list of available partitions is part of a normal workflow for `IncrementalDataSet`.

#### Incremental dataset save
Expand Down
28 changes: 11 additions & 17 deletions kedro/io/partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ def _path_to_partition(self, path: str) -> str:
path = path[: -len(self._filename_suffix)]
return path

def _get_valid_loaded_partitions(
self, partitions: Dict[str, Callable[[], Any]]
) -> Dict[str, Callable[[], Any]]:
if not partitions:
raise DataSetError(f"No partitions found in '{self._path}'")
return partitions

def _load(self) -> Dict[str, Callable[[], Any]]:
partitions = {}

Expand All @@ -287,10 +294,7 @@ def _load(self) -> Dict[str, Callable[[], Any]]:
partition_id = self._path_to_partition(partition)
partitions[partition_id] = dataset.load

if not partitions:
raise DataSetError(f"No partitions found in '{self._path}'")

return partitions
return self._get_valid_loaded_partitions(partitions)

def _save(self, data: Dict[str, Any]) -> None:
if self._overwrite and self._filesystem.exists(self._normalized_path):
Expand Down Expand Up @@ -382,7 +386,6 @@ def __init__(
load_args: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
):

"""Creates a new instance of ``IncrementalDataSet``.
Args:
Expand Down Expand Up @@ -521,18 +524,9 @@ def _read_checkpoint(self) -> Union[str, None]:
except DataSetError:
return None

def _load(self) -> Dict[str, Callable[[], Any]]:
partitions = {}

for partition in self._list_partitions():
partition_id = self._path_to_partition(partition)
kwargs = deepcopy(self._dataset_config)
# join the protocol back since PySpark may rely on it
kwargs[self._filepath_arg] = self._join_protocol(partition)
partitions[partition_id] = self._dataset_type( # type: ignore
**kwargs
).load()

def _get_valid_loaded_partitions(
self, partitions: Dict[str, Callable[[], Any]]
) -> Dict[str, Callable[[], Any]]:
return partitions

def confirm(self) -> None:
Expand Down
9 changes: 6 additions & 3 deletions tests/io/test_incremental_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def test_load_and_confirm(self, local_csvs, partitioned_data_pandas):
pds = IncrementalDataSet(str(local_csvs), DATASET)
loaded = pds.load()
assert loaded.keys() == partitioned_data_pandas.keys()
for partition_id, data in loaded.items():
for partition_id, load_func in loaded.items():
data = load_func()
assert_frame_equal(data, partitioned_data_pandas[partition_id])

checkpoint_path = local_csvs / pds.DEFAULT_CHECKPOINT_FILENAME
Expand Down Expand Up @@ -97,7 +98,8 @@ def test_save(self, local_csvs):
pds.save({new_partition_key: df})
assert new_partition_path.exists()
loaded = pds.load()
assert_frame_equal(loaded[new_partition_key], df)
reloaded_data = loaded[new_partition_key]()
assert_frame_equal(reloaded_data, df)

@pytest.mark.parametrize(
"filename_suffix,expected_partitions",
Expand Down Expand Up @@ -381,7 +383,8 @@ def test_load_and_confirm(self, mocked_csvs_in_s3, partitioned_data_pandas):
assert pds._checkpoint._protocol == "s3"
loaded = pds.load()
assert loaded.keys() == partitioned_data_pandas.keys()
for partition_id, data in loaded.items():
for partition_id, load_func in loaded.items():
data = load_func()
assert_frame_equal(data, partitioned_data_pandas[partition_id])

assert not pds._checkpoint.exists()
Expand Down

0 comments on commit 78c483a

Please sign in to comment.