Skip to content

Commit

Permalink
✨ use async when upserting to grapher
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Dec 19, 2024
1 parent bf1a645 commit 3bd5e5b
Show file tree
Hide file tree
Showing 22 changed files with 1,170 additions and 502 deletions.
46 changes: 20 additions & 26 deletions apps/backport/datasync/data_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import requests
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from structlog import get_logger
from tenacity import Retrying
from tenacity.retry import retry_if_exception_type
from tenacity.stop import stop_after_attempt
from tenacity.wait import wait_fixed

from etl import config, files
from etl import config
from etl.config import OWIDEnv
from etl.db import read_sql

Expand Down Expand Up @@ -156,7 +157,7 @@ def variable_data(data_df: pd.DataFrame) -> Dict[str, Any]:
return data # type: ignore


def _load_variable(session: Session, variable_id: int) -> Dict[str, Any]:
async def _load_variable(session: AsyncSession, variable_id: int) -> Dict[str, Any]:
sql = """
SELECT
variables.*,
Expand All @@ -173,14 +174,14 @@ def _load_variable(session: Session, variable_id: int) -> Dict[str, Any]:
"""

# Using the session to execute raw SQL and fetching one row as a result
result = session.execute(text(sql), {"variable_id": variable_id}).fetchone()
result = (await session.execute(text(sql), {"variable_id": variable_id})).fetchone()

# Ensure result exists and convert to dictionary
assert result, f"variableId `{variable_id}` not found"
return dict(result._mapping)


def _load_topic_tags(session: Session, variable_id: int) -> List[str]:
async def _load_topic_tags(session: AsyncSession, variable_id: int) -> List[str]:
sql = """
SELECT
tags.name
Expand All @@ -191,13 +192,13 @@ def _load_topic_tags(session: Session, variable_id: int) -> List[str]:
"""

# Using the session to execute raw SQL
result = session.execute(text(sql), {"variable_id": variable_id}).fetchall()
result = (await session.execute(text(sql), {"variable_id": variable_id})).fetchall()

# Extract tag names from the result and return as a list
return [row[0] for row in result]


def _load_faqs(session: Session, variable_id: int) -> List[Dict[str, Any]]:
async def _load_faqs(session: AsyncSession, variable_id: int) -> List[Dict[str, Any]]:
sql = """
SELECT
gdocId,
Expand All @@ -208,13 +209,13 @@ def _load_faqs(session: Session, variable_id: int) -> List[Dict[str, Any]]:
"""

# Using the session to execute raw SQL
result = session.execute(text(sql), {"variable_id": variable_id}).fetchall()
result = (await session.execute(text(sql), {"variable_id": variable_id})).fetchall()

# Convert the result rows to a list of dictionaries
return [dict(row._mapping) for row in result]


def _load_origins_df(session: Session, variable_id: int) -> pd.DataFrame:
async def _load_origins_df(session: AsyncSession, variable_id: int) -> pd.DataFrame:
sql = """
SELECT
origins.*
Expand All @@ -225,7 +226,7 @@ def _load_origins_df(session: Session, variable_id: int) -> pd.DataFrame:
"""

# Use the session to execute the raw SQL
result_proxy = session.execute(text(sql), {"variable_id": variable_id})
result_proxy = await session.execute(text(sql), {"variable_id": variable_id})

