Skip to content

Commit

Permalink
🔨 Anomalist performance (#3447)
Browse files Browse the repository at this point in the history
* 🔨 Improve anomalist performance
  • Loading branch information
Marigold authored Oct 22, 2024
1 parent e3a67b3 commit e1eb471
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 88 deletions.
17 changes: 6 additions & 11 deletions apps/anomalist/anomalist_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from etl.config import OWID_ENV
from etl.db import get_engine, read_sql
from etl.files import create_folder, upload_file_to_server
from etl.grapher_io import variable_data_df_from_s3
from etl.grapher_io import variable_data_df_from_catalog

log = structlog.get_logger()

Expand Down Expand Up @@ -261,7 +261,7 @@ def anomaly_detection(
with Session(engine) as session:
dataset = gm.Dataset.load_dataset(session, dataset_id)

log.info("loading_data_from_s3.start")
log.info("loading_data.start")
variables_old = [
variables[variable_id_old]
for variable_id_old in variable_mapping.keys()
Expand All @@ -270,7 +270,7 @@ def anomaly_detection(
variables_old_and_new = variables_in_dataset + variables_old
t = time.time()
df = load_data_for_variables(engine=engine, variables=variables_old_and_new)
log.info("loading_data_from_s3.end", t=time.time() - t)
log.info("loading_data.end", t=time.time() - t)

for anomaly_type in anomaly_types:
# Instantiate the anomaly detector.
Expand Down Expand Up @@ -396,15 +396,10 @@ def export_anomalies_file(df: pd.DataFrame, dataset_id: int, anomaly_type: str)
return path_str


# @memory.cache
def load_data_for_variables(engine: Engine, variables: list[gm.Variable]) -> pd.DataFrame:
# TODO: cache this on disk & re-validate with etags
df_long = variable_data_df_from_s3(engine, [v.id for v in variables], workers=None)

df_long = df_long.rename(columns={"variableId": "variable_id", "entityName": "entity_name"})

# pivot dataframe
df = df_long.pivot(index=["entity_name", "year"], columns="variable_id", values="value")
# Load data from local catalog
df = variable_data_df_from_catalog(engine, variables=variables)
df = df.rename(columns={"country": "entity_name"}).set_index(["entity_name", "year"])

# reorder in the same order as variables
df = df[[v.id for v in variables]]
Expand Down
41 changes: 38 additions & 3 deletions apps/anomalist/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import structlog
from joblib import Memory
from rich_click.rich_command import RichCommand
from sqlalchemy.engine import Engine

from apps.anomalist.anomalist_api import ANOMALY_TYPE, anomaly_detection
from etl.db import get_engine, read_sql
from etl.db import get_engine, production_or_master_engine, read_sql
from etl.paths import CACHE_DIR

log = structlog.get_logger()
Expand Down Expand Up @@ -91,6 +92,12 @@ def cli(
```
$ etl anomalist --anomaly-type gp --dataset-ids 6589
```
**Example 4:** Create anomalies for new datasets
```
$ etl anomalist --anomaly-type gp
```
"""
# Convert variable mapping from JSON to dictionary.
if variable_mapping:
Expand All @@ -104,8 +111,15 @@ def cli(
else:
variable_mapping_dict = {}

# Load all variables from given datasets
if dataset_ids:
# If no variable IDs are given, load all variables from the given datasets.
if not variable_ids:
assert not dataset_ids, "Cannot specify both dataset IDs and variable IDs."

# Use new datasets
if not dataset_ids:
dataset_ids = load_datasets_new_ids(get_engine())

# Load all variables from given datasets
assert not variable_ids, "Cannot specify both dataset IDs and variable IDs."
q = """
select id from variables
Expand All @@ -123,5 +137,26 @@ def cli(
)


def load_datasets_new_ids(source_engine: Engine) -> list[int]:
# Compare against production or staging-site-master
target_engine = production_or_master_engine()

# Get new datasets
# TODO: replace by real catalogPath when we have it in MySQL
q = """SELECT
id,
CONCAT(namespace, "/", version, "/", shortName) as catalogPath
FROM datasets
"""
source_datasets = read_sql(q, source_engine)
target_datasets = read_sql(q, target_engine)

return list(
source_datasets[
source_datasets.catalogPath.isin(set(source_datasets["catalogPath"]) - set(target_datasets["catalogPath"]))
]["id"]
)


if __name__ == "__main__":
cli()
11 changes: 8 additions & 3 deletions apps/anomalist/gp_detector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
import time
import warnings
Expand All @@ -24,6 +25,11 @@

memory = Memory(CACHE_DIR, verbose=0)

# Maximum time for processing in seconds
ANOMALIST_MAX_TIME = int(os.environ.get("ANOMALIST_MAX_TIME", 10))
# Number of jobs for parallel processing
ANOMALIST_N_JOBS = int(os.environ.get("ANOMALIST_N_JOBS", 1))


@memory.cache
def _load_population():
Expand Down Expand Up @@ -65,8 +71,7 @@ def _processing_queue(items: list[tuple[str, int]]) -> List[tuple]:
class AnomalyGaussianProcessOutlier(AnomalyDetector):
anomaly_type = "gp_outlier"

# TODO: max_time is hard-coded to 10, but it should be configurable in production
def __init__(self, max_time: Optional[float] = 10, n_jobs: int = 1):
def __init__(self, max_time: Optional[float] = ANOMALIST_MAX_TIME, n_jobs: int = ANOMALIST_N_JOBS):
self.max_time = max_time
self.n_jobs = n_jobs

Expand All @@ -76,7 +81,7 @@ def get_text(entity: str, year: int) -> str:

def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame:
# Convert to long format
df_wide = df.melt(id_vars=["entity_name", "year"])
df_wide = df.melt(id_vars=["entity_name", "year"], var_name="variable_id")
# Filter to only include the specified variable IDs.
df_wide = (
df_wide[df_wide["variable_id"].isin(variable_ids)]
Expand Down
31 changes: 9 additions & 22 deletions apps/owidbot/anomalist.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import time

from structlog import get_logger

from apps.anomalist.anomalist_api import anomaly_detection
from apps.anomalist.cli import load_datasets_new_ids
from apps.wizard.app_pages.anomalist.utils import load_variable_mapping
from etl import grapher_model as gm
from etl.config import OWIDEnv
from etl.db import Engine, read_sql

from .chart_diff import production_or_master_engine
from etl.db import read_sql

log = get_logger()

Expand All @@ -15,13 +16,12 @@ def run(branch: str) -> None:
"""Compute all anomalist for new and updated datasets."""
# Get engines for branch and production
source_engine = OWIDEnv.from_staging(branch).get_engine()
target_engine = production_or_master_engine()

# Create table with anomalist if it doesn't exist
gm.Anomaly.create_table(source_engine, if_exists="skip")

# Load new dataset ids
datasets_new_ids = _load_datasets_new_ids(source_engine, target_engine)
datasets_new_ids = load_datasets_new_ids(source_engine)

if not datasets_new_ids:
log.info("No new datasets found.")
Expand All @@ -36,26 +36,13 @@ def run(branch: str) -> None:
# Load variable mapping
variable_mapping_dict = load_variable_mapping(datasets_new_ids)

log.info("owidbot.anomalist.start", n_variables=len(variable_ids))
t = time.time()

# Run anomalist
anomaly_detection(
variable_mapping=variable_mapping_dict,
variable_ids=variable_ids,
)


def _load_datasets_new_ids(source_engine: Engine, target_engine: Engine) -> list[int]:
# Get new datasets
# TODO: replace by real catalogPath when we have it in MySQL
q = """SELECT
id,
CONCAT(namespace, "/", version, "/", shortName) as catalogPath
FROM datasets
"""
source_datasets = read_sql(q, source_engine)
target_datasets = read_sql(q, target_engine)

return list(
source_datasets[
source_datasets.catalogPath.isin(set(source_datasets["catalogPath"]) - set(target_datasets["catalogPath"]))
]["id"]
)
log.info("owidbot.anomalist.end", n_variables=len(variable_ids), t=time.time() - t)
13 changes: 2 additions & 11 deletions apps/owidbot/chart_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from structlog import get_logger

from apps.wizard.app_pages.chart_diff.chart_diff import ChartDiffsLoader
from etl.config import OWID_ENV, OWIDEnv, get_container_name
from etl.db import Engine
from etl.config import OWIDEnv, get_container_name
from etl.db import production_or_master_engine

from . import github_utils as gh_utils

Expand Down Expand Up @@ -67,15 +67,6 @@ def run(branch: str, charts_df: pd.DataFrame) -> str:
return body


def production_or_master_engine() -> Engine:
"""Return the production engine if available, otherwise connect to staging-site-master."""
if OWID_ENV.env_remote == "production":
return OWID_ENV.get_engine()
else:
log.warning("ENV file doesn't connect to production DB, comparing against staging-site-master")
return OWIDEnv.from_staging("master").get_engine()


def call_chart_diff(branch: str) -> pd.DataFrame:
source_engine = OWIDEnv.from_staging(branch).get_engine()
target_engine = production_or_master_engine()
Expand Down
11 changes: 11 additions & 0 deletions etl/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,14 @@ def to_sql(df: pd.DataFrame, name: str, engine: Optional[Engine | Session] = Non
return df.to_sql(name, engine.bind, *args, **kwargs)
else:
raise ValueError(f"Unsupported engine type {type(engine)}")


def production_or_master_engine() -> Engine:
"""Return the production engine if available, otherwise connect to staging-site-master."""
if config.OWID_ENV.env_remote == "production":
return config.OWID_ENV.get_engine()
elif config.ENV_FILE_PROD:
return config.OWIDEnv.from_env_file(config.ENV_FILE_PROD).get_engine()
else:
log.warning("ENV file doesn't connect to production DB, comparing against staging-site-master")
return config.OWIDEnv.from_staging("master").get_engine()
Loading

0 comments on commit e1eb471

Please sign in to comment.