Skip to content

Commit

Permalink
feat: Add method to combine multiple result objects
Browse files Browse the repository at this point in the history
This will integrate their info dicts into the dataframe
  • Loading branch information
tomjholland committed Jan 2, 2025
1 parent 6bc9649 commit 744f496
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
35 changes: 31 additions & 4 deletions pyprobe/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def __init__(self, base_dataframe: pl.LazyFrame | pl.DataFrame) -> None:
self.cache: Dict[str, pl.Series] = {}
self._cached_dataframe = None
self._base_dataframe = base_dataframe
if isinstance(base_dataframe, pl.LazyFrame):
self.base_dataframe = base_dataframe
else:
self.base_dataframe = base_dataframe
if isinstance(base_dataframe, pl.DataFrame):
self.cached_dataframe = base_dataframe

@property
Expand Down Expand Up @@ -575,3 +572,33 @@ def build(
data.append(step_data)
data = pl.concat(data)
return cls(base_dataframe=data, info=info)


def combine_results(
results: List[Result],
concat_method: str = "diagonal",
) -> Result:
"""Combine multiple Result objects into a single Result object.
This method should be used to combine multiple Result objects that have different
entries in their info dictionaries. The info dictionaries of the Result objects will
be integrated into the dataframe of the new Result object
Args:
results (List[Result]): The Result objects to combine.
concat_method (str):
The method to use for concatenation. Default is 'diagonal'. See the
polars.concat method documentation for more information.
Returns:
Result: A new result object with the combined data.
"""
for result in results:
instructions = [
pl.lit(result.info[key]).alias(key) for key in result.info.keys()
]
result.live_dataframe = result.live_dataframe.with_columns(instructions)
combined_result = results[0].clean_copy()
combined_result.info = {}
combined_result.extend(results[0:], concat_method=concat_method)
return combined_result
25 changes: 24 additions & 1 deletion tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import polars.testing as pl_testing
import pytest

from pyprobe.result import PolarsColumnCache, Result
from pyprobe.result import PolarsColumnCache, Result, combine_results


def test_PolarsColumnCache_lazyframe():
Expand Down Expand Up @@ -466,3 +466,26 @@ def test_clean_copy(reduced_result_fixture):
assert isinstance(clean_result, Result)
assert isinstance(clean_result.base_dataframe, pl.LazyFrame)
pl_testing.assert_frame_equal(clean_result.data, new_df)


def test_combine_results():
"""Test the combine results method."""
result1 = Result(
base_dataframe=pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}),
info={"test index": 1.0},
)
result2 = Result(
base_dataframe=pl.DataFrame({"a": [7, 8, 9], "b": [10, 11, 12]}),
info={"test index": 2.0},
)
combined_result = combine_results([result1, result2])
expected_data = pl.DataFrame(
{
"a": [1, 2, 3, 7, 8, 9],
"b": [4, 5, 6, 10, 11, 12],
"test index": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
}
)
pl_testing.assert_frame_equal(
combined_result.data, expected_data, check_column_order=False
)

0 comments on commit 744f496

Please sign in to comment.