Skip to content

Commit

Permalink
plots: add date and groupby column customisation in get_epicurve()
Browse files Browse the repository at this point in the history
  • Loading branch information
abhidg committed Aug 29, 2024
1 parent 13b0839 commit 84b9f5a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 66 deletions.
10 changes: 9 additions & 1 deletion src/obr/outbreaks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@

outbreak_marburg = [
("data", get_counts, dict(date_col="Data_up_to")),
("figure/epicurve", plot_epicurve),
(
"figure/epicurve",
plot_epicurve,
dict(
title="Date of symptom onset",
date_col="Date_onset_estimated",
groupby_col="Case_status",
),
),
(
"figure/epicurve_location_status",
plot_timeseries_location_status,
Expand Down
102 changes: 58 additions & 44 deletions src/obr/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Library of plots used in most outbreaks
"""

import re
import logging

import pandas as pd
Expand All @@ -12,11 +11,17 @@
import plotly.express as px
from plotly.subplots import make_subplots

from .util import percentage_occurrence, name_bin, AGE_BINS, get_age_bins
from .util import (
percentage_occurrence,
name_bin,
AGE_BINS,
get_age_bins,
non_null_unique,
)
from .theme import (
REGEX_DATE,
FONT,
TITLE_FONT,
PALETTE,
LEGEND_FONT_SIZE,
BLUE_PRIMARY_COLOR,
PRIMARY_COLOR,
Expand All @@ -26,6 +31,8 @@
GRID_COLOR,
)

REGEX_DATE = r"^202\d-[0,1]\d-[0-3]\d"

pd.options.mode.chained_assignment = None


Expand Down Expand Up @@ -69,32 +76,39 @@ def get_delays(
return both[target_col] - both[onset_col]


def get_epicurve(df: pd.DataFrame, cumulative: bool = True) -> pd.DataFrame:
"""Returns epidemic curve - number of cases by (estimated) date of symptom onset"""
df["Date_onset_estimated"] = df.Date_onset_estimated.map(
lambda x: (
pd.to_datetime(x)
if isinstance(x, str) and re.match(REGEX_DATE, x)
else None
)
)

def get_epicurve(
df: pd.DataFrame,
date_col: str,
groupby_col: str,
values: list[str] | None = None,
cumulative: bool = True,
) -> pd.DataFrame:
"""Returns epidemic curve
Parameters
----------
df
Data from which epicurve is obtained
date_col
Date column to use
groupby_col
Column to group by, e.g. Case_status
values
Values of the column to plot, e.g. ['confirmed', 'probable']
cumulative
Whether to return cumulative counts (default = true)
"""
values = non_null_unique(df[groupby_col]) if values is None else values
epicurve = (
df[
~pd.isna(df.Date_onset_estimated)
& df.Case_status.isin(["confirmed", "probable"])
]
.groupby(["Date_onset_estimated", "Case_status"])
df[~pd.isna(df[date_col]) & df[groupby_col].isin(values)]
.groupby([date_col, groupby_col])
.size()
.reset_index()
.pivot(index="Date_onset_estimated", columns="Case_status", values=0)
.pivot(index=date_col, columns=groupby_col, values=0)
.fillna(0)
.astype(int)
)
if cumulative:
epicurve["confirmed"] = epicurve.confirmed.cumsum()
epicurve["probable"] = epicurve.probable.cumsum()
return epicurve.reset_index()
return epicurve.cumsum() if cumulative else epicurve


def get_counts(df: pd.DataFrame, date_col: str) -> dict[str, int]:
Expand Down Expand Up @@ -229,31 +243,31 @@ def plot_timeseries_location_status(
return fig


def plot_epicurve(df: pd.DataFrame, non_confirmed_col: str, cumulative: bool = True):
data = get_epicurve(df, cumulative=cumulative)
def plot_epicurve(
df: pd.DataFrame,
title: str,
date_col: str,
groupby_col: str,
values: list[str] | None = None,
cumulative: bool = True,
palette: list[str] = PALETTE,
):
values = non_null_unique(df[groupby_col]) if values is None else values
data = get_epicurve(df, date_col, groupby_col, values, cumulative=cumulative)
fig = go.Figure()

fig.add_trace(
go.Scatter(
x=data.Date_onset_estimated,
y=data.confirmed,
name="confirmed",
line_color=PRIMARY_COLOR,
line_width=3,
),
)
fig.add_trace(
go.Scatter(
x=data.Date_onset_estimated,
y=data[non_confirmed_col],
name="probable",
line_color=SECONDARY_COLOR,
line_width=3,
for idx, value in enumerate(values):
fig.add_trace(
go.Scatter(
x=data.index,
y=data[value],
name=value,
line_color=palette[idx],
line_width=3,
),
)
)

fig.update_xaxes(
title_text="Date of symptom onset",
title_text=title,
title_font_family=TITLE_FONT,
title_font_color=FG_COLOR,
gridcolor=GRID_COLOR,
Expand Down
17 changes: 15 additions & 2 deletions src/obr/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Theme and color settings for obr
"""

REGEX_DATE = r"^202\d-[0,1]\d-[0-3]\d"
OVERRIDES = {}
FONT = "Inter"
TITLE_FONT = "mabry-regular-pro"
LEGEND_FONT_SIZE = 13
Expand All @@ -15,3 +13,18 @@
BG_COLOR = "#ECF3F0"
FG_COLOR = "#1E1E1E"
GRID_COLOR = "#DEDEDE"

PALETTE = [
PRIMARY_COLOR,
SECONDARY_COLOR,
"#FFD700", # Gold
"#FFAA00", # Amber
"#FFCC33", # Saffron
"#F0E68C", # Khaki
"#FFFF00", # Yellow
"#00BFFF", # Deep Sky Blue
"#1E90FF", # Dodger Blue
"#6495ED", # Cornflower Blue
"#87CEEB", # Sky Blue
"#4682B4", # Steel Blue
]
34 changes: 28 additions & 6 deletions src/obr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@
Briefing report generator for Marburg 2023 outbreak
"""

import re
import logging
import datetime
from pathlib import Path
from typing import Any, Callable
from functools import cache

import boto3
import chevron
import pandas as pd
import numpy as np
from dateutil.parser import ParserError
import plotly.graph_objects as go
import plotly.io
import plotly.express as px
from plotly.subplots import make_subplots

PlotFunction = Callable[..., dict[str, Any] | go.Figure]
PlotData = tuple[str, PlotFunction, dict[str, Any]]
Expand All @@ -34,6 +30,25 @@
(70, 79),
(80, 120),
]
REGEX_DATE = r"^202\d-[0,1]\d-[0-3]\d"


def non_null_unique(arr: pd.Series) -> pd.Series:
uniq = arr.unique()
return uniq[~pd.isna(uniq)]


def fix_datetimes(df: pd.DataFrame):
"Convert date fields to datetime in place"
date_columns = [c for c in df.columns if c.startswith("Date_") or "Date " in c]
for date_col in date_columns:
df[date_col] = df[date_col].map(
lambda x: (
pd.to_datetime(x)
if isinstance(x, str) and re.match(REGEX_DATE, x)
else None
)
)


def get_age_bins(age: str) -> range:
Expand Down Expand Up @@ -114,6 +129,13 @@ def invalidate_cache(
raise


def read_csv(filename: str) -> pd.DataFrame:
"Helper function with post-processing steps after pd.read_csv"
df = pd.read_csv(filename, dtype=str, na_values=["N/K", "NK"])
fix_datetimes(df)
return df


def build(
outbreak_name: str,
data_url: str,
Expand All @@ -128,7 +150,7 @@ def build(
if not (template := Path(__file__).parent / "outbreaks" / output_file).exists():
raise FileNotFoundError(f"Template for outbreak not found at: {template}")
var = {"published_date": str(date)}
df = pd.read_csv(data_url, na_values=["NK", "N/K"])
df = read_csv(data_url)
for plot in plots:
kwargs = {} if len(plot) == 2 else plot[2]
if plot[0] == "data":
Expand Down
16 changes: 8 additions & 8 deletions tests/test_data.csv
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ID,Case_status,Age,Gender,Date_onset,Date_onset_estimated,Date_death,Date_of_first_consult,Data_up_to,Location_District
1,confirmed,50-55,male,2023-03-05,2023-03-05,2023-03-09,,2023-04-04,Bata
2,probable,40-46,female,2023-02-06,2023-02-06,2023-02-14,,2023-04-04,Bata
3,confirmed,20,male,2023-02-19,2023-02-19,,2023-02-25,2023-04-04,Nsoc Nsomo
4,confirmed,99,female,2023-01-05,2023-01-05,2023-01-11,,2023-04-04,Nsoc Nsomo
5,probable,65,male,,2023-01-13,2023-01-19,,2023-04-04,Ebiebyin
6,confirmed,59,female,,2023-03-29,,2023-04-02,2023-04-04,Ebiebyin
7,confirmed,0,male,2023-02-11,2023-02-11,,2023-02-13,2023-04-04,Nsork
ID,Case_status,Age,Gender,Date_onset,Date_death,Date_of_first_consult,Data_up_to,Location_District
1,confirmed,50-55,male,2023-03-05,2023-03-09,,2023-04-04,Bata
2,probable,40-46,female,2023-02-06,2023-02-14,,2023-04-04,Bata
3,confirmed,20,male,2023-02-19,,2023-02-25,2023-04-04,Nsoc Nsomo
4,confirmed,99,female,2023-01-05,2023-01-11,,2023-04-04,Nsoc Nsomo
5,probable,65,male,2023-01-13,2023-01-19,,2023-04-04,Ebiebyin
6,confirmed,59,female,2023-03-29,,2023-04-02,2023-04-04,Ebiebyin
7,confirmed,0,male,2023-02-11,,2023-02-13,2023-04-04,Nsork
13 changes: 8 additions & 5 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@pytest.mark.parametrize(
"column,expected_delay_series",
[("Date_death", [4, 8, 6]), ("Date_of_first_consult", [6, 2])],
[("Date_death", [4, 8, 6, 6]), ("Date_of_first_consult", [6, 4, 2])],
)
def test_get_delays(column, expected_delay_series):
assert list(get_delays(DATA, column).dt.days) == expected_delay_series
Expand All @@ -53,10 +53,12 @@ def test_get_age_bin_data():


def test_get_epicurve():
epicurve = get_epicurve(DATA)
epicurve = get_epicurve(
DATA, "Date_onset", "Case_status", ["confirmed", "probable"]
)
assert (
epicurve.to_csv(index=False)
== """Date_onset_estimated,confirmed,probable
epicurve.to_csv()
== """Date_onset,confirmed,probable
2023-01-05,1,0
2023-01-13,1,1
2023-02-06,1,2
Expand All @@ -79,7 +81,8 @@ def test_get_counts():


def test_get_timeseries_location_status():
data = DATA.rename(columns={"Date_onset": "Date_onset_estimated"})
assert (
get_timeseries_location_status(DATA).to_csv(index=False)
get_timeseries_location_status(data).to_csv(index=False)
== EXPECTED_TIMESERIES_LOCATION_STATUS
)

0 comments on commit 84b9f5a

Please sign in to comment.