Skip to content

Commit

Permalink
use parse csv data types function in util
Browse files Browse the repository at this point in the history
  • Loading branch information
mnjowe committed Dec 17, 2024
1 parent 32b0d35 commit a93cdea
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 33 deletions.
29 changes: 4 additions & 25 deletions src/tlo/methods/tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
It schedules TB treatment and follow-up appointments along with preventive therapy
for eligible people (HIV+ and paediatric contacts of active TB cases
"""
import ast
from functools import reduce
from typing import Any

import pandas as pd

Expand All @@ -17,33 +15,12 @@
from tlo.methods.dxmanager import DxTest
from tlo.methods.hsi_event import HSI_Event
from tlo.methods.symptommanager import Symptom
from tlo.util import random_date, read_csv_files
from tlo.util import parse_csv_values_for_columns_with_mixed_datatypes, random_date, read_csv_files

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def parse_csv_columns_with_mixed_datatypes(value: Any):
""" Normally, pd.read_csv parses all mixed datatypes column values as strings. Problems arise when you're trying to
directly use the output from this as you may expect data value in a particular datatype but all you get is a
string. This method seeks to address that by trying to parse values according to their best possible format.
Currently, it is part of this Module as I think this is the only module that affected by this behaviour
:param value: mixed datatype column value
"""
try:
return int(value) # Try to convert to int
except ValueError:
try:
return float(value) # Try to convert to float
except ValueError:
try:
return ast.literal_eval(value) # Try to convert to list
except ValueError:
return value # return an unconverted value


class Tb(Module):
"""Set up the baseline population with TB prevalence"""

Expand Down Expand Up @@ -932,7 +909,9 @@ def update_parameters_for_program_scaleup(self):
p = self.parameters
scaled_params_workbook = p["scaleup_parameters"]
for col in scaled_params_workbook.columns:
scaled_params_workbook[col] = scaled_params_workbook[col].apply(parse_csv_columns_with_mixed_datatypes)
scaled_params_workbook[col] = scaled_params_workbook[col].apply(
parse_csv_values_for_columns_with_mixed_datatypes
)

if p['type_of_scaleup'] == 'target':
scaled_params = scaled_params_workbook.set_index('parameter')['target_value'].to_dict()
Expand Down
15 changes: 7 additions & 8 deletions tests/test_htm_scaleup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
symptommanager,
tb,
)
from tlo.methods.tb import parse_csv_columns_with_mixed_datatypes
from tlo.util import read_csv_files
from tlo.util import parse_csv_values_for_columns_with_mixed_datatypes, read_csv_files

resourcefilepath = Path(os.path.dirname(__file__)) / "../resources"

Expand Down Expand Up @@ -57,7 +56,7 @@ def get_sim(seed):
def check_initial_params(sim):

original_params = read_csv_files(resourcefilepath / 'ResourceFile_HIV', files='parameters')
original_params.value = original_params.value.apply(parse_csv_columns_with_mixed_datatypes)
original_params.value = original_params.value.apply(parse_csv_values_for_columns_with_mixed_datatypes)

# check initial parameters
assert sim.modules["Hiv"].parameters["beta"] == \
Expand All @@ -77,7 +76,7 @@ def test_hiv_scale_up(seed):
and on correct date """

original_params = read_csv_files(resourcefilepath / 'ResourceFile_HIV', files="parameters")
original_params.value = original_params.value.apply(parse_csv_columns_with_mixed_datatypes)
original_params.value = original_params.value.apply(parse_csv_values_for_columns_with_mixed_datatypes)
new_params = read_csv_files(resourcefilepath / 'ResourceFile_HIV', files="scaleup_parameters")

popsize = 100
Expand Down Expand Up @@ -110,7 +109,7 @@ def test_hiv_scale_up(seed):
# check malaria parameters unchanged
mal_original_params = read_csv_files(resourcefilepath / 'malaria' / 'ResourceFile_malaria',
files="parameters")
mal_original_params.value = mal_original_params.value.apply(parse_csv_columns_with_mixed_datatypes)
mal_original_params.value = mal_original_params.value.apply(parse_csv_values_for_columns_with_mixed_datatypes)

mal_rdt_testing = read_csv_files(resourcefilepath / 'malaria' / 'ResourceFile_malaria',
files="WHO_TestData2023")
Expand All @@ -128,7 +127,7 @@ def test_hiv_scale_up(seed):

# check tb parameters unchanged
tb_original_params = read_csv_files(resourcefilepath / 'ResourceFile_TB', files="parameters")
tb_original_params.value = tb_original_params.value.apply(parse_csv_columns_with_mixed_datatypes)
tb_original_params.value = tb_original_params.value.apply(parse_csv_values_for_columns_with_mixed_datatypes)
tb_testing = read_csv_files(resourcefilepath / 'ResourceFile_TB', files="NTP2019")

pd.testing.assert_series_equal(sim.modules["Tb"].parameters["rate_testing_active_tb"]["treatment_coverage"],
Expand All @@ -151,7 +150,7 @@ def test_htm_scale_up(seed):

# Load data on HIV prevalence
original_hiv_params = read_csv_files(resourcefilepath / 'ResourceFile_HIV', files="parameters")
original_hiv_params.value = original_hiv_params.value.apply(parse_csv_columns_with_mixed_datatypes)
original_hiv_params.value = original_hiv_params.value.apply(parse_csv_values_for_columns_with_mixed_datatypes)
new_hiv_params = read_csv_files(resourcefilepath / 'ResourceFile_HIV', files="scaleup_parameters")

popsize = 100
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_htm_scale_up(seed):

# check tb parameters changed
new_tb_params = read_csv_files(resourcefilepath / 'ResourceFile_TB', files="scaleup_parameters")
new_tb_params.target_value = new_tb_params.target_value.apply(parse_csv_columns_with_mixed_datatypes)
new_tb_params.target_value = new_tb_params.target_value.apply(parse_csv_values_for_columns_with_mixed_datatypes)

assert sim.modules["Tb"].parameters["rate_testing_active_tb"]["treatment_coverage"].eq(new_tb_params.loc[
new_tb_params.parameter == "tb_treatment_coverage", "target_value"].values[0]).all()
Expand Down

0 comments on commit a93cdea

Please sign in to comment.