# Fetch the results into a DataFrame
df = pd.DataFrame(result_proxy.fetchall(), columns=result_proxy.keys())
Expand Down Expand Up @@ -358,17 +359,22 @@ def _move_population_origin_to_end(origins: List[Dict[str, Any]]) -> List[Dict[s
return new_origins


def variable_metadata(session: Session, variable_id: int, variable_data: pd.DataFrame) -> Dict[str, Any]:
async def variable_metadata(session: AsyncSession, variable_id: int, variable_data: pd.DataFrame) -> Dict[str, Any]:
"""Fetch metadata for a single variable from database. This function was initially based on the
one from owid-grapher repository and uses raw SQL commands. It'd be interesting to rewrite it
using SQLAlchemy ORM in grapher_model.py.
"""
task_variable = _load_variable(session, variable_id)
task_origins = _load_origins_df(session, variable_id)
task_topic_tags = _load_topic_tags(session, variable_id)
task_faqs = _load_faqs(session, variable_id)

return _variable_metadata(
db_variable_row=_load_variable(session, variable_id),
db_variable_row=await task_variable,
variable_data=variable_data,
db_origins_df=_load_origins_df(session, variable_id),
db_topic_tags=_load_topic_tags(session, variable_id),
db_faqs=_load_faqs(session, variable_id),
db_origins_df=await task_origins,
db_topic_tags=await task_topic_tags,
db_faqs=await task_faqs,
)


Expand All @@ -394,18 +400,6 @@ def _omit_nullable_values(d: dict) -> dict:
return {k: v for k, v in d.items() if v is not None and (isinstance(v, list) and len(v) or not pd.isna(v))}


def checksum_data_str(var_data_str: str) -> str:
return files.checksum_str(var_data_str)


def checksum_metadata(meta: Dict[str, Any]) -> str:
"""Calculate checksum for metadata. It modifies the metadata dict!"""
# Drop fields not needed for checksum computation
meta = filter_out_fields_in_metadata_for_checksum(meta)

return files.checksum_str(json.dumps(meta, default=str))


def filter_out_fields_in_metadata_for_checksum(meta: Dict[str, Any]) -> Dict[str, Any]:
"""Drop fields that are not needed to estimate the checksum."""
meta_ = deepcopy(meta)
Expand Down
24 changes: 24 additions & 0 deletions apps/backport/datasync/datasync.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import gzip
import json
from typing import Any, Dict
Expand All @@ -17,6 +18,9 @@
config.enable_bugsnag()


R2_UPLOAD_SEMAPHORE = asyncio.Semaphore(10)


def upload_gzip_dict(d: Dict[str, Any], s3_path: str, private: bool = False) -> None:
return upload_gzip_string(json.dumps(d, default=str), s3_path=s3_path, private=private)

Expand Down Expand Up @@ -46,3 +50,23 @@ def upload_gzip_string(s: str, s3_path: str, private: bool = False) -> None:
ContentType="application/json",
**extra_args,
)


async def upload_gzip_string_async(client: Any, s: str, s3_path: str, private: bool = False) -> None:
"""Upload compressed dictionary to S3 and return its URL."""
body_gzip = gzip.compress(s.encode())

bucket, key = s3_utils.s3_bucket_key(s3_path)

assert not private, "r2 does not support private files yet"
extra_args = {}

async with R2_UPLOAD_SEMAPHORE:
await client.put_object(
Bucket=bucket,
Body=body_gzip,
Key=key,
ContentEncoding="gzip",
ContentType="application/json",
**extra_args,
)
22 changes: 12 additions & 10 deletions apps/chart_sync/admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import string
from typing import Any, Dict, List, Optional

import aiohttp
import requests
import structlog
from sqlalchemy import text
Expand Down Expand Up @@ -81,16 +82,17 @@ def set_tags(self, chart_id: int, tags: List[Dict[str, Any]]) -> dict:
assert js["success"]
return js

def put_grapher_config(self, variable_id: int, grapher_config: Dict[str, Any]) -> dict:
resp = requests.put(
self.base_url + f"/admin/api/variables/{variable_id}/grapherConfigETL",
cookies={"sessionid": self.session_id},
json=grapher_config,
)
js = self._json_from_response(resp)
assert js["success"]
return js

async def put_grapher_config(self, variable_id: int, grapher_config: Dict[str, Any]) -> dict:
async with aiohttp.ClientSession(cookies={"sessionid": self.session_id}) as session:
async with session.put(
self.base_url + f"/admin/api/variables/{variable_id}/grapherConfigETL", json=grapher_config
) as resp:
# TODO: make _json_from_response async
js = await resp.json()
assert js["success"]
return js

# TODO: make it async
def delete_grapher_config(self, variable_id: int) -> dict:
resp = requests.delete(
self.base_url + f"/admin/api/variables/{variable_id}/grapherConfigETL",
Expand Down
10 changes: 1 addition & 9 deletions etl/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,9 @@ def main_cli(

# make everything single threaded, useful for debugging
if not use_threads:
config.GRAPHER_INSERT_WORKERS = 1
config.DIRTY_STEPS_WORKERS = 1
workers = 1

# GRAPHER_INSERT_WORKERS should be split among workers
if workers > 1:
config.GRAPHER_INSERT_WORKERS = config.GRAPHER_INSERT_WORKERS // workers

kwargs = dict(
steps=steps,
dry_run=dry_run,
Expand All @@ -217,7 +212,6 @@ def main_cli(
for _ in runs:
if ipdb:
config.IPDB_ENABLED = True
config.GRAPHER_INSERT_WORKERS = 1
config.DIRTY_STEPS_WORKERS = 1
kwargs["workers"] = 1
with launch_ipdb_on_exception():
Expand Down Expand Up @@ -378,9 +372,7 @@ def run_dag(
)
return exec_steps(steps, strict=strict)
else:
print(
f"--- Running {len(steps)} steps with {workers} processes ({config.GRAPHER_INSERT_WORKERS} threads each):"
)
print(f"--- Running {len(steps)} steps with {workers} processes:")
return exec_steps_parallel(steps, workers, dag=dag, strict=strict)


Expand Down
4 changes: 0 additions & 4 deletions etl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ def variable_metadata_url(variable_id):
# because we're making a lot of HTTP requests
DIRTY_STEPS_WORKERS = int(env.get("DIRTY_STEPS_WORKERS", 5))

# number of workers for grapher inserts to DB, this is for all processes, so if
# --workers is higher than 1, this will be divided among them
GRAPHER_INSERT_WORKERS = int(env.get("GRAPHER_WORKERS", 40))

# only upsert indicators matching this filter, this is useful for fast development
# of data pages for a single indicator
GRAPHER_FILTER = env.get("GRAPHER_FILTER", None)
Expand Down
11 changes: 11 additions & 0 deletions etl/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import validators
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import Session

from etl import config
Expand Down Expand Up @@ -60,6 +61,16 @@ def get_engine(conf: Optional[Dict[str, Any]] = None) -> Engine:
return _get_engine_cached(cf, pid)


def get_engine_async(conf: Optional[Dict[str, Any]] = None) -> AsyncEngine:
cf: Any = dict_to_object(conf) if conf else config
engine = create_async_engine(
f"mysql+aiomysql://{cf.DB_USER}:{quote(cf.DB_PASS)}@{cf.DB_HOST}:{cf.DB_PORT}/{cf.DB_NAME}",
pool_size=30, # Increase pool size
max_overflow=50, # Increase overflow limit
)
return engine


def get_dataset_id(
dataset_name: str, db_conn: Optional[pymysql.Connection] = None, version: Optional[str] = None
) -> Any:
Expand Down
18 changes: 0 additions & 18 deletions etl/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,6 @@ def clear(self) -> None:

CACHE_CHECKSUM_FILE = RuntimeCache()

TEXT_CHARS = bytes(range(32, 127)) + b"\n\r\t\f\b"
DEFAULT_CHUNK_SIZE = 512


def istextblock(block: bytes) -> bool:
if not block:
# An empty file is considered a valid text file
return True

if b"\x00" in block:
# Files with null bytes are binary
return False

# Use translate's 'deletechars' argument to efficiently remove all
# occurrences of TEXT_CHARS from the block
nontext = block.translate(None, TEXT_CHARS)
return float(len(nontext)) / len(block) <= 0.30


def checksum_str(s: str) -> str:
"Return the md5 hex digest of the string."
Expand Down
30 changes: 14 additions & 16 deletions etl/grapher_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session

from apps.backport.datasync import data_metadata as dm
from etl.db import get_engine, read_sql
from etl.files import checksum_str

Expand Down Expand Up @@ -92,8 +93,8 @@ def _yield_wide_table(
# Validation
if "year" not in table.primary_key:
raise Exception("Table is missing `year` primary key")
if "entity_id" not in table.primary_key:
raise Exception("Table is missing `entity_id` primary key")
if "entityId" not in table.primary_key:
raise Exception("Table is missing `entityId` primary key")
if na_action == "raise":
for col in table.columns:
if table[col].isna().any():
Expand All @@ -102,7 +103,7 @@ def _yield_wide_table(
if cols_with_none_units:
raise Exception("Columns with missing units: " + ", ".join(cols_with_none_units))

dim_names = [k for k in table.primary_key if k not in ("year", "entity_id")]
dim_names = [k for k in table.primary_key if k not in ("year", "entityId", "entityCode", "entityName")]

# Keep only entity_id and year in index
table = table.reset_index(level=dim_names)
Expand Down Expand Up @@ -188,7 +189,6 @@ def _yield_wide_table(
# traverse metadata and expand Jinja
tab[short_name].metadata = _expand_jinja(tab[short_name].metadata, dim_dict)

# Keep only entity_id and year in index
yield tab


Expand Down Expand Up @@ -504,20 +504,14 @@ def _adapt_dataset_metadata_for_grapher(
return metadata


def _adapt_table_for_grapher(
table: catalog.Table, engine: Engine | None = None, country_col: str = "country", year_col: str = "year"
) -> catalog.Table:
def _adapt_table_for_grapher(table: catalog.Table, engine: Engine) -> catalog.Table:
"""Adapt table (from a garden dataset) to be used in a grapher step. This function
is not meant to be run explicitly, but by default in the grapher step.
Parameters
----------
table : catalog.Table
Table from garden dataset.
country_col : str
Name of country column in table.
year_col : str
Name of year column in table.
Returns
-------
Expand All @@ -534,7 +528,7 @@ def _adapt_table_for_grapher(
), f"Variable titles are not unique ({variable_titles_counts[variable_titles_counts > 1].index})."

# Remember original dimensions
dim_names = [n for n in table.index.names if n and n not in ("year", "date", "entity_id", country_col)]
dim_names = [n for n in table.index.names if n and n not in ("year", "date", "entity_id", "country")]

# Reset index unless we have default index
if table.index.names != [None]:
Expand All @@ -546,14 +540,18 @@ def _adapt_table_for_grapher(
assert "year" not in table.columns, "Table cannot have both `date` and `year` columns."
table = adapt_table_with_dates_to_grapher(table)

assert {"year", country_col} <= set(table.columns), f"Table must have columns {country_col} and year."
assert {"year", "country"} <= set(table.columns), "Table must have columns country and year."
assert "entity_id" not in table.columns, "Table must not have column entity_id."

# Grapher needs a column entity id, that is constructed based on the unique entity names in the database.
table["entity_id"] = country_to_entity_id(table[country_col], create_entities=True, engine=engine)
table = table.drop(columns=[country_col]).rename(columns={year_col: "year"})
table["entityId"] = country_to_entity_id(table["country"], create_entities=True, engine=engine)
table = table.drop(columns=["country"])

# Add entity code and name
with Session(engine) as session:
table = dm.add_entity_code_and_name(session, table).copy_metadata(table)

table = table.set_index(["entity_id", "year"] + dim_names)
table = table.set_index(["entityId", "entityCode", "entityName", "year"] + dim_names)

# Ensure the default source of each column includes the description of the table (since that is the description that
# will appear in grapher on the SOURCES tab).
Expand Down
Loading

0 comments on commit 3bd5e5b

Please sign in to comment.