Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] plot_age_pyramid: add the possibility to have one datetime_ref that differs for each patient #18

Closed
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## v0.1.3 (2022-12-06)

- Adding person-dependant `datetime_ref` to `plot_age_pyramid`

## v0.1.2 (2022-12-05)

### Added
Expand Down
20 changes: 13 additions & 7 deletions eds_scikit/plot/data_quality.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import altair as alt
import numpy as np
Expand All @@ -12,7 +12,7 @@

def plot_age_pyramid(
person: DataFrame,
datetime_ref: datetime = None,
datetime_ref: Union[datetime, str] = None,
savefig: bool = False,
filename: Optional[str] = None,
) -> Tuple[alt.Chart, Series]:
Expand All @@ -26,8 +26,10 @@ def plot_age_pyramid(
- `person_id`, dtype : any
- `gender_source_value`, dtype : str, {'m', 'f'}

datetime_ref : datetime,
datetime_ref : Union[datetime, str],
The reference date to compute population age from.
If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference.
If a datetime, the reference datetime is the same for all patients.
If set to None, datetime.today() will be used instead.

savefig : bool,
Expand Down Expand Up @@ -55,11 +57,15 @@ def plot_age_pyramid(

person_ = person.copy()

if datetime_ref:
today = pd.to_datetime(datetime_ref)
if type(datetime_ref) == datetime:
datetime_ref_ = pd.to_datetime(datetime_ref)
elif type(datetime_ref) == str:
if datetime_ref not in person_.columns:
raise ValueError(f"{datetime_ref} should be either a column of the dataframe, or a datetime.")
datetime_ref_ = person_[datetime_ref]
else:
today = datetime.today()
person_["age"] = (today - person_["birth_datetime"]).dt.total_seconds()
datetime_ref_ = datetime.today()
person_["age"] = (datetime_ref_ - person_["birth_datetime"]).dt.total_seconds()
person_["age"] /= 365 * 24 * 3600

bins = np.arange(0, 100, 10)
Expand Down
25 changes: 15 additions & 10 deletions tests/test_age_pyramid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path

import altair as alt
Expand All @@ -8,31 +8,36 @@

from eds_scikit.datasets.synthetic.person import load_person
from eds_scikit.plot.data_quality import plot_age_pyramid
import pandas as pd
import numpy as np

data = load_person()

person_with_inclusion_date = data.person.copy()
person_with_inclusion_date["inclusion_datetime"] = person_with_inclusion_date["birth_datetime"] + pd.to_timedelta(np.random.randint(0, 1000, len(person_with_inclusion_date)), unit='d')

def test_plot_age_pyramid():
original_person = data.person.copy()

datetime_ref = datetime(2020, 1, 1)
chart, group_gender_age = plot_age_pyramid(data.person, datetime_ref, savefig=False)
@pytest.mark.parametrize(
"datetime_ref", [datetime(2020, 1, 1), "inclusion_datetime"]
)
def test_plot_age_pyramid(datetime_ref):
original_person = person_with_inclusion_date.copy()
chart, group_gender_age = plot_age_pyramid(person_with_inclusion_date, datetime_ref, savefig=False)
assert isinstance(chart, alt.vegalite.v4.api.ConcatChart)
assert isinstance(group_gender_age, Series)

# Check that the data is unchanged
assert_frame_equal(original_person, data.person)
assert_frame_equal(original_person, person_with_inclusion_date)

filename = "test.html"
_ = plot_age_pyramid(data.person, savefig=True, filename=filename)
_ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=filename)
path = Path(filename)
assert path.exists()
path.unlink()

with pytest.raises(ValueError, match="You have to set a filename"):
_ = plot_age_pyramid(data.person, savefig=True, filename=None)
_ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=None)

with pytest.raises(
ValueError, match="'filename' type must be str, got <class 'list'>"
):
_ = plot_age_pyramid(data.person, savefig=True, filename=[1])
_ = plot_age_pyramid(person_with_inclusion_date, savefig=True, filename=[1])