Skip to content

Commit

Permalink
include year prameter on EpiScanner class
Browse files Browse the repository at this point in the history
  • Loading branch information
luabida committed Dec 21, 2023
1 parent 3cc857c commit 629c1a7
Showing 1 changed file with 73 additions and 54 deletions.
127 changes: 73 additions & 54 deletions src/scanner/scanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os
import asyncio
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Literal

Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
"SC", "SP", "SE", "TO", "DF"
],
# fmt: on
year: int,
verbose: bool = False,
):
"""
Expand Down Expand Up @@ -82,12 +84,17 @@ def __init__(
f"Unknown uf {uf}. Options: {list(STATES.keys())}"
)

cur_year = datetime.now().year
if year > cur_year or year < 2010:
raise ValueError("Year must be < current year and > 2010")

self.disease = disease
self.uf = uf
self.year = year
self.verbose = verbose
self.data = self._get_alerta_table()

asyncio.run(self._scan_all())
asyncio.run(self._scan_all_geocodes())

def export(
self,
Expand Down Expand Up @@ -117,7 +124,9 @@ def export(
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

file_name = self.uf + "_" + self.disease + "." + format
file_name = (
self.uf + "_" + self.disease + "_" + str(self.year) + "." + format
)
file = output_dir / file_name

df = self._parse_results()
Expand Down Expand Up @@ -164,7 +173,8 @@ def _get_alerta_table(self) -> pd.DataFrame:
FROM "Municipio"."Historico_alerta{table_suffix}" historico
JOIN "Dengue_global"."Municipio" municipio
ON historico.municipio_geocodigo=municipio.geocodigo
WHERE municipio.uf=\'{state_name}\'
WHERE municipio.uf='{state_name}'
AND EXTRACT(YEAR FROM "data_iniSE") = {self.year}
ORDER BY "data_iniSE" DESC;
"""

Expand All @@ -181,18 +191,17 @@ def _filter_city(self, geocode):
dfcity["casos_cum"] = dfcity.casos.cumsum()
return dfcity

def _save_results(self, geocode, year, results, curve):
def _save_results(self, geocode, results, curve):
self.results[geocode].append(
{
"year": year,
"success": results.success,
"params": results.params.valuesdict(),
"sir_pars": get_SIR_pars(results.params.valuesdict()),
}
)
self.curves[geocode].append(
{
"year": year,
"year": self.year,
"df": curve,
"residuals": abs(curve.richards - curve.casos_cum),
"sum_res": (
Expand All @@ -206,8 +215,8 @@ def _save_results(self, geocode, year, results, curve):
def _parse_results(self) -> pd.DataFrame:
data = {
"geocode": [],
"muni_name": [],
"year": [],
"muni_name": [],
"peak_week": [],
"beta": [],
"gamma": [],
Expand All @@ -228,12 +237,10 @@ def _parse_results(self) -> pd.DataFrame:
params = [
p["params"]
for p in self.results[gc]
if p["year"] == c["year"]
][0]
sir_params = [
p["sir_pars"]
for p in self.results[gc]
if p["year"] == c["year"]
][0]
data["peak_week"].append(params["tp1"])
data["total_cases"].append(params["L1"])
Expand All @@ -254,46 +261,39 @@ async def _scan(self, geocode):
df = self._filter_city(geocode)
df = df.assign(year=[i.year for i in df.index])

async def scan_year(y):
if self.verbose:
logger.info(f"Scanning year {y}")

dfy = df[df.year == y]
window = int(max([str(x)[-2:] for x in dfy.SE]))
has_transmission = dfy.transmissao.sum() > 3

if not has_transmission:
if self.verbose:
logger.info(
f"""
There were less than 3 weeks with Rt>1
in {geocode} in {y}.\nSkipping analysis
"""
)
return

out, curve = otim(
dfy[["casos", "casos_cum"]].iloc[0:window], # NOQA E203
0,
window,
)

self._save_results(geocode, y, out, curve)
dfy = df[df.year == self.year]
window = int(max([str(x)[-2:] for x in dfy.SE]))
has_transmission = dfy.transmissao.sum() > 3

if out.success:
if self.verbose:
logger.info(
f"""
R0 in {y}: {
self.results[geocode][-1]['sir_pars']['R0']
}
"""
)
if not has_transmission:
if self.verbose:
logger.info(
f"""
There were less than 3 weeks with Rt>1
in {geocode}.\nSkipping analysis
"""
)
return

out, curve = otim(
df[["casos", "casos_cum"]].iloc[0:window], # NOQA E203
0,
window,
)

tasks = [scan_year(y) for y in set(df.year.values)]
await asyncio.gather(*tasks)
self._save_results(geocode, out, curve)

async def _scan_all(self):
if out.success:
if self.verbose:
logger.info(
f"""
R0: {
self.results[geocode][-1]['sir_pars']['R0']
}
"""
)

async def _scan_all_geocodes(self):
tasks = [
self._scan(geocode)
for geocode in self.data.municipio_geocodigo.unique()
Expand All @@ -303,17 +303,36 @@ async def _scan_all(self):
def _to_duckdb(self, output_dir: str):
output_dir = Path(output_dir)
db = output_dir / "episcanner.duckdb"
con = duckdb.connect(str(db.absolute()))

df = self._parse_results()
table_name = self.uf

con.register("df", df)
con.execute(
f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df"
)
con.unregister("df")
con.close()
try:
con = duckdb.connect(str(db.absolute()))
con.register("df", df)

try:
rows = con.execute(
f"SELECT COUNT(*) FROM '{table_name}'"
f" WHERE year = {self.year}"
).fetchone()[0]

if rows > 0:
con.execute(
f"REPLACE TABLE '{table_name}'"
f" WHERE year = {self.year}"
)
else:
con.execute(
f"INSERT INTO '{table_name}' SELECT * FROM df"
)
except duckdb.duckdb.CatalogException:
con.execute(
f"CREATE TABLE '{table_name}' AS SELECT * FROM df"
)
finally:
con.unregister("df")
con.close()

if self.verbose:
logger.info(f"{self.uf} data wrote into {db.absolute()}")

0 comments on commit 629c1a7

Please sign in to comment.