Skip to content

Commit

Permalink
plots: support date and group customisation in get_epicurve()
Browse files Browse the repository at this point in the history
  • Loading branch information
abhidg committed Aug 29, 2024
1 parent 13b0839 commit 1c2a470
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 49 deletions.
73 changes: 42 additions & 31 deletions src/obr/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Library of plots used in most outbreaks
"""

import re
import logging

import pandas as pd
Expand Down Expand Up @@ -69,31 +68,43 @@ def get_delays(
return both[target_col] - both[onset_col]


def get_epicurve(df: pd.DataFrame, cumulative: bool = True) -> pd.DataFrame:
"""Returns epidemic curve - number of cases by (estimated) date of symptom onset"""
df["Date_onset_estimated"] = df.Date_onset_estimated.map(
lambda x: (
pd.to_datetime(x)
if isinstance(x, str) and re.match(REGEX_DATE, x)
else None
def get_epicurve(
df: pd.DataFrame,
date_col: str,
groups: list[tuple[str, list[str]]],
cumulative: bool = True,
) -> pd.DataFrame:
"""Returns epidemic curve
Parameters
----------
df
Data from which epicurve is obtained
date_col
Date column to use
groups
List of tuples of column name and value that will be grouped
such as ("Case_status", ["confirmed", "probable"]) or
("Outcome", ["death", "recovered"])
cumulative
Whether to return cumulative counts (default = true)
"""
epicurve = None
for i, group in enumerate(groups):
group_column, group_values = group
group_epicurve = (
df[~pd.isna(df[date_col]) & df[group_column].isin(group_values)]
.groupby([date_col, group_column])
.size()
.reset_index()
.pivot(index=date_col, columns=group_column, values=0)
.fillna(0)
.astype(int)
)
)

epicurve = (
df[
~pd.isna(df.Date_onset_estimated)
& df.Case_status.isin(["confirmed", "probable"])
]
.groupby(["Date_onset_estimated", "Case_status"])
.size()
.reset_index()
.pivot(index="Date_onset_estimated", columns="Case_status", values=0)
.fillna(0)
.astype(int)
)
# cannot merge first epicurve
epicurve = group_epicurve if i == 0 else pd.concat([epicurve, group_epicurve])
if cumulative:
epicurve["confirmed"] = epicurve.confirmed.cumsum()
epicurve["probable"] = epicurve.probable.cumsum()
epicurve = epicurve.cumsum()
return epicurve.reset_index()


Expand All @@ -119,29 +130,29 @@ def get_timeseries_location_status(
statuses = ["confirmed", "probable"]
df = df[
df.Case_status.isin(statuses)
& ~pd.isna(df.Date_onset_estimated)
& ~pd.isna(df.Date_onset)
& ~pd.isna(df.Location_District)
]
locations = sorted(set(df.Location_District)) + [None]
mindate, maxdate = df.Date_onset_estimated.min(), df.Date_onset_estimated.max()
mindate, maxdate = df.Date_onset.min(), df.Date_onset.max()

def timeseries_for_location(location: str | None) -> pd.DataFrame:
if location is None:
counts = (
df.groupby(["Date_onset_estimated", "Case_status"])
df.groupby(["Date_onset", "Case_status"])
.size()
.reset_index()
.pivot(index="Date_onset_estimated", columns="Case_status", values=0)
.pivot(index="Date_onset", columns="Case_status", values=0)
.fillna(0)
.astype(int)
)
else:
counts = (
df[df.Location_District == location]
.groupby(["Date_onset_estimated", "Case_status"])
.groupby(["Date_onset", "Case_status"])
.size()
.reset_index()
.pivot(index="Date_onset_estimated", columns="Case_status", values=0)
.pivot(index="Date_onset", columns="Case_status", values=0)
.fillna(0)
.astype(int)
)
Expand All @@ -158,7 +169,7 @@ def timeseries_for_location(location: str | None) -> pd.DataFrame:
timeseries = pd.concat(map(timeseries_for_location, locations)).fillna(0)
for col in ["daily_" + s for s in statuses] + ["cumulative_" + s for s in statuses]:
timeseries[col] = timeseries[col].astype(int)
return timeseries.reset_index(names="Date_onset_estimated")
return timeseries.reset_index(names="Date_onset")


def plot_timeseries_location_status(
Expand Down
31 changes: 25 additions & 6 deletions src/obr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@
Briefing report generator for Marburg 2023 outbreak
"""

import re
import logging
import datetime
from pathlib import Path
from typing import Any, Callable
from functools import cache

import boto3
import chevron
import pandas as pd
import numpy as np
from dateutil.parser import ParserError
import plotly.graph_objects as go
import plotly.io
import plotly.express as px
from plotly.subplots import make_subplots

PlotFunction = Callable[..., dict[str, Any] | go.Figure]
PlotData = tuple[str, PlotFunction, dict[str, Any]]
Expand All @@ -34,6 +30,22 @@
(70, 79),
(80, 120),
]
REGEX_DATE = r"^202\d-[0,1]\d-[0-3]\d"


