From 16dd70cda8ea9dd835d5e99575492bd0d8b75366 Mon Sep 17 00:00:00 2001 From: Matthieu Doutreligne Date: Thu, 8 Dec 2022 09:09:13 +0100 Subject: [PATCH 1/7] plot_age_pyramid: add the possibility to have one datetime_ref that differs for each patient: eg. the date of their inclusion in the study. --- eds_scikit/plot/data_quality.py | 20 +++++++++++++------- tests/test_age_pyramid.py | 25 +++++++++++++++---------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/eds_scikit/plot/data_quality.py b/eds_scikit/plot/data_quality.py index ac8b7268..3bb9e720 100644 --- a/eds_scikit/plot/data_quality.py +++ b/eds_scikit/plot/data_quality.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import altair as alt import numpy as np @@ -12,7 +12,7 @@ def plot_age_pyramid( person: DataFrame, - datetime_ref: datetime = None, + datetime_ref: Union[datetime, str] = None, savefig: bool = False, filename: Optional[str] = None, ) -> Tuple[alt.Chart, Series]: @@ -26,8 +26,10 @@ def plot_age_pyramid( - `person_id`, dtype : any - `gender_source_value`, dtype : str, {'m', 'f'} - datetime_ref : datetime, + datetime_ref : Union[datetime, str], The reference date to compute population age from. + If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference. + If a datetime, the reference datetime is the same for all patients. If set to None, datetime.today() will be used instead. savefig : bool, @@ -55,11 +57,15 @@ def plot_age_pyramid( person_ = person.copy() - if datetime_ref: - today = pd.to_datetime(datetime_ref) + if type(datetime_ref) == datetime: + datetime_ref_ = pd.to_datetime(datetime_ref) + elif type(datetime_ref) == str: + if datetime_ref not in person_.columns: + raise ValueError(f"{datetime_ref} should be either a column of the dataframe, or a datetime.") + datetime_ref_ = person_[datetime_ref] else: - today = datetime.today() - person_["age"] = (today - person_["birth_datetime"]).dt.total_seconds() + datetime_ref_ = datetime.today() + person_["age"] = (datetime_ref_ - person_["birth_datetime"]).dt.total_seconds() person_["age"] /= 365 * 24 * 3600 bins = np.arange(0, 100, 10) diff --git a/tests/test_age_pyramid.py b/tests/test_age_pyramid.py index 6edb233b..4850316f 100644 --- a/tests/test_age_pyramid.py +++ b/tests/test_age_pyramid.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path import altair as alt @@ -8,31 +8,36 @@ from eds_scikit.datasets.synthetic.person import load_person from eds_scikit.plot.data_quality import plot_age_pyramid +import pandas as pd +import numpy as np data = load_person() +person_with_inclusion_date = data.person.copy() +person_with_inclusion_date["inclusion_datetime"] = person_with_inclusion_date["birth_datetime"] + pd.to_timedelta(np.random.randint(0, 1000, len(person_with_inclusion_date)), unit='d') -def test_plot_age_pyramid(): - original_person = data.person.copy() - - datetime_ref = datetime(2020, 1, 1) - chart, group_gender_age = plot_age_pyramid(data.person, datetime_ref, savefig=False) +@pytest.mark.parametrize( + "datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime"] + ) +def test_plot_age_pyramid(datetime_ref): + original_person = person_with_inclusion_date.copy() + chart, group_gender_age = plot_age_pyramid(person_with_inclusion_date, datetime_ref, savefig=False) assert isinstance(chart, alt.vegalite.v4.api.ConcatChart) assert isinstance(group_gender_age, Series) # Check that the data is unchanged - assert_frame_equal(original_person, data.person) + assert_frame_equal(original_person, person_with_inclusion_date) filename = "test.html" - _ = plot_age_pyramid(data.person, savefig=True, filename=filename) + _ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=filename) path = Path(filename) assert path.exists() path.unlink() with pytest.raises(ValueError, match="You have to set a filename"): - _ = plot_age_pyramid(data.person, savefig=True, filename=None) + _ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=None) with pytest.raises( ValueError, match="'filename' type must be str, got " ): - _ = plot_age_pyramid(data.person, savefig=True, filename=[1]) + _ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=[1]) From 605cde7a72ef8aa50a291b6fc412345fa4ec0750 Mon Sep 17 00:00:00 2001 From: Matthieu Doutreligne Date: Thu, 8 Dec 2022 09:22:10 +0100 Subject: [PATCH 2/7] added to changelog --- changelog.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/changelog.md b/changelog.md index 8b72c9d4..0f8787e5 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,9 @@ # Changelog +## v0.1.3 (2022-12-06) + +- Adding person-dependant `datetime_ref` to `plot_age_pyramid` + ## v0.1.2 (2022-12-05) ### Added From 38439e1c853cb007cde3bc03e672cb74b442c508 Mon Sep 17 00:00:00 2001 From: Matthieu Doutreligne Date: Thu, 8 Dec 2022 14:29:04 +0100 Subject: [PATCH 3/7] plot_age_pyramid: better type testing - add a new test for pytest.raise - add a parsable datetime string - add better error messages when bad value or bad type --- eds_scikit/plot/data_quality.py | 33 +++++++++++++++++++--------- tests/test_age_pyramid.py | 39 +++++++++++++++++++++++++++------ 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/eds_scikit/plot/data_quality.py b/eds_scikit/plot/data_quality.py index 3bb9e720..22a0c522 100644 --- a/eds_scikit/plot/data_quality.py +++ b/eds_scikit/plot/data_quality.py @@ -1,3 +1,4 @@ +from copy import copy from datetime import datetime from typing import Optional, Tuple, Union @@ -54,18 +55,30 @@ def plot_age_pyramid( raise ValueError("You have to set a filename") if not isinstance(filename, str): raise ValueError(f"'filename' type must be str, got {type(filename)}") - + datetime_ref_raw = copy(datetime_ref) person_ = person.copy() - - if type(datetime_ref) == datetime: - datetime_ref_ = pd.to_datetime(datetime_ref) - elif type(datetime_ref) == str: - if datetime_ref not in person_.columns: - raise ValueError(f"{datetime_ref} should be either a column of the dataframe, or a datetime.") - datetime_ref_ = person_[datetime_ref] + if datetime_ref is None: + datetime_ref = datetime.today() + elif isinstance(datetime_ref, datetime): + datetime_ref = pd.to_datetime(datetime_ref) + elif isinstance(datetime_ref, str): + if datetime_ref in person_.columns: + datetime_ref = person_[datetime_ref] + else: + datetime_ref = pd.to_datetime( + datetime_ref, errors="coerce" + ) # In case of error, will return NaT + if pd.isnull(datetime_ref): + raise ValueError( + f"`datetime_ref` must either be a column name or parseable date, " + f"got string '{datetime_ref_raw}'" + ) else: - datetime_ref_ = datetime.today() - person_["age"] = (datetime_ref_ - person_["birth_datetime"]).dt.total_seconds() + raise TypeError( + f"`datetime_ref` must be either None, a parseable string date" + f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}" + ) + person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds() person_["age"] /= 365 * 24 * 3600 bins = np.arange(0, 100, 10) diff --git a/tests/test_age_pyramid.py b/tests/test_age_pyramid.py index 4850316f..29f92e2c 100644 --- a/tests/test_age_pyramid.py +++ b/tests/test_age_pyramid.py @@ -1,27 +1,35 @@ -from datetime import datetime, timedelta +from datetime import datetime from pathlib import Path import altair as alt +import numpy as np +import pandas as pd import pytest from pandas.core.series import Series from pandas.testing import assert_frame_equal from eds_scikit.datasets.synthetic.person import load_person from eds_scikit.plot.data_quality import plot_age_pyramid -import pandas as pd -import numpy as np data = load_person() person_with_inclusion_date = data.person.copy() -person_with_inclusion_date["inclusion_datetime"] = person_with_inclusion_date["birth_datetime"] + pd.to_timedelta(np.random.randint(0, 1000, len(person_with_inclusion_date)), unit='d') +N = len(person_with_inclusion_date) +delta_days = pd.to_timedelta(np.random.randint(0, 1000, N), unit="d") + +person_with_inclusion_date["inclusion_datetime"] = ( + person_with_inclusion_date["birth_datetime"] + delta_days +) + @pytest.mark.parametrize( - "datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime"] - ) + "datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime", "2020-01-01"] +) def test_plot_age_pyramid(datetime_ref): original_person = person_with_inclusion_date.copy() - chart, group_gender_age = plot_age_pyramid(person_with_inclusion_date, datetime_ref, savefig=False) + chart, group_gender_age = plot_age_pyramid( + person_with_inclusion_date, datetime_ref, savefig=False + ) assert isinstance(chart, alt.vegalite.v4.api.ConcatChart) assert isinstance(group_gender_age, Series) @@ -41,3 +49,20 @@ def test_plot_age_pyramid(datetime_ref): ValueError, match="'filename' type must be str, got " ): _ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=[1]) + + +def test_plot_age_pyramid_datetime_ref_error(): + with pytest.raises( + ValueError, + match="`datetime_ref` must either be a column name or parseable date, got string '20x2-01-01'", + ): + _ = plot_age_pyramid( + person_with_inclusion_date, datetime_ref="20x2-01-01", savefig=False + ) + with pytest.raises( + TypeError, + match="`datetime_ref` must be either None, a parseable string date, a column name or a datetime. Got type: , 2022", + ): + _ = plot_age_pyramid( + person_with_inclusion_date, datetime_ref=2022, savefig=False + ) From 1b024f2139fc256ffbb4a49b37d9c2e6664125c1 Mon Sep 17 00:00:00 2001 From: Matthieu Date: Fri, 9 Dec 2022 13:39:55 +0100 Subject: [PATCH 4/7] Update changelog.md Co-authored-by: Thomas Petit-Jean <30775613+Thomzoy@users.noreply.github.com> --- changelog.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/changelog.md b/changelog.md index 0f8787e5..57d968c6 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,8 @@ # Changelog -## v0.1.3 (2022-12-06) +## Pending + +### Added - Adding person-dependant `datetime_ref` to `plot_age_pyramid` From b2876679cabd33aa71194f9e8adb16133e1c1433 Mon Sep 17 00:00:00 2001 From: Vincent M Date: Thu, 12 Jan 2023 19:21:34 +0100 Subject: [PATCH 5/7] improve speed for koalas op and add bd.cache --- eds_scikit/plot/data_quality.py | 46 ++++++++++++------- .../utils/custom_implem/custom_implem.py | 15 ++++-- eds_scikit/utils/custom_implem/cut.py | 18 ++++---- tests/test_age_pyramid.py | 4 +- 4 files changed, 53 insertions(+), 30 deletions(-) diff --git a/eds_scikit/plot/data_quality.py b/eds_scikit/plot/data_quality.py index ce472145..2f840baa 100644 --- a/eds_scikit/plot/data_quality.py +++ b/eds_scikit/plot/data_quality.py @@ -1,10 +1,11 @@ from copy import copy from datetime import datetime -from typing import Optional, Tuple, Union +from typing import Tuple, Union import altair as alt import numpy as np import pandas as pd +from pandas.api.types import is_integer_dtype from pandas.core.frame import DataFrame from pandas.core.series import Series @@ -14,7 +15,7 @@ def plot_age_pyramid( person: DataFrame, - datetime_ref: datetime = None, + datetime_ref: Union[datetime, str] = None, filename: str = None, savefig: bool = False, return_vector: bool = False, @@ -57,15 +58,15 @@ def plot_age_pyramid( raise ValueError("You have to set a filename") if not isinstance(filename, str): raise ValueError(f"'filename' type must be str, got {type(filename)}") - datetime_ref_raw = copy(datetime_ref) - person_ = person.copy() + datetime_ref_original = copy(datetime_ref) + if datetime_ref is None: datetime_ref = datetime.today() elif isinstance(datetime_ref, datetime): datetime_ref = pd.to_datetime(datetime_ref) elif isinstance(datetime_ref, str): - if datetime_ref in person_.columns: - datetime_ref = person_[datetime_ref] + if datetime_ref in person.columns: + datetime_ref = person[datetime_ref] else: datetime_ref = pd.to_datetime( datetime_ref, errors="coerce" @@ -73,33 +74,46 @@ def plot_age_pyramid( if pd.isnull(datetime_ref): raise ValueError( f"`datetime_ref` must either be a column name or parseable date, " - f"got string '{datetime_ref_raw}'" + f"got string '{datetime_ref_original}'" ) else: raise TypeError( f"`datetime_ref` must be either None, a parseable string date" f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}" ) - person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds() - person_["age"] /= 365 * 24 * 3600 + + person = person.loc[person["gender_source_value"].isin(["m", "f"])] + + deltas = datetime_ref - person["birth_datetime"] + if not is_integer_dtype(deltas): + deltas = deltas.dt.total_seconds() + person["age"] = deltas / 365 * 24 * 3600 bins = np.arange(0, 100, 10) labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])] - person_["age_bins"] = bd.cut(person_["age"], bins=bins, labels=labels) - person_["age_bins"] = ( - person_["age_bins"].astype(str).str.lower().str.replace("nan", "90+") - ) + # This is equivalent to `pd.cut()` for pandas and this call our custom `cut` + # implementation for koalas. + person["age_bins"] = bd.cut(person["age"], bins=bins, labels=labels) - person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])] - group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[ + # This is equivalent to `person.cache()` for koalas and this is a no-op + # for pandas. + # Cache the intermediate results of the transformation so that other transformation + # runs on top of cached will perform faster. + # TODO: try to remove it and check perfs. + bd.cache(person) + + group_gender_age = person.groupby(["gender_source_value", "age_bins"])[ "person_id" ].count() # Convert to pandas to ease plotting. - # Since we have aggregated the data, this operation won't crash. group_gender_age = bd.to_pandas(group_gender_age) + group_gender_age["age_bins"] = ( + person["age_bins"].astype(str).str.lower().str.replace("nan", "90+") + ) + male = group_gender_age["m"].reset_index() female = group_gender_age["f"].reset_index() diff --git a/eds_scikit/utils/custom_implem/custom_implem.py b/eds_scikit/utils/custom_implem/custom_implem.py index 81c5ae2a..55430eb9 100644 --- a/eds_scikit/utils/custom_implem/custom_implem.py +++ b/eds_scikit/utils/custom_implem/custom_implem.py @@ -13,6 +13,17 @@ class CustomImplem: All public facing methods must be stateless and defined as classmethods. """ + @classmethod + def cache(cls, obj: DataFrame, backend=None) -> None: + """Run df.cache() for Koalas. No-op for pandas.""" + if backend is pd: + return + elif backend is ks: + obj.spark.cache() + return + else: + raise ValueError(f"Unknown backend {backend}") + @classmethod def add_unique_id( cls, @@ -27,9 +38,7 @@ def add_unique_id( elif backend is ks: return obj.koalas.attach_id_column(id_type="distributed", column=col_name) else: - raise NotImplementedError( - f"No method 'add_unique_id' is available for backend '{backend}'." - ) + raise ValueError(f"Unknown backend {backend}") @classmethod def cut( diff --git a/eds_scikit/utils/custom_implem/cut.py b/eds_scikit/utils/custom_implem/cut.py index 85dbc2b5..7cc247c7 100644 --- a/eds_scikit/utils/custom_implem/cut.py +++ b/eds_scikit/utils/custom_implem/cut.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pandas.core.algorithms as algos +from databricks import koalas as ks from pandas import IntervalIndex, to_datetime, to_timedelta from pandas._libs import Timedelta, Timestamp from pandas._libs.lib import infer_dtype @@ -371,7 +372,9 @@ def _bins_to_cuts( # hack to bypass "TypeError: 'Series' object does not support item assignment" ids = ids.to_frame() ids.loc[na_mask] = 0 - ids = ids[ids.columns[0]] + ids.columns = ["key"] + ids["key"] -= 1 + # ids = ids[ids.columns[0]] if labels: if not (labels is None or is_list_like(labels)): @@ -400,17 +403,16 @@ def _bins_to_cuts( ordered=ordered, ) - label_mapping = dict(zip(range(len(labels)), labels)) + labels = ks.DataFrame({"key": range(len(labels)), "val": labels}) # x values outside of bins edges (i.e. when ids = 0) are mapped to NaN - result = (ids - 1).map(label_mapping) - result.fillna(np.nan, inplace=True) + result = ids.merge(labels, on="key", how="left") + # result = (ids - 1).map(label_mapping) + result = result["val"].fillna(np.nan) else: - result = ids - 1 # hack to bypass "TypeError: 'Series' object does not support item assignment" - result = result.to_frame() - result.loc[na_mask] = np.nan - result = result[result.columns[0]] + ids.loc[na_mask] = np.nan + result = result["val"] return result, bins diff --git a/tests/test_age_pyramid.py b/tests/test_age_pyramid.py index f4338d0f..7b049a01 100644 --- a/tests/test_age_pyramid.py +++ b/tests/test_age_pyramid.py @@ -27,9 +27,7 @@ ) def test_plot_age_pyramid(datetime_ref): original_person = person_with_inclusion_date.copy() - chart = plot_age_pyramid( - person_with_inclusion_date, datetime_ref, savefig=False - ) + chart = plot_age_pyramid(person_with_inclusion_date, datetime_ref, savefig=False) assert isinstance(chart, alt.vegalite.v4.api.ConcatChart) # Check that the data is unchanged From 53c18550de31188f68fe321df5148931b0134d01 Mon Sep 17 00:00:00 2001 From: Vincent M Date: Thu, 12 Jan 2023 19:38:08 +0100 Subject: [PATCH 6/7] fix keyval error in plot func --- eds_scikit/plot/data_quality.py | 20 +++++++++----------- tests/test_age_pyramid.py | 6 +++--- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/eds_scikit/plot/data_quality.py b/eds_scikit/plot/data_quality.py index 2f840baa..1a2ee83b 100644 --- a/eds_scikit/plot/data_quality.py +++ b/eds_scikit/plot/data_quality.py @@ -7,7 +7,6 @@ import pandas as pd from pandas.api.types import is_integer_dtype from pandas.core.frame import DataFrame -from pandas.core.series import Series from ..utils.checks import check_columns from ..utils.framework import bd @@ -19,7 +18,7 @@ def plot_age_pyramid( filename: str = None, savefig: bool = False, return_vector: bool = False, -) -> Tuple[alt.Chart, Series]: +) -> Tuple[alt.Chart, DataFrame]: """Plot an age pyramid from a 'person' pandas DataFrame. Parameters @@ -103,19 +102,18 @@ def plot_age_pyramid( # TODO: try to remove it and check perfs. bd.cache(person) - group_gender_age = person.groupby(["gender_source_value", "age_bins"])[ - "person_id" - ].count() + group = person.groupby(["gender_source_value", "age_bins"])["person_id"].count() # Convert to pandas to ease plotting. - group_gender_age = bd.to_pandas(group_gender_age) + group = bd.to_pandas(group) - group_gender_age["age_bins"] = ( + group = group.to_frame().reset_index() + group["age_bins"] = ( person["age_bins"].astype(str).str.lower().str.replace("nan", "90+") ) - male = group_gender_age["m"].reset_index() - female = group_gender_age["f"].reset_index() + male = group.loc[group["gender_source_value"] == "m"].reset_index() + female = group.loc[group["gender_source_value"] == "f"].reset_index() left = ( alt.Chart(male) @@ -151,9 +149,9 @@ def plot_age_pyramid( if savefig: chart.save(filename) if return_vector: - return group_gender_age + return group if return_vector: - return chart, group_gender_age + return chart, group return chart diff --git a/tests/test_age_pyramid.py b/tests/test_age_pyramid.py index 7b049a01..bd37118a 100644 --- a/tests/test_age_pyramid.py +++ b/tests/test_age_pyramid.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd import pytest -from pandas.core.series import Series +from pandas.core.frame import DataFrame from pandas.testing import assert_frame_equal from eds_scikit.datasets.synthetic.person import load_person @@ -45,13 +45,13 @@ def test_age_pyramid_output(): group_gender_age = plot_age_pyramid( data.person, savefig=True, return_vector=True, filename=filename ) - assert isinstance(group_gender_age, Series) + assert isinstance(group_gender_age, DataFrame) chart, group_gender_age = plot_age_pyramid( data.person, savefig=False, return_vector=True ) assert isinstance(chart, alt.vegalite.v4.api.ConcatChart) - assert isinstance(group_gender_age, Series) + assert isinstance(group_gender_age, DataFrame) chart = plot_age_pyramid(data.person, savefig=False, return_vector=False) assert isinstance(chart, alt.vegalite.v4.api.ConcatChart) From 1b4f6a7a6c968ec045da8111e6d46975d00fe0a3 Mon Sep 17 00:00:00 2001 From: Vincent M Date: Thu, 12 Jan 2023 19:41:22 +0100 Subject: [PATCH 7/7] typo --- eds_scikit/plot/data_quality.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eds_scikit/plot/data_quality.py b/eds_scikit/plot/data_quality.py index 1a2ee83b..8c6bc613 100644 --- a/eds_scikit/plot/data_quality.py +++ b/eds_scikit/plot/data_quality.py @@ -109,7 +109,7 @@ def plot_age_pyramid( group = group.to_frame().reset_index() group["age_bins"] = ( - person["age_bins"].astype(str).str.lower().str.replace("nan", "90+") + group["age_bins"].astype(str).str.lower().str.replace("nan", "90+") ) male = group.loc[group["gender_source_value"] == "m"].reset_index()