Skip to content

Commit

Permalink
✨ Add anomalist to owidbot
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Oct 21, 2024
1 parent 9e1caac commit df4f867
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 25 deletions.
39 changes: 36 additions & 3 deletions apps/anomalist/anomalist_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import structlog
from owid.catalog import find
from sqlalchemy.engine import Engine
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session

from apps.anomalist.detectors import (
Expand Down Expand Up @@ -192,13 +193,14 @@ def anomaly_detection(
variable_mapping: Optional[dict[int, int]] = None,
variable_ids: Optional[list[int]] = None,
dry_run: bool = False,
force: bool = False,
reset_db: bool = False,
) -> None:
"""Detect anomalies."""
engine = get_engine()

# Ensure the 'anomalies' table exists. Optionally reset it if reset_db is True.
gm.Anomaly.create_table(engine, reset=reset_db)
gm.Anomaly.create_table(engine, if_exists="replace" if reset_db else "skip")

# If no anomaly types are provided, default to all available types
if not anomaly_types:
Expand Down Expand Up @@ -235,6 +237,10 @@ def anomaly_detection(
dataset_variable_ids[variable.datasetId].append(variable)

for dataset_id, variables_in_dataset in dataset_variable_ids.items():
# Get dataset's checksum
with Session(engine) as session:
dataset = gm.Dataset.load_dataset(session, dataset_id)

log.info("Loading data from S3")
variables_old = [
variables[variable_id_old]
Expand All @@ -249,6 +255,11 @@ def anomaly_detection(
if anomaly_type not in ANOMALY_DETECTORS:
raise ValueError(f"Unsupported anomaly type: {anomaly_type}")

if not force:
if not needs_update(engine, dataset, anomaly_type):
log.info(f"Anomaly type {anomaly_type} for dataset {dataset_id} already exists in the database.")
continue

log.info(f"Detecting anomaly type {anomaly_type} for dataset {dataset_id}")

# Instantiate the anomaly detector.
Expand Down Expand Up @@ -276,6 +287,7 @@ def anomaly_detection(
# TODO: validate format of the output dataframe
anomaly = gm.Anomaly(
datasetId=dataset_id,
datasetSourceChecksum=dataset.sourceChecksum,
anomalyType=anomaly_type,
)
anomaly.dfScore = df_score_long
Expand Down Expand Up @@ -319,6 +331,22 @@ def anomaly_detection(
session.commit()


def needs_update(engine: Engine, dataset: gm.Dataset, anomaly_type: str) -> bool:
"""If there's an anomaly with the dataset checksum in DB, it doesn't need
to be updated."""
with Session(engine) as session:
try:
anomaly = gm.Anomaly.load(
session,
dataset_id=dataset.id,
anomaly_type=anomaly_type,
)
except NoResultFound:
return True

return anomaly.datasetSourceChecksum != dataset.sourceChecksum


def export_anomalies_file(df: pd.DataFrame, dataset_id: int, anomaly_type: str) -> str:
"""Export anomaly df to local file (and upload to staging server if applicable)."""
filename = f"{dataset_id}_{anomaly_type}.feather"
Expand Down Expand Up @@ -353,8 +381,13 @@ def load_data_for_variables(engine: Engine, variables: list[gm.Variable]) -> pd.
# reorder in the same order as variables
df = df[[v.id for v in variables]]

# try converting to numeric
df = df.astype(float)
# TODO: how should we treat non-numeric variables? We can exclude it here, but then we need to
# fix it in detectors
# HACK: set non-numeric variables to zero
numeric_cols = df.select_dtypes(include="number").columns
for col in df.columns:
if col not in numeric_cols:
df[col] = 0

# TODO:
# remove countries with all nulls or all zeros or constant values
Expand Down
8 changes: 8 additions & 0 deletions apps/anomalist/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
type=bool,
help="Do not write to target database.",
)
@click.option(
"--force",
"-f",
is_flag=True,
help="TBD",
)
@click.option(
"--reset-db/--no-reset-db",
default=False,
Expand All @@ -61,6 +67,7 @@ def cli(
variable_mapping: str,
variable_ids: Optional[list[int]],
dry_run: bool,
force: bool,
reset_db: bool,
) -> None:
"""TBD
Expand Down Expand Up @@ -111,6 +118,7 @@ def cli(
variable_mapping=variable_mapping_dict,
variable_ids=list(variable_ids) if variable_ids else None,
dry_run=dry_run,
force=force,
reset_db=reset_db,
)

Expand Down
61 changes: 61 additions & 0 deletions apps/owidbot/anomalist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from structlog import get_logger

from apps.anomalist.anomalist_api import anomaly_detection
from apps.wizard.app_pages.anomalist.utils import load_variable_mapping
from etl import grapher_model as gm
from etl.config import OWIDEnv
from etl.db import Engine, read_sql

from .chart_diff import production_or_master_engine

log = get_logger()


def run(branch: str) -> None:
"""Compute all anomalist for new and updated datasets."""
# Get engines for branch and production
source_engine = OWIDEnv.from_staging(branch).get_engine()
target_engine = production_or_master_engine()

# Create table with anomalist if it doesn't exist
gm.Anomaly.create_table(source_engine, if_exists="skip")

# Load new dataset ids
datasets_new_ids = _load_datasets_new_ids(source_engine, target_engine)

if not datasets_new_ids:
log.info("No new datasets found.")
return

log.info(f"New datasets: {datasets_new_ids}")

# Load all their variables
q = """SELECT id FROM variables WHERE datasetId IN %(dataset_ids)s"""
variable_ids = list(read_sql(q, source_engine, params={"dataset_ids": datasets_new_ids})["id"])

# Load variable mapping
variable_mapping_dict = load_variable_mapping(datasets_new_ids)

# Run anomalist
anomaly_detection(
variable_mapping=variable_mapping_dict,
variable_ids=variable_ids,
)


def _load_datasets_new_ids(source_engine: Engine, target_engine: Engine) -> list[int]:
# Get new datasets
# TODO: replace by real catalogPath when we have it in MySQL
q = """SELECT
id,
CONCAT(namespace, "/", version, "/", shortName) as catalogPath
FROM datasets
"""
source_datasets = read_sql(q, source_engine)
target_datasets = read_sql(q, target_engine)

return list(
source_datasets[
source_datasets.catalogPath.isin(set(source_datasets["catalogPath"]) - set(target_datasets["catalogPath"]))
]["id"]
)
15 changes: 10 additions & 5 deletions apps/owidbot/chart_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from apps.wizard.app_pages.chart_diff.chart_diff import ChartDiffsLoader
from etl.config import OWID_ENV, OWIDEnv, get_container_name
from etl.db import Engine

from . import github_utils as gh_utils

Expand Down Expand Up @@ -66,14 +67,18 @@ def run(branch: str, charts_df: pd.DataFrame) -> str:
return body


def call_chart_diff(branch: str) -> pd.DataFrame:
source_engine = OWIDEnv.from_staging(branch).get_engine()

def production_or_master_engine() -> Engine:
"""Return the production engine if available, otherwise connect to staging-site-master."""
if OWID_ENV.env_remote == "production":
target_engine = OWID_ENV.get_engine()
return OWID_ENV.get_engine()
else:
log.warning("ENV file doesn't connect to production DB, comparing against staging-site-master")
target_engine = OWIDEnv.from_staging("master").get_engine()
return OWIDEnv.from_staging("master").get_engine()


def call_chart_diff(branch: str) -> pd.DataFrame:
source_engine = OWIDEnv.from_staging(branch).get_engine()
target_engine = production_or_master_engine()

df = ChartDiffsLoader(source_engine, target_engine).get_diffs_summary_df(
config=True,
Expand Down
9 changes: 7 additions & 2 deletions apps/owidbot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from rich import print
from rich_click.rich_command import RichCommand

from apps.owidbot import chart_diff, data_diff, grapher
from apps.owidbot import anomalist, chart_diff, data_diff, grapher
from etl.config import get_container_name

from . import github_utils as gh_utils

log = structlog.get_logger()

REPOS = Literal["etl", "owid-grapher"]
SERVICES = Literal["data-diff", "chart-diff", "grapher"]
SERVICES = Literal["data-diff", "chart-diff", "grapher", "anomalist"]


@click.command("owidbot", cls=RichCommand, help=__doc__)
Expand Down Expand Up @@ -76,6 +76,11 @@ def cli(

elif service == "grapher":
services_body["grapher"] = grapher.run(branch)

elif service == "anomalist":
# TODO: anomalist could post a summary of anomalies to the PR
anomalist.run(branch)

else:
raise AssertionError("Invalid service")

Expand Down
24 changes: 15 additions & 9 deletions apps/wizard/app_pages/anomalist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ def get_datasets_and_mapping_inputs() -> Tuple[Dict[int, str], Dict[int, str], D
steps_df_grapher["id_name"] = [f"[{ds['id']}] {ds['name']}" for ds in steps_df_grapher.to_dict(orient="records")]

# Load mapping created by indicator upgrader (if any).
variable_mapping = load_variable_mapping(datasets_new_ids)

# List all grapher datasets.
datasets_all = steps_df_grapher["id_name"].to_list()
datasets_all = steps_df_grapher[["id", "id_name"]].set_index("id").squeeze().to_dict()

# List new datasets.
datasets_new = {k: v for k, v in datasets_all.items() if k in datasets_new_ids}

return datasets_all, datasets_new, variable_mapping # type: ignore


def load_variable_mapping(datasets_new_ids: List[int]) -> Dict[int, int]:
mapping = WizardDB.get_variable_mapping_raw()
if len(mapping) > 0:
# Set of ids of new datasets that appear in the mapping generated by indicator upgrader.
Expand All @@ -66,22 +79,15 @@ def get_datasets_and_mapping_inputs() -> Tuple[Dict[int, str], Dict[int, str], D
# This could be useful if a user wants to compare two arbitrary versions of existing grapher datasets.
variable_mapping = dict()

# List all grapher datasets.
datasets_all = steps_df_grapher["id_name"].to_list()
datasets_all = steps_df_grapher[["id", "id_name"]].set_index("id").squeeze().to_dict()

# List new datasets.
datasets_new = {k: v for k, v in datasets_all.items() if k in datasets_new_ids}

return datasets_all, datasets_new, variable_mapping # type: ignore
return variable_mapping # type: ignore


def create_tables(_owid_env: OWIDEnv = OWID_ENV):
"""Create all required tables.
If exist, nothing is created.
"""
gm.Anomaly.create_table(_owid_env.engine)
gm.Anomaly.create_table(_owid_env.engine, if_exists="skip")


@st.cache_data(show_spinner=False)
Expand Down
23 changes: 17 additions & 6 deletions etl/grapher_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,19 @@ def from_dict(cls, d: Dict[str, Any]) -> Self:
return x

@classmethod
def create_table(cls, engine: Engine, reset: bool = False) -> None:
if reset:
# Drop the table if it exists
def create_table(cls, engine: Engine, if_exists: Literal["fail", "replace", "skip"] = "fail") -> None:
if if_exists == "replace":
# Drop the table if it exists and create a new one
cls.__table__.drop(engine, checkfirst=True) # type: ignore

# Create table
cls.__table__.create(engine, checkfirst=True) # type: ignore
cls.__table__.create(engine, checkfirst=False) # type: ignore
elif if_exists == "skip":
# Create the table only if it doesn't already exist
cls.__table__.create(engine, checkfirst=True) # type: ignore
elif if_exists == "fail":
# Attempt to create the table; fail if it already exists
cls.__table__.create(engine, checkfirst=False) # type: ignore
else:
raise ValueError(f"Unrecognized value for if_exists: {if_exists}")


class Entity(Base):
Expand Down Expand Up @@ -1754,6 +1760,7 @@ class Anomaly(Base):
DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), init=False
)
datasetId: Mapped[int] = mapped_column(Integer)
datasetSourceChecksum: Mapped[Optional[str]] = mapped_column(VARCHAR(64), default=None)
anomalyType: Mapped[str] = mapped_column(VARCHAR(255), default=str)
path_file: Mapped[Optional[str]] = mapped_column(VARCHAR(255), default=None)
_dfScore: Mapped[Optional[bytes]] = mapped_column("dfScore", LONGBLOB, default=None)
Expand All @@ -1774,6 +1781,10 @@ def __repr__(self) -> str:
f"datasetId={self.datasetId}, anomalyType={self.anomalyType})>"
)

@classmethod
def load(cls, session: Session, dataset_id: int, anomaly_type: str) -> "Anomaly":
return session.scalars(select(cls).where(cls.datasetId == dataset_id, cls.anomalyType == anomaly_type)).one()

@hybrid_property
def dfScore(self) -> Optional[pd.DataFrame]: # type: ignore
if self._dfScore is None:
Expand Down

0 comments on commit df4f867

Please sign in to comment.