def fix_datetimes(df: pd.DataFrame):
"Convert date fields to datetime in place"
date_columns = [c for c in df.columns if c.lower().startswith("date_")] + [
"Source_I_Date report"
]
for date_col in date_columns:
df[date_col] = df[date_col].map(
lambda x: (
pd.to_datetime(x)
if isinstance(x, str) and re.match(REGEX_DATE, x)
else None
)
)


def get_age_bins(age: str) -> range:
Expand Down Expand Up @@ -114,6 +126,13 @@ def invalidate_cache(
raise


def read_csv(filename: str) -> pd.DataFrame:
"Helper function with post-processing steps after pd.read_csv"
df = pd.read_csv(filename, dtype=str, na_values=["N/K", "NK"])
fix_datetimes(df)
return df


def build(
outbreak_name: str,
data_url: str,
Expand All @@ -128,7 +147,7 @@ def build(
if not (template := Path(__file__).parent / "outbreaks" / output_file).exists():
raise FileNotFoundError(f"Template for outbreak not found at: {template}")
var = {"published_date": str(date)}
df = pd.read_csv(data_url, na_values=["NK", "N/K"])
df = read_csv(data_url, na_values=["NK", "N/K"], dtype=str)
for plot in plots:
kwargs = {} if len(plot) == 2 else plot[2]
if plot[0] == "data":
Expand Down
16 changes: 8 additions & 8 deletions tests/test_data.csv
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ID,Case_status,Age,Gender,Date_onset,Date_onset_estimated,Date_death,Date_of_first_consult,Data_up_to,Location_District
1,confirmed,50-55,male,2023-03-05,2023-03-05,2023-03-09,,2023-04-04,Bata
2,probable,40-46,female,2023-02-06,2023-02-06,2023-02-14,,2023-04-04,Bata
3,confirmed,20,male,2023-02-19,2023-02-19,,2023-02-25,2023-04-04,Nsoc Nsomo
4,confirmed,99,female,2023-01-05,2023-01-05,2023-01-11,,2023-04-04,Nsoc Nsomo
5,probable,65,male,,2023-01-13,2023-01-19,,2023-04-04,Ebiebyin
6,confirmed,59,female,,2023-03-29,,2023-04-02,2023-04-04,Ebiebyin
7,confirmed,0,male,2023-02-11,2023-02-11,,2023-02-13,2023-04-04,Nsork
ID,Case_status,Age,Gender,Date_onset,Date_death,Date_of_first_consult,Data_up_to,Location_District
1,confirmed,50-55,male,2023-03-05,2023-03-09,,2023-04-04,Bata
2,probable,40-46,female,2023-02-06,2023-02-14,,2023-04-04,Bata
3,confirmed,20,male,2023-02-19,,2023-02-25,2023-04-04,Nsoc Nsomo
4,confirmed,99,female,2023-01-05,2023-01-11,,2023-04-04,Nsoc Nsomo
5,probable,65,male,2023-01-13,2023-01-19,,2023-04-04,Ebiebyin
6,confirmed,59,female,2023-03-29,,2023-04-02,2023-04-04,Ebiebyin
7,confirmed,0,male,2023-02-11,,2023-02-13,2023-04-04,Nsork
10 changes: 6 additions & 4 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

DATA = pd.read_csv(Path(__file__).with_name("test_data.csv"))

EXPECTED_TIMESERIES_LOCATION_STATUS = """Date_onset_estimated,daily_confirmed,daily_probable,cumulative_confirmed,cumulative_probable,Location_District
EXPECTED_TIMESERIES_LOCATION_STATUS = """Date_onset,daily_confirmed,daily_probable,cumulative_confirmed,cumulative_probable,Location_District
2023-02-06,0,1,0,1,Bata
2023-03-05,1,0,1,1,Bata
2023-01-13,0,1,0,1,Ebiebyin
Expand All @@ -33,7 +33,7 @@

@pytest.mark.parametrize(
"column,expected_delay_series",
[("Date_death", [4, 8, 6]), ("Date_of_first_consult", [6, 2])],
[("Date_death", [4, 8, 6, 6]), ("Date_of_first_consult", [6, 4, 2])],
)
def test_get_delays(column, expected_delay_series):
assert list(get_delays(DATA, column).dt.days) == expected_delay_series
Expand All @@ -53,10 +53,12 @@ def test_get_age_bin_data():


def test_get_epicurve():
epicurve = get_epicurve(DATA)
epicurve = get_epicurve(
DATA, date_col="Date_onset", groups=[("Case_status", ["confirmed", "probable"])]
)
assert (
epicurve.to_csv(index=False)
== """Date_onset_estimated,confirmed,probable
== """Date_onset,confirmed,probable
2023-01-05,1,0
2023-01-13,1,1
2023-02-06,1,2
Expand Down

0 comments on commit 1c2a470

Please sign in to comment.