Skip to content

Commit

Permalink
Merge branch 'feature/inject-data-service' into 'master'
Browse files Browse the repository at this point in the history
inject data registry instance instead of static config

Closes #365

See merge request caimira/caimira!476
  • Loading branch information
andrejhenriques committed Dec 20, 2023
2 parents 17a5fa6 + ddecd91 commit 45b81b1
Show file tree
Hide file tree
Showing 38 changed files with 1,272 additions and 990 deletions.
2 changes: 1 addition & 1 deletion app-config/auth-service/auth_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tornado.web import Application, RequestHandler
import tornado.log

LOG = logging.getLogger(__name__)
LOG = logging.getLogger("AUTH")


class BaseHandler(RequestHandler):
Expand Down
115 changes: 75 additions & 40 deletions caimira/apps/calculator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import html
import json
import pandas as pd
from pprint import pformat
from io import StringIO
import os
from pathlib import Path
Expand All @@ -24,6 +25,9 @@
from tornado.web import Application, RequestHandler, StaticFileHandler
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
import tornado.log
from caimira.store.data_registry import DataRegistry

from caimira.store.data_service import DataService

from . import markdown_tools
from . import model_generator, co2_model_generator
Expand All @@ -37,13 +41,13 @@
# calculator version. If the calculator needs to make breaking changes (e.g. change
# form attributes) then it can also increase its MAJOR version without needing to
# increase the overall CAiMIRA version (found at ``caimira.__version__``).
__version__ = "4.14.2"
__version__ = "4.14.3"

LOG = logging.getLogger("APP")

LOG = logging.getLogger(__name__)


class BaseRequestHandler(RequestHandler):

async def prepare(self):
"""Called at the beginning of a request before `get`/`post`/etc."""

Expand Down Expand Up @@ -97,20 +101,22 @@ async def prepare(self):

class ConcentrationModel(BaseRequestHandler):
async def post(self) -> None:
debug = self.settings.get("debug", False)

data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)

requested_model_config = {
name: self.get_argument(name) for name in self.request.arguments
}
if self.settings.get("debug", False):
from pprint import pprint
pprint(requested_model_config)
start = datetime.datetime.now()

LOG.debug(pformat(requested_model_config))

try:
form = model_generator.VirusFormData.from_dict(requested_model_config)
form = model_generator.VirusFormData.from_dict(requested_model_config, data_registry)
except Exception as err:
if self.settings.get("debug", False):
import traceback
print(traceback.format_exc())
LOG.exception(err)
response_json = {'code': 400, 'error': f'Your request was invalid {html.escape(str(err))}'}
self.set_status(400)
self.finish(json.dumps(response_json))
Expand All @@ -126,7 +132,7 @@ async def post(self) -> None:
if self.get_cookie('conditional_plot'):
form.conditional_probability_plot = True if self.get_cookie('conditional_plot') == '1' else False
self.clear_cookie('conditional_plot') # Clears cookie after changing the form value.

report_task = executor.submit(
report_generator.build_report, base_url, form,
executor_factory=functools.partial(
Expand All @@ -151,17 +157,20 @@ async def post(self) -> None:
Expects algorithm input in HTTP POST request body in JSON format.
Returns report data (algorithm output) in HTTP POST response body in JSON format.
"""
debug = self.settings.get("debug", False)

data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)

requested_model_config = json.loads(self.request.body)
if self.settings.get("debug", False):
from pprint import pprint
pprint(requested_model_config)
LOG.debug(pformat(requested_model_config))

try:
form = model_generator.VirusFormData.from_dict(requested_model_config)
form = model_generator.VirusFormData.from_dict(requested_model_config, data_registry)
except Exception as err:
if self.settings.get("debug", False):
import traceback
print(traceback.format_exc())
LOG.exception(err)
response_json = {'code': 400, 'error': f'Your request was invalid {html.escape(str(err))}'}
self.set_status(400)
await self.finish(json.dumps(response_json))
Expand All @@ -171,14 +180,22 @@ async def post(self) -> None:
max_workers=self.settings['handler_worker_pool_size'],
timeout=300,
)
report_data_task = executor.submit(calculate_report_data, form, form.build_model())
model = form.build_model()
report_data_task = executor.submit(calculate_report_data, form, model)
report_data: dict = await asyncio.wrap_future(report_data_task)
await self.finish(report_data)


class StaticModel(BaseRequestHandler):
async def get(self) -> None:
form = model_generator.VirusFormData.from_dict(model_generator.baseline_raw_form_data())
debug = self.settings.get("debug", False)

data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)

form = model_generator.VirusFormData.from_dict(model_generator.baseline_raw_form_data(), data_registry)
base_url = self.request.protocol + "://" + self.request.host
report_generator: ReportGenerator = self.settings['report_generator']
executor = loky.get_reusable_executor(max_workers=self.settings['handler_worker_pool_size'])
Expand Down Expand Up @@ -245,14 +262,14 @@ async def get(self, hotel_id, floor_id):
if (client_id == None or client_secret == None or arve_api_key == None):
# If the credentials are not defined, we skip the ARVE API connection
return self.send_error(401)

http_client = AsyncHTTPClient()

URL = 'https://arveapi.auth.eu-central-1.amazoncognito.com/oauth2/token'
headers = { "Content-Type": "application/x-www-form-urlencoded",
"Authorization": b"Basic " + base64.b64encode(f'{client_id}:{client_secret}'.encode())
}

try:
response = await http_client.fetch(HTTPRequest(
url=URL,
Expand All @@ -263,9 +280,9 @@ async def get(self, hotel_id, floor_id):
raise_error=True)
except Exception as e:
print("Something went wrong: %s" % e)

access_token = json.loads(response.body)['access_token']

URL = f'https://api.arve.swiss/v1/{hotel_id}/{floor_id}'
headers = {
"x-api-key": arve_api_key,
Expand All @@ -280,11 +297,11 @@ async def get(self, hotel_id, floor_id):
raise_error=True)
except Exception as e:
print("Something went wrong: %s" % e)

self.set_header("Content-Type", 'application/json')
return self.finish(response.body)


class CasesData(BaseRequestHandler):
async def get(self, country):
http_client = AsyncHTTPClient()
Expand All @@ -300,7 +317,7 @@ async def get(self, country):
print("Something went wrong: %s" % e)

country_name = json.loads(response.body)['name']['common']

# Get global incident rates
URL = 'https://covid19.who.int/WHO-COVID-19-global-data.csv'
try:
Expand All @@ -321,7 +338,7 @@ async def get(self, country):
# If any of the 'New_cases' is 0, it means the data is not updated.
if (cases.loc[eight_days_ago:current_date]['New_cases'] == 0).any(): return self.finish('')
return self.finish(str(round(cases.loc[eight_days_ago:current_date]['New_cases'].mean())))


class GenericExtraPage(BaseRequestHandler):

Expand All @@ -340,7 +357,7 @@ def get(self):
active_page=self.active_page,
text_blocks=template_environment.globals["common_text"]
))


class CO2ModelResponse(BaseRequestHandler):
def check_xsrf_cookie(self):
Expand All @@ -349,11 +366,16 @@ def check_xsrf_cookie(self):
Thus, XSRF cookies are disabled by overriding base class implementation of this method with a pass statement.
"""
pass

async def post(self, endpoint: str) -> None:
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)

requested_model_config = tornado.escape.json_decode(self.request.body)
try:
form = co2_model_generator.CO2FormData.from_dict(requested_model_config)
form = co2_model_generator.CO2FormData.from_dict(requested_model_config, data_registry)
except Exception as err:
if self.settings.get("debug", False):
import traceback
Expand All @@ -376,17 +398,17 @@ async def post(self, endpoint: str) -> None:
co2_model_generator.CO2FormData.build_model, form,
)
report = await asyncio.wrap_future(report_task)

