Skip to content

Commit

Permalink
✨ use async when upserting to grapher
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Dec 19, 2024
1 parent df10676 commit d01f9d6
Show file tree
Hide file tree
Showing 17 changed files with 10,159 additions and 292 deletions.
42 changes: 27 additions & 15 deletions apps/backport/datasync/data_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def variable_data(data_df: pd.DataFrame) -> Dict[str, Any]:
return data # type: ignore


def _load_variable(session: Session, variable_id: int) -> Dict[str, Any]:
async def _load_variable(session: AsyncSession, variable_id: int) -> Dict[str, Any]:
sql = """
SELECT
variables.*,
Expand All @@ -41,14 +41,14 @@ def _load_variable(session: Session, variable_id: int) -> Dict[str, Any]:
"""

# Using the session to execute raw SQL and fetching one row as a result
result = session.execute(text(sql), {"variable_id": variable_id}).fetchone()
result = (await session.execute(text(sql), {"variable_id": variable_id})).fetchone()

# Ensure result exists and convert to dictionary
assert result, f"variableId `{variable_id}` not found"
return dict(result._mapping)


def _load_topic_tags(session: Session, variable_id: int) -> List[str]:
async def _load_topic_tags(session: AsyncSession, variable_id: int) -> List[str]:
sql = """
SELECT
tags.name
Expand All @@ -59,13 +59,13 @@ def _load_topic_tags(session: Session, variable_id: int) -> List[str]:
"""

# Using the session to execute raw SQL
result = session.execute(text(sql), {"variable_id": variable_id}).fetchall()
result = (await session.execute(text(sql), {"variable_id": variable_id})).fetchall()

# Extract tag names from the result and return as a list
return [row[0] for row in result]


def _load_faqs(session: Session, variable_id: int) -> List[Dict[str, Any]]:
async def _load_faqs(session: AsyncSession, variable_id: int) -> List[Dict[str, Any]]:
sql = """
SELECT
gdocId,
Expand All @@ -76,13 +76,13 @@ def _load_faqs(session: Session, variable_id: int) -> List[Dict[str, Any]]:
"""

# Using the session to execute raw SQL
result = session.execute(text(sql), {"variable_id": variable_id}).fetchall()
result = (await session.execute(text(sql), {"variable_id": variable_id})).fetchall()

# Convert the result rows to a list of dictionaries
return [dict(row._mapping) for row in result]


def _load_origins_df(session: Session, variable_id: int) -> pd.DataFrame:
async def _load_origins_df(session: AsyncSession, variable_id: int) -> pd.DataFrame:
sql = """
SELECT
origins.*
Expand All @@ -93,7 +93,7 @@ def _load_origins_df(session: Session, variable_id: int) -> pd.DataFrame:
"""

# Use the session to execute the raw SQL
result_proxy = session.execute(text(sql), {"variable_id": variable_id})
result_proxy = await session.execute(text(sql), {"variable_id": variable_id})

