Skip to content

Commit

Permalink
added data_service attr where needed
Browse files Browse the repository at this point in the history
  • Loading branch information
lrdossan committed Dec 11, 2023
1 parent a254575 commit 349a04e
Show file tree
Hide file tree
Showing 15 changed files with 145 additions and 118 deletions.
22 changes: 11 additions & 11 deletions caimira/apps/calculator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class ConcentrationModel(BaseRequestHandler):
async def post(self) -> None:
debug = self.settings.get("debug", False)

data_registry = self.settings.get("data_registry")
data_service = self.settings.get("data_service")
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)

Expand Down Expand Up @@ -159,10 +159,10 @@ async def post(self) -> None:
"""
debug = self.settings.get("debug", False)

data_registry = self.settings.get("data_registry")
data_service = self.settings.get("data_service")
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_configuration(data_registry)
data_service.update_registry(data_registry)

requested_model_config = json.loads(self.request.body)
LOG.debug(pformat(requested_model_config))
Expand Down Expand Up @@ -190,10 +190,10 @@ class StaticModel(BaseRequestHandler):
async def get(self) -> None:
debug = self.settings.get("debug", False)

data_registry = self.settings.get("data_registry")
data_service = self.settings.get("data_service")
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_configuration(data_registry)
data_service.update_registry(data_registry)

form = model_generator.VirusFormData.from_dict(model_generator.baseline_raw_form_data(), data_registry)
base_url = self.request.protocol + "://" + self.request.host
Expand Down Expand Up @@ -368,10 +368,10 @@ def check_xsrf_cookie(self):
pass

async def post(self, endpoint: str) -> None:
data_registry = self.settings.get("data_registry")
data_service = self.settings.get("data_service")
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_configuration(data_registry)
data_service.update_registry(data_registry)

requested_model_config = tornado.escape.json_decode(self.request.body)
try:
Expand Down
1 change: 0 additions & 1 deletion caimira/apps/calculator/form_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import datetime
import html
import logging
import typing
Expand Down
9 changes: 4 additions & 5 deletions caimira/apps/calculator/report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,17 @@ def conditional_prob_inf_given_vl_dist(


def manufacture_conditional_probability_data(
data_registry: DataRegistry,
exposure_model: models.ExposureModel,
infection_probability: models._VectorisedFloat
):

data_registry: DataRegistry = exposure_model.data_registry

min_vl = data_registry.conditional_prob_inf_given_viral_load['min_vl']
max_vl = data_registry.conditional_prob_inf_given_viral_load['max_vl']
step = (max_vl - min_vl)/100
viral_loads = np.arange(min_vl, max_vl, step)
specific_vl = np.log10(exposure_model.concentration_model.virus.viral_load_in_sputum)
pi_means, lower_percentiles, upper_percentiles = conditional_prob_inf_given_vl_dist(infection_probability, viral_loads,
pi_means, lower_percentiles, upper_percentiles = conditional_prob_inf_given_vl_dist(data_registry, infection_probability, viral_loads,
specific_vl, step)

return list(viral_loads), list(pi_means), list(lower_percentiles), list(upper_percentiles)
Expand Down Expand Up @@ -414,12 +414,11 @@ def manufacture_alternative_scenarios(form: VirusFormData) -> typing.Dict[str, m


def scenario_statistics(
data_registry: DataRegistry,
mc_model: mc.ExposureModel,
sample_times: typing.List[float],
compute_prob_exposure: bool
):
model = mc_model.build_model(size=data_registry.monte_carlo_sample_size)
model = mc_model.build_model(size=mc_model.data_registry.monte_carlo_sample_size)
if (compute_prob_exposure):
# It means we have data to calculate the total_probability_rule
prob_probabilistic_exposure = model.total_probability_rule()
Expand Down
12 changes: 7 additions & 5 deletions caimira/monte_carlo/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def covid_overal_vl_data(data_registry):
function=lambda d: np.interp(
d,
viral_load(data_registry),
frequencies_pdf,
frequencies_pdf(data_registry),
data_registry.covid_overal_vl_data['interpolation_fp_left'],
data_registry.covid_overal_vl_data['interpolation_fp_right']
),
Expand Down Expand Up @@ -441,22 +441,24 @@ def expiration_BLO_factors(data_registry):
def expiration_distributions(data_registry):
return {
exp_type: expiration_distribution(
BLO_factors,
data_registry=data_registry,
BLO_factors=BLO_factors,
d_min=param_evaluation(data_registry.long_range_expiration_distributions, 'minimum_diameter'),
d_max=param_evaluation(data_registry.long_range_expiration_distributions, 'maximum_diameter')
)
for exp_type, BLO_factors in expiration_BLO_factors.items()
for exp_type, BLO_factors in expiration_BLO_factors(data_registry).items()
}


def short_range_expiration_distributions(data_registry):
return {
exp_type: expiration_distribution(
BLO_factors,
data_registry=data_registry,
BLO_factors=BLO_factors,
d_min=param_evaluation(data_registry.short_range_expiration_distributions, 'minimum_diameter'),
d_max=param_evaluation(data_registry.short_range_expiration_distributions, 'maximum_diameter')
)
for exp_type, BLO_factors in expiration_BLO_factors.items()
for exp_type, BLO_factors in expiration_BLO_factors(data_registry).items()
}


Expand Down
4 changes: 2 additions & 2 deletions caimira/store/data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ class DataService:

def __init__(
self,
credentials: typing.Dict[str, str],
credentials: typing.Dict[str, typing.Optional[str]],
host: str,
):
self._credentials = credentials
self._host = host

@classmethod
def create(cls, credentials: typing.Dict[str, str], host: str = "https://caimira-data-api.app.cern.ch"):
def create(cls, credentials: typing.Dict[str, typing.Optional[str]], host: str = "https://caimira-data-api.app.cern.ch"):
"""Factory."""
return cls(credentials, host)

Expand Down
16 changes: 8 additions & 8 deletions caimira/tests/apps/calculator/test_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_infected_less_than_total_people(activity, total_people, infected_people
baseline_form.total_people = total_people
baseline_form.infected_people = infected_people
with pytest.raises(ValueError, match=error):
baseline_form.validate(data_registry)
baseline_form.validate()


def present_times(interval: models.Interval) -> models.BoundarySequence_t:
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_exposed_present_lunch_end_before_beginning(baseline_form: model_generat
baseline_form.exposed_lunch_start = minutes_since_midnight(14 * 60)
baseline_form.exposed_lunch_finish = minutes_since_midnight(13 * 60)
with pytest.raises(ValueError):
baseline_form.validate(data_registry)
baseline_form.validate()


@pytest.mark.parametrize(
Expand All @@ -291,7 +291,7 @@ def test_exposed_presence_lunch_break(baseline_form: model_generator.VirusFormDa
baseline_form.exposed_lunch_start = minutes_since_midnight(exposed_lunch_start * 60)
baseline_form.exposed_lunch_finish = minutes_since_midnight(exposed_lunch_finish * 60)
with pytest.raises(ValueError, match='exposed lunch break must be within presence times.'):
baseline_form.validate(data_registry)
baseline_form.validate()


@pytest.mark.parametrize(
Expand All @@ -307,7 +307,7 @@ def test_infected_presence_lunch_break(baseline_form: model_generator.VirusFormD
baseline_form.infected_lunch_start = minutes_since_midnight(infected_lunch_start * 60)
baseline_form.infected_lunch_finish = minutes_since_midnight(infected_lunch_finish * 60)
with pytest.raises(ValueError, match='infected lunch break must be within presence times.'):
baseline_form.validate(data_registry)
baseline_form.validate()


def test_exposed_breaks_length(baseline_form: model_generator.VirusFormData, data_registry: DataRegistry):
Expand All @@ -317,7 +317,7 @@ def test_exposed_breaks_length(baseline_form: model_generator.VirusFormData, dat
baseline_form.exposed_finish = minutes_since_midnight(11 * 60)
baseline_form.exposed_lunch_option = False
with pytest.raises(ValueError, match='Length of breaks >= Length of exposed presence.'):
baseline_form.validate(data_registry)
baseline_form.validate()


def test_infected_breaks_length(baseline_form: model_generator.VirusFormData, data_registry: DataRegistry):
Expand All @@ -328,7 +328,7 @@ def test_infected_breaks_length(baseline_form: model_generator.VirusFormData, da
baseline_form.infected_coffee_break_option = 'coffee_break_4'
baseline_form.infected_coffee_duration = 30
with pytest.raises(ValueError, match='Length of breaks >= Length of infected presence.'):
baseline_form.validate(data_registry)
baseline_form.validate()


@pytest.fixture
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_valid_no_lunch(baseline_form: model_generator.VirusFormData, data_regis
baseline_form.exposed_lunch_option = False
baseline_form.exposed_lunch_start = minutes_since_midnight(0)
baseline_form.exposed_lunch_finish = minutes_since_midnight(0)
assert baseline_form.validate(data_registry) is None
assert baseline_form.validate() is None


def test_no_breaks(baseline_form: model_generator.VirusFormData):
Expand Down Expand Up @@ -516,7 +516,7 @@ def test_natural_ventilation_window_opening_periodically(baseline_form: model_ge
baseline_form.windows_duration = 20
baseline_form.windows_frequency = 10
with pytest.raises(ValueError, match='Duration cannot be bigger than frequency.'):
baseline_form.validate(data_registry)
baseline_form.validate()


def test_key_validation_mech_ventilation_type_na(baseline_form_data, data_registry):
Expand Down
10 changes: 5 additions & 5 deletions caimira/tests/apps/calculator/test_specific_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def test_specific_break_structure(break_input, error, baseline_form: model_generator.VirusFormData, data_registry: DataRegistry):
baseline_form.specific_breaks = break_input
with pytest.raises(TypeError, match=error):
baseline_form.validate(data_registry)
baseline_form.validate()


@pytest.mark.parametrize(
Expand All @@ -34,7 +34,7 @@ def test_specific_break_structure(break_input, error, baseline_form: model_gener
def test_specific_population_break_data_structure(population_break_input, error, baseline_form: model_generator.VirusFormData, data_registry: DataRegistry):
baseline_form.specific_breaks = {'exposed_breaks': population_break_input, 'infected_breaks': population_break_input}
with pytest.raises(TypeError, match=error):
baseline_form.validate(data_registry)
baseline_form.validate()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_specific_break_time(break_input, error, baseline_form: model_generator.
def test_precise_activity_structure(precise_activity_input, error, baseline_form: model_generator.VirusFormData, data_registry: DataRegistry):
baseline_form.precise_activity = precise_activity_input
with pytest.raises(TypeError, match=error):
baseline_form.validate(data_registry)
baseline_form.validate()


@pytest.mark.parametrize(
Expand All @@ -79,7 +79,7 @@ def test_precise_activity_structure(precise_activity_input, error, baseline_form
[{"physical_activity": "Light activity", "respiratory_activity": [{"type": "Breathing", "percentage": 50}]}, 'The sum of all respiratory activities should be 100. Got 50.'],
]
)
def test_sum_precise_activity(precise_activity_input, error, baseline_form: model_generator.VirusFormData, data_registry: DataRegistry):
def test_sum_precise_activity(precise_activity_input, error, baseline_form: model_generator.VirusFormData):
baseline_form.precise_activity = precise_activity_input
with pytest.raises(ValueError, match=error):
baseline_form.validate(data_registry)
baseline_form.validate()
8 changes: 7 additions & 1 deletion caimira/tests/models/test_concentration_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy.testing as npt
import pytest
from dataclasses import dataclass
import typing

from caimira import models
from caimira.store.data_registry import DataRegistry

@dataclass(frozen=True)
class KnownConcentrationModelBase(models._ConcentrationModelBase):
Expand Down Expand Up @@ -198,12 +198,14 @@ def test_integrated_concentration(simple_conc_model):
]
)
def test_normed_integrated_concentration_with_background_concentration(
data_registry: DataRegistry,
simple_conc_model: models.ConcentrationModel,
dummy_population: models.Population,
known_min_background_concentration: float,
expected_normed_integrated_concentration: float):

known_conc_model = KnownConcentrationModelBase(
data_registry,
room = simple_conc_model.room,
ventilation = simple_conc_model.ventilation,
known_population = dummy_population,
Expand All @@ -229,6 +231,7 @@ def test_normed_integrated_concentration_with_background_concentration(
]
)
def test_normed_integrated_concentration_vectorisation(
data_registry: DataRegistry,
simple_conc_model: models.ConcentrationModel,
dummy_population: models.Population,
known_removal_rate: float,
Expand All @@ -237,6 +240,7 @@ def test_normed_integrated_concentration_vectorisation(
expected_normed_integrated_concentration: float):

known_conc_model = KnownConcentrationModelBase(
data_registry = data_registry,
room = simple_conc_model.room,
ventilation = simple_conc_model.ventilation,
known_population = dummy_population,
Expand Down Expand Up @@ -264,13 +268,15 @@ def test_normed_integrated_concentration_vectorisation(
]
)
def test_zero_ventilation_rate(
data_registry: DataRegistry,
simple_conc_model: models.ConcentrationModel,
dummy_population: models.Population,
known_removal_rate: float,
known_min_background_concentration: float,
expected_concentration: float):

known_conc_model = KnownConcentrationModelBase(
data_registry = data_registry,
room = simple_conc_model.room,
ventilation = simple_conc_model.ventilation,
known_population = dummy_population,
Expand Down
5 changes: 4 additions & 1 deletion caimira/tests/models/test_dynamic_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@pytest.fixture
def full_exposure_model(data_registry):
return models.ExposureModel(
data_registry=data_registry,
concentration_model=models.ConcentrationModel(
data_registry=data_registry,
room=models.Room(volume=100),
Expand All @@ -25,6 +26,7 @@ def full_exposure_model(data_registry):
virus=models.Virus.types['SARS_CoV_2'],
host_immunity=0.
),
evaporation_factor=0.3,
),
short_range=(),
exposed=models.Population(
Expand Down Expand Up @@ -148,12 +150,13 @@ def test_linearity_with_number_of_infected(full_exposure_model: models.ExposureM
@pytest.mark.parametrize(
"time", (8., 9., 10., 11., 12., 13., 14.),
)
def test_dynamic_dose(full_exposure_model: models.ExposureModel, time: float):
def test_dynamic_dose(data_registry, full_exposure_model: models.ExposureModel, time: float):

dynamic_infected: models.ExposureModel = dc_utils.nested_replace(
full_exposure_model,
{
'concentration_model.infected': models.InfectedPopulation(
data_registry=data_registry,
number=models.IntPiecewiseConstant(
(8, 10, 12, 13, 17), (1, 2, 0, 3)),
presence=None,
Expand Down
Loading

0 comments on commit 349a04e

Please sign in to comment.