result = dict(report.CO2_fit_params())
ventilation_transition_times = report.ventilation_transition_times

result['fitting_ventilation_type'] = form.fitting_ventilation_type
result['transition_times'] = ventilation_transition_times
result['CO2_plot'] = co2_model_generator.CO2FormData.generate_ventilation_plot(CO2_data=form.CO2_data,
transition_times=ventilation_transition_times[:-1],
result['CO2_plot'] = co2_model_generator.CO2FormData.generate_ventilation_plot(CO2_data=form.CO2_data,
transition_times=ventilation_transition_times[:-1],
predictive_CO2=result['predictive_CO2'])
self.finish(result)


def get_url(app_root: str, relative_path: str = '/'):
return app_root.rstrip('/') + relative_path.rstrip('/')
Expand All @@ -413,7 +435,7 @@ def make_app(
(get_root_calculator_url(r'/report'), ConcentrationModel),
(get_root_url(r'/static/(.*)'), StaticFileHandler, {'path': static_dir}),
(get_root_calculator_url(r'/static/(.*)'), StaticFileHandler, {'path': calculator_static_dir}),
]
]

urls: typing.List = base_urls + [
(get_root_url(r'/_c/(.*)'), CompressedCalculatorFormInputs),
Expand All @@ -429,7 +451,7 @@ def make_app(
'active_page': 'calculator/user-guide',
'filename': 'userguide.html.j2'}),
]

interface: str = os.environ.get('CAIMIRA_THEME', '<undefined>')
if interface != '<undefined>' and (interface != '<undefined>' and 'cern' not in interface): urls = list(filter(lambda i: i in base_urls, urls))

Expand Down Expand Up @@ -468,9 +490,22 @@ def make_app(
if debug:
tornado.log.enable_pretty_logging()

data_registry = DataRegistry()
data_service = None
data_service_enabled = os.environ.get("DATA_SERVICE_ENABLED", "False")
is_enabled = data_service_enabled.lower() == "true"
if is_enabled:
credentials = {
"email": os.environ.get("DATA_SERVICE_CLIENT_EMAIL", None),
"password": os.environ.get("DATA_SERVICE_CLIENT_PASSWORD", None),
}
data_service = DataService.create(credentials)

return Application(
urls,
debug=debug,
data_registry=data_registry,
data_service=data_service,
template_environment=template_environment,
default_handler_class=Missing404Handler,
report_generator=ReportGenerator(loader, get_root_url, get_root_calculator_url),
Expand Down
21 changes: 19 additions & 2 deletions caimira/apps/calculator/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
from pathlib import Path

from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -34,16 +35,32 @@ def configure_parser(parser) -> argparse.ArgumentParser:
return parser


def _init_logging(debug=False):
# Set the logging level for urllib3 and requests to WARNING
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)

# set app root log level
logger = logging.getLogger()
root_log_level = logging.DEBUG if debug else logging.WARNING
logger.setLevel(root_log_level)


def main():
parser = configure_parser(argparse.ArgumentParser())
args = parser.parse_args()

debug = args.no_debug
_init_logging(debug)

theme_dir = args.theme
if theme_dir is not None:
theme_dir = Path(theme_dir).absolute()
assert theme_dir.exists()
app = make_app(debug=args.no_debug, APPLICATION_ROOT=args.app_root, calculator_prefix=args.prefix, theme_dir=theme_dir)

app = make_app(debug=debug, APPLICATION_ROOT=args.app_root, calculator_prefix=args.prefix, theme_dir=theme_dir)
app.listen(args.port)
IOLoop.instance().start()
IOLoop.current().start()


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 45b81b1

Please sign in to comment.