# Fetch the results into a DataFrame
df = pd.DataFrame(result_proxy.fetchall(), columns=result_proxy.keys())
Expand Down Expand Up @@ -212,12 +212,17 @@ def _variable_metadata(
return variableMetadata


def _move_population_origin_to_end(origins: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def _move_population_origin_to_end(
origins: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Move population origin to the end of the list of origins. This way it gets displayed last on data page."""
new_origins = []
pop_origin = None
for origin in origins:
if origin.get("title") == "Population" and origin.get("producer") == "Various sources":
if (
origin.get("title") == "Population"
and origin.get("producer") == "Various sources"
):
pop_origin = origin
else:
new_origins.append(origin)
Expand All @@ -226,17 +231,24 @@ def _move_population_origin_to_end(origins: List[Dict[str, Any]]) -> List[Dict[s
return new_origins


def variable_metadata(session: Session, variable_id: int, variable_data: pd.DataFrame) -> Dict[str, Any]:
async def variable_metadata(
session: AsyncSession, variable_id: int, variable_data: pd.DataFrame
) -> Dict[str, Any]:
"""Fetch metadata for a single variable from database. This function was initially based on the
one from owid-grapher repository and uses raw SQL commands. It'd be interesting to rewrite it
using SQLAlchemy ORM in grapher_model.py.
"""
task_variable = _load_variable(session, variable_id)
task_origins = _load_origins_df(session, variable_id)
task_topic_tags = _load_topic_tags(session, variable_id)
task_faqs = _load_faqs(session, variable_id)

return _variable_metadata(
db_variable_row=_load_variable(session, variable_id),
db_variable_row=await task_variable,
variable_data=variable_data,
db_origins_df=_load_origins_df(session, variable_id),
db_topic_tags=_load_topic_tags(session, variable_id),
db_faqs=_load_faqs(session, variable_id),
db_origins_df=await task_origins,
db_topic_tags=await task_topic_tags,
db_faqs=await task_faqs,
)


Expand Down
24 changes: 24 additions & 0 deletions apps/backport/datasync/datasync.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import gzip
import json
from typing import Any, Dict
Expand All @@ -17,6 +18,9 @@
config.enable_bugsnag()


R2_UPLOAD_SEMAPHORE = asyncio.Semaphore(10)


def upload_gzip_dict(d: Dict[str, Any], s3_path: str, private: bool = False) -> None:
return upload_gzip_string(json.dumps(d, default=str), s3_path=s3_path, private=private)

Expand Down Expand Up @@ -46,3 +50,23 @@ def upload_gzip_string(s: str, s3_path: str, private: bool = False) -> None:
ContentType="application/json",
**extra_args,
)


async def upload_gzip_string_async(client: Any, s: str, s3_path: str, private: bool = False) -> None:
"""Upload compressed dictionary to S3 and return its URL."""
body_gzip = gzip.compress(s.encode())

bucket, key = s3_utils.s3_bucket_key(s3_path)

assert not private, "r2 does not support private files yet"
extra_args = {}

async with R2_UPLOAD_SEMAPHORE:
await client.put_object(
Bucket=bucket,
Body=body_gzip,
Key=key,
ContentEncoding="gzip",
ContentType="application/json",
**extra_args,
)
60 changes: 50 additions & 10 deletions apps/chart_sync/admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import cache
from typing import Any, Dict, List, Optional

import aiohttp
import requests
import structlog
from requests.adapters import HTTPAdapter, Retry
Expand All @@ -25,7 +26,9 @@ def is_502_error(exception):


class AdminAPI(object):
def __init__(self, owid_env: OWIDEnv, grapher_user_id: Optional[int] = GRAPHER_USER_ID):
def __init__(
self, owid_env: OWIDEnv, grapher_user_id: Optional[int] = GRAPHER_USER_ID
):
self.owid_env = owid_env
self.session_id = create_session_id(owid_env, grapher_user_id)

Expand Down Expand Up @@ -64,7 +67,9 @@ def create_chart(self, chart_config: dict, user_id: Optional[int] = None) -> dic
raise AdminAPIError({"error": js["error"], "chart_config": chart_config})
return js

def update_chart(self, chart_id: int, chart_config: dict, user_id: Optional[int] = None) -> dict:
def update_chart(
self, chart_id: int, chart_config: dict, user_id: Optional[int] = None
) -> dict:
resp = requests.put(
f"{self.owid_env.admin_api}/charts/{chart_id}",
cookies={"sessionid": self._get_session_id(user_id)},
Expand All @@ -75,7 +80,9 @@ def update_chart(self, chart_id: int, chart_config: dict, user_id: Optional[int]
raise AdminAPIError({"error": js["error"], "chart_config": chart_config})
return js

def set_tags(self, chart_id: int, tags: List[Dict[str, Any]], user_id: Optional[int] = None) -> dict:
def set_tags(
self, chart_id: int, tags: List[Dict[str, Any]], user_id: Optional[int] = None
) -> dict:
resp = requests.post(
f"{self.owid_env.admin_api}/charts/{chart_id}/setTags",
cookies={"sessionid": self._get_session_id(user_id)},
Expand All @@ -86,7 +93,9 @@ def set_tags(self, chart_id: int, tags: List[Dict[str, Any]], user_id: Optional[
raise AdminAPIError({"error": js["error"], "tags": tags})
return js

def put_grapher_config(self, variable_id: int, grapher_config: Dict[str, Any]) -> dict:
def put_grapher_config(
self, variable_id: int, grapher_config: Dict[str, Any]
) -> dict:
# If schema is missing, use the default one
grapher_config.setdefault("$schema", DEFAULT_GRAPHER_SCHEMA)

Expand All @@ -98,9 +107,31 @@ def put_grapher_config(self, variable_id: int, grapher_config: Dict[str, Any]) -
)
js = self._json_from_response(resp)
if not js["success"]:
raise AdminAPIError({"error": js["error"], "variable_id": variable_id, "grapher_config": grapher_config})
raise AdminAPIError(
{
"error": js["error"],
"variable_id": variable_id,
"grapher_config": grapher_config,
}
)
return js

async def put_grapher_config(
self, variable_id: int, grapher_config: Dict[str, Any]
) -> dict:
async with aiohttp.ClientSession(
cookies={"sessionid": self.session_id}
) as session:
async with session.put(
self.base_url + f"/admin/api/variables/{variable_id}/grapherConfigETL",
json=grapher_config,
) as resp:
# TODO: make _json_from_response async
js = await resp.json()
assert js["success"]
return js

# TODO: make it async
def delete_grapher_config(self, variable_id: int) -> dict:
resp = requests.delete(
self.owid_env.admin_api + f"/variables/{variable_id}/grapherConfigETL",
Expand All @@ -111,7 +142,9 @@ def delete_grapher_config(self, variable_id: int) -> dict:
raise AdminAPIError({"error": js["error"], "variable_id": variable_id})
return js

def put_mdim_config(self, slug: str, mdim_config: dict, user_id: Optional[int] = None) -> dict:
def put_mdim_config(
self, slug: str, mdim_config: dict, user_id: Optional[int] = None
) -> dict:
# Retry in case we're restarting Admin on staging server
resp = requests_with_retry().put(
self.owid_env.admin_api + f"/multi-dim/{slug}",
Expand All @@ -120,7 +153,9 @@ def put_mdim_config(self, slug: str, mdim_config: dict, user_id: Optional[int] =
)
js = self._json_from_response(resp)
if not js["success"]:
raise AdminAPIError({"error": js["error"], "slug": slug, "mdim_config": mdim_config})
raise AdminAPIError(
{"error": js["error"], "slug": slug, "mdim_config": mdim_config}
)
return js


Expand Down Expand Up @@ -153,15 +188,19 @@ def _generate_random_string(length=32) -> str:
return result_str


def _create_user_session(session: Session, user_email: str, expiration_seconds=3600) -> str:
def _create_user_session(
session: Session, user_email: str, expiration_seconds=3600
) -> str:
"""Create a new short-lived session for given user and return its session id."""
# Generate a random string
session_key = _generate_random_string()

json_str = json.dumps({"user_email": user_email})

# Base64 encode
session_data = base64.b64encode(("prefix:" + json_str).encode("utf-8")).decode("utf-8")
session_data = base64.b64encode(("prefix:" + json_str).encode("utf-8")).decode(
"utf-8"
)

query = text(
"""
Expand All @@ -174,7 +213,8 @@ def _create_user_session(session: Session, user_email: str, expiration_seconds=3
params={
"session_key": session_key,
"session_data": session_data,
"expire_date": dt.datetime.utcnow() + dt.timedelta(seconds=expiration_seconds),
"expire_date": dt.datetime.utcnow()
+ dt.timedelta(seconds=expiration_seconds),
},
)

Expand Down
10 changes: 1 addition & 9 deletions etl/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,9 @@ def main_cli(

# make everything single threaded, useful for debugging
if not use_threads:
config.GRAPHER_INSERT_WORKERS = 1
config.DIRTY_STEPS_WORKERS = 1
workers = 1

# GRAPHER_INSERT_WORKERS should be split among workers
if workers > 1:
config.GRAPHER_INSERT_WORKERS = config.GRAPHER_INSERT_WORKERS // workers

kwargs = dict(
steps=steps,
dry_run=dry_run,
Expand All @@ -223,7 +218,6 @@ def main_cli(
for _ in runs:
if ipdb:
config.IPDB_ENABLED = True
config.GRAPHER_INSERT_WORKERS = 1
config.DIRTY_STEPS_WORKERS = 1
kwargs["workers"] = 1
with launch_ipdb_on_exception():
Expand Down Expand Up @@ -384,9 +378,7 @@ def run_dag(
)
return exec_steps(steps, strict=strict)
else:
print(
f"--- Running {len(steps)} steps with {workers} processes ({config.GRAPHER_INSERT_WORKERS} threads each):"
)
print(f"--- Running {len(steps)} steps with {workers} processes:")
return exec_steps_parallel(steps, workers, dag=dag, strict=strict)


Expand Down
4 changes: 0 additions & 4 deletions etl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ def variable_metadata_url(variable_id):
# because we're making a lot of HTTP requests
DIRTY_STEPS_WORKERS = int(env.get("DIRTY_STEPS_WORKERS", 5))

# number of workers for grapher inserts to DB, this is for all processes, so if
# --workers is higher than 1, this will be divided among them
GRAPHER_INSERT_WORKERS = int(env.get("GRAPHER_WORKERS", 40))

# only upsert indicators matching this filter, this is useful for fast development
# of data pages for a single indicator
GRAPHER_FILTER = env.get("GRAPHER_FILTER", None)
Expand Down
Loading

0 comments on commit d01f9d6

Please sign in to comment.