Skip to content

Commit

Permalink
feat(result): add cache_columns and data_with_columns method to result
Browse files Browse the repository at this point in the history
  - cache_columns allows the user to specify columns to put in the cache
- data_with_columns returns a dataframe filtered only to the specified columns
This commit also adds ruff private member access checks
  • Loading branch information
tomjholland committed Feb 23, 2025
1 parent 41ffbaf commit dc7e73d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyprobe/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def validate_required_columns(self) -> "AnalysisValidator":
Raises:
ValueError: If any of the required columns are missing.
"""
self.input_data._polars_cache.collect_columns(*self.required_columns)
self.input_data.cache_columns(*self.required_columns)
return self

@property
Expand Down
5 changes: 2 additions & 3 deletions pyprobe/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ def _retrieve_relevant_columns(

except ValueError:
continue
if quantity in result_obj._polars_cache.quantities:
if quantity in result_obj.quantities:
relevant_columns.append(arg)
if len(relevant_columns) == 0:
raise ValueError(
f"None of the columns in {all_args} are present in the Result object."
)
result_obj._polars_cache.collect_columns(*relevant_columns)
return result_obj._get_data_subset(*relevant_columns)
return result_obj.data_with_columns(*relevant_columns)


try:
Expand Down
27 changes: 22 additions & 5 deletions pyprobe/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def cached_dataframe(self) -> pl.DataFrame:
"""Return the cached dataframe as a Polars DataFrame."""
if self._cached_dataframe is None:
self._cached_dataframe = pl.DataFrame(self.cache)
return pl.DataFrame(self.cache)
elif set(self._cached_dataframe.collect_schema().names()) != set(
self.cache.keys()
):
self._cached_dataframe = pl.DataFrame(self.cache)
return self._cached_dataframe

@cached_dataframe.setter
def cached_dataframe(self, value: pl.DataFrame) -> None:
Expand Down Expand Up @@ -202,6 +206,19 @@ def live_dataframe(self, value: pl.DataFrame) -> None:
"""Set the data as a polars DataFrame."""
self._polars_cache.base_dataframe = value

def cache_columns(self, *columns: str) -> None:
"""Collect columns from the base dataframe and add to the cache.
If no columns are provided, all columns will be cached.
Args:
*columns (str): The columns to cache.
"""
if columns:
self._polars_cache.collect_columns(*columns)
else:
self._polars_cache.collect_columns(*self.column_list)

@property
def data(self) -> pl.DataFrame:
"""Return the data as a polars DataFrame.
Expand Down Expand Up @@ -261,11 +278,11 @@ def hvplot(self, *args: Any, **kwargs: Any) -> None:
":code:`hvplot.extension('plotly')`.\n\n" + (hvplot.__doc__ or "")
)

def _get_data_subset(self, *column_names: str) -> pl.DataFrame:
def data_with_columns(self, *column_names: str) -> pl.DataFrame:
"""Return a subset of the data with the specified columns.
Args:
*column_names: The columns to include in the new result object.
*column_names: The columns to include in the returned dataframe.
Returns:
A subset of the data with the specified columns.
Expand All @@ -283,7 +300,7 @@ def __getitem__(self, *column_names: str) -> "Result":
Result: A new result object with the specified columns.
"""
return Result(
base_dataframe=self._get_data_subset(*column_names), info=self.info
base_dataframe=self.data_with_columns(*column_names), info=self.info
)

def get(
Expand All @@ -302,7 +319,7 @@ def get(
ValueError: If no column names are provided.
ValueError: If a column name is not in the data.
"""
array = self._get_data_subset(*column_names).to_numpy()
array = self.data_with_columns(*column_names).to_numpy()
if len(column_names) == 0:
error_msg = "At least one column name must be provided."
logger.error(error_msg)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ select = [
"A", # flake8-builtins: Check for Python builtins being used as variables or parameters
"D", # flake8-docstrings: Check docstrings
"I", # isort: Check and enforce import ordering
"TID", # flake8-tidy-imports
"SLF001", # private-member-access: Checks for accesses on "private" class members.
]

[tool.ruff.lint.pydocstyle]
Expand Down
50 changes: 50 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def test_collect_columns():
assert cache.cache["a"].to_list() == expected_a.to_list()
pl_testing.assert_frame_equal(cache.cached_dataframe, lf.select("a").collect())

# Test making a second collection
cache.collect_columns("b")
pl.testing.assert_frame_equal(cache.cached_dataframe, lf.select("a", "b").collect())

# Test multiple column collection
cache = PolarsColumnCache(lf)
cache.collect_columns("a", "b")
Expand All @@ -69,6 +73,26 @@ def test_collect_columns():
assert cache.cache["Current [mA]"].to_list() == expected_current.to_list()


def test_cached_dataframe():
"""Test the cached_dataframe property."""
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
cache = PolarsColumnCache(lf)
assert cache._cached_dataframe is None
assert cache.cached_dataframe.is_empty()

cache.collect_columns("a")
assert cache._cached_dataframe.is_empty()
pl.testing.assert_frame_equal(cache.cached_dataframe, lf.select("a").collect())
pl.testing.assert_frame_equal(cache._cached_dataframe, lf.select("a").collect())

cache.collect_columns("b")
pl.testing.assert_frame_equal(cache._cached_dataframe, lf.select("a").collect())
pl.testing.assert_frame_equal(cache.cached_dataframe, lf.select("a", "b").collect())
pl.testing.assert_frame_equal(
cache._cached_dataframe, lf.select("a", "b").collect()
)


def test_live_dataframe():
"""Test the live_dataframe property."""
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
Expand Down Expand Up @@ -115,6 +139,32 @@ def test_init(Result_fixture):
assert isinstance(Result_fixture.info, dict)


def test_cache_columns():
"""Test the collect method."""
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
result_object = Result(base_dataframe=lf, info={})
result_object.cache_columns("a")
pl_testing.assert_frame_equal(
result_object._polars_cache.cached_dataframe, lf.select("a").collect()
)

result_object = Result(base_dataframe=lf, info={})
result_object.cache_columns("a", "b")
pl_testing.assert_frame_equal(
result_object._polars_cache.cached_dataframe,
lf.select("a", "b").collect(),
check_column_order=False,
)

result_object = Result(base_dataframe=lf, info={})
result_object.cache_columns()
pl_testing.assert_frame_equal(
result_object._polars_cache.cached_dataframe,
lf.collect(),
check_column_order=False,
)


def test_get(Result_fixture):
"""Test the get method."""
current = Result_fixture.get("Current [A]")
Expand Down

0 comments on commit dc7e73d

Please sign in to comment.