Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid use of apply in computing infection status in schisto module #683

Merged
merged 1 commit into from
Aug 10, 2022

Conversation

matt-graham
Copy link
Collaborator

The profiling described in #286 (comment) suggested that the use of the DataFrame.apply method in the lines

correct_status = df.loc[idx].apply(
lambda x: _inf_status(x['age_years'], x[prop('aggregate_worm_burden')]),
axis=1
)

in the SchistoSpecies.update_infectious_status_and_symptoms method may be a bottleneck, due to the resulting Python iteration over each row in the (filtered) population dataframe.

This PR replaces the apply call by instead creating a Series object directly and using boolean indexing to implement the corresponding logic while avoiding the row-by-row iteration.

Doing a basic microbenchmark of applying the new boolean indicing based _get_infection_status function to a subset of randomly generated population dataframe with 50000 rows compared to the previous apply based approach suggests the new approach should be significantly quicker:

                   using apply times: min = 439.10ms, max = 513.39ms
        using boolean indexing times: min = 5.33ms, max = 7.30ms
Microbenchmark code
import timeit
import pandas as pd
import numpy as np

seed = 20220809
population_size = 50_000
n_repeat = 7
n_iter = 10

rng = np.random.RandomState(seed)

categories = ("Non-infected", "Low-infection", "High-infection")

population_dataframe = pd.DataFrame(
    {
        "is_alive": rng.choice((True, False), size=population_size),
        "age_years": rng.randint(0, 100, size=population_size),
        "aggregate_worm_burden": rng.uniform(size=population_size),
        "infection_status": rng.choice(categories, size=population_size),
    }
)

params = {
    "low_intensity_threshold": 0.1, 
    "high_intensity_threshold_PSAC": 0.5,
    "high_intensity_threshold": 0.7
}


def _get_infection_status(population: pd.DataFrame) -> pd.Series:
    age = population["age_years"]
    agg_wb = population["aggregate_worm_burden"]
    status = pd.Series(
        "Non-infected",
        index=population.index,
        dtype=population["infection_status"].dtype
    )
    high_group = (
        (age < 5) & (agg_wb >= params["high_intensity_threshold_PSAC"])
    ) | (agg_wb >= params["high_intensity_threshold"])
    low_group = ~high_group & (agg_wb >= params["low_intensity_threshold"])
    status[high_group] = "High-infection"
    status[low_group] = "Low-infection"
    return status


def _inf_status(age_years: int, agg_wb: int) -> str:
    if age_years < 5:
        if agg_wb >= params['high_intensity_threshold_PSAC']:
            return 'High-infection'

    if agg_wb >= params['high_intensity_threshold']:
        return 'High-infection'

    if agg_wb >= params['low_intensity_threshold']:
        return 'Low-infection'

    return 'Non-infected'


def print_results_string(results_dict, key, num_iter):
    times = np.array([t / num_iter for t in results_dict[key]])
    print(
        f"{key:>30} times: min = {times.min() * 1000:.2f}ms, "
        f"max = {times.max() * 1000:.2f}ms"
    )


assert (
    _get_infection_status(population_dataframe[population_dataframe.is_alive])
    == population_dataframe[population_dataframe.is_alive].apply(
        lambda x: _inf_status(x.age_years, x.aggregate_worm_burden), axis=1
    )
).all()

results = {}

for key, func in {
    "using apply": lambda: population_dataframe[population_dataframe.is_alive].apply(
        lambda x: _inf_status(x.age_years, x.aggregate_worm_burden), axis=1
    ),
    "using boolean indexing": lambda: _get_infection_status(
        population_dataframe[population_dataframe.is_alive]
    )
}.items():
    results[key] = timeit.repeat(func, repeat=n_repeat, number=n_iter)
    print_results_string(results, key, n_iter)

@matt-graham
Copy link
Collaborator Author

/run scaled-sim

@matt-graham matt-graham requested review from tamuri and tbhallett August 9, 2022 13:26
@matt-graham
Copy link
Collaborator Author

Realised the above microbenchmark didn't correctly set the dtype of the infection_status column to be categorical - updating the benchmark code to correct this gives the following results

                   using apply times: min = 417.55ms, max = 491.83ms
        using boolean indexing times: min = 14.03ms, max = 19.38ms

which surprisingly suggests applying the operations on a categorical datatype column are quite a lot slower (though still a lot quicker than apply based approach). This perhaps suggests we should drop the dtype argument in the Series initialiser in _get_infection_status and just use the default type as Pandas will still allow comparing with the (categorical-dtyped) infection status column (the current apply approach does not create a categorical-dtyped Series)

Updated microbenchmark code
import timeit
import pandas as pd
import numpy as np

seed = 20220809
population_size = 50_000
n_repeat = 7
n_iter = 10

rng = np.random.RandomState(seed)

categories = ("Non-infected", "Low-infection", "High-infection")

population_dataframe = pd.DataFrame(
    {
        "is_alive": rng.choice((True, False), size=population_size),
        "age_years": rng.randint(0, 100, size=population_size),
        "aggregate_worm_burden": rng.uniform(size=population_size),
        "infection_status": pd.Series(
            rng.choice(categories, size=population_size),
            dtype=pd.CategoricalDtype(categories=categories),
        )
    }
)

params = {
    "low_intensity_threshold": 0.1,
    "high_intensity_threshold_PSAC": 0.5,
    "high_intensity_threshold": 0.7
}


def _get_infection_status(population: pd.DataFrame) -> pd.Series:
    age = population["age_years"]
    agg_wb = population["aggregate_worm_burden"]
    status = pd.Series(
        "Non-infected",
        index=population.index,
        dtype=population["infection_status"].dtype,
    )
    high_group = (
        (age < 5) & (agg_wb >= params["high_intensity_threshold_PSAC"])
    ) | (agg_wb >= params["high_intensity_threshold"])
    low_group = ~high_group & (agg_wb >= params["low_intensity_threshold"])
    status[high_group] = "High-infection"
    status[low_group] = "Low-infection"
    return status


def _inf_status(age_years: int, agg_wb: int) -> str:
    if age_years < 5:
        if agg_wb >= params['high_intensity_threshold_PSAC']:
            return 'High-infection'

    if agg_wb >= params['high_intensity_threshold']:
        return 'High-infection'

    if agg_wb >= params['low_intensity_threshold']:
        return 'Low-infection'

    return 'Non-infected'


def print_results_string(results_dict, key, num_iter):
    times = np.array([t / num_iter for t in results_dict[key]])
    print(
        f"{key:>30} times: min = {times.min() * 1000:.2f}ms, "
        f"max = {times.max() * 1000:.2f}ms"
    )


assert (
    _get_infection_status(population_dataframe[population_dataframe.is_alive])
    == population_dataframe[population_dataframe.is_alive].apply(
        lambda x: _inf_status(x.age_years, x.aggregate_worm_burden), axis=1
    )
).all()

results = {}

for key, func in {
    "using apply": lambda: population_dataframe[population_dataframe.is_alive].apply(
        lambda x: _inf_status(x.age_years, x.aggregate_worm_burden), axis=1
    ),
    "using boolean indexing": lambda: _get_infection_status(
        population_dataframe[population_dataframe.is_alive]
    )
}.items():
    results[key] = timeit.repeat(func, repeat=n_repeat, number=n_iter)
    print_results_string(results, key, n_iter)

@tamuri tamuri merged commit 2a675cb into master Aug 10, 2022
@tamuri tamuri deleted the mmg/schisto-infection-status-optimization branch August 10, 2022 11:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants