From 58db7928cbbab3530aac9bf67d33f93b9882b5ef Mon Sep 17 00:00:00 2001 From: Julius Busecke Date: Fri, 31 May 2024 12:42:09 -0400 Subject: [PATCH] Move bq code to seperate module (#42) * Move bq code to seperate module * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Add tests and pin pangeo-forge-esgf>0.3.0 --- .gitignore | 1 + leap_data_management_utils/bq_interfaces.py | 252 ++++++++++++++++++ leap_data_management_utils/cmip_transforms.py | 162 +---------- .../data_management_transforms.py | 93 +------ .../tests/test_cmip_catalog.py | 45 ++++ pyproject.toml | 10 +- 6 files changed, 307 insertions(+), 256 deletions(-) create mode 100644 leap_data_management_utils/bq_interfaces.py create mode 100644 leap_data_management_utils/tests/test_cmip_catalog.py diff --git a/.gitignore b/.gitignore index 8b276f9..df3bd95 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ cython_debug/ #.idea/ _version.py +.vscode/settings.json diff --git a/leap_data_management_utils/bq_interfaces.py b/leap_data_management_utils/bq_interfaces.py new file mode 100644 index 0000000..32921c0 --- /dev/null +++ b/leap_data_management_utils/bq_interfaces.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +from google.api_core.exceptions import NotFound +from google.cloud import bigquery +from tqdm.auto import tqdm + + +@dataclass +class BQInterface: + """Class to read/write information from BigQuery table + :param table_id: BigQuery table ID + :param client: BigQuery client object + :param result_limit: Maximum number of results to return from query + """ + + table_id: str + client: Optional[bigquery.client.Client] = None + result_limit: Optional[int] = 10 + schema: Optional[list] = None + + def __post_init__(self): + # TODO how do I handle the schema? This class could be used for any table, but for + # TODO this specific case I want to prescribe the schema + # for now just hardcode it + if not self.schema: + self.schema = [ + bigquery.SchemaField('dataset_id', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('dataset_url', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('timestamp', 'TIMESTAMP', mode='REQUIRED'), + ] + if self.client is None: + self.client = bigquery.Client() + + # check if table exists, otherwise create it + try: + self._get_table() + except NotFound: + self.create_table() + + def create_table(self) -> bigquery.table.Table: + """Create the table if it does not exist""" + print(f'Creating {self.table_id =}') + table = bigquery.Table(self.table_id, schema=self.schema) + self.client.create_table(table) # Make an API request. + + def _get_table(self) -> bigquery.table.Table: + """Get the table object""" + return self.client.get_table(self.table_id) + + def insert(self, fields: dict = {}): + timestamp = datetime.now().isoformat() + + rows_to_insert = [ + fields | {'timestamp': timestamp} # timestamp is always overridden + ] + + errors = self.client.insert_rows_json(self._get_table(), rows_to_insert) + if errors: + raise RuntimeError(f'Error inserting row: {errors}') + + def catalog_insert(self, dataset_id: str, dataset_url: str, extra_fields: dict = {}): + rows_to_insert = [ + { + 'dataset_id': dataset_id, + 'dataset_url': dataset_url, + } + | extra_fields + ] + self.insert(rows_to_insert) + + def _get_query_job(self, query: str) -> bigquery.job.query.QueryJob: + return self.client.query(query) + + def get_all(self) -> list[bigquery.table.Row]: + """Get all rows in the table""" + query = f""" + SELECT * FROM {self.table_id}; + """ + results = self._get_query_job(query) + return results.to_dataframe() + + def get_latest(self) -> list[bigquery.table.Row]: + """Get the latest row for all iids in the table""" + # adopted from https://stackoverflow.com/a/1313293 + query = f""" + WITH ranked_iids AS ( + SELECT i.*, ROW_NUMBER() OVER (PARTITION BY instance_id ORDER BY timestamp DESC) AS rn + FROM {self.table_id} AS i + ) + SELECT * FROM ranked_iids WHERE rn = 1; + """ + results = self._get_query_job(query) + return results.to_dataframe().drop(columns=['rn']) + + +@dataclass +class IIDEntry: + """Single row/entry for an iid + :param iid: CMIP6 instance id + :param store: URL to zarr store + """ + + iid: str + store: str # TODO: should this allow other objects? + retracted: bool + tests_passed: bool + + # Check if the iid conforms to a schema + def __post_init__(self): + schema = 'mip_era.activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.version' + facets = self.iid.split('.') + if len(facets) != len(schema.split('.')): + raise ValueError(f'IID does not conform to CMIP6 {schema =}. Got {self.iid =}') + assert self.store.startswith('gs://') + assert self.retracted in [True, False] + assert self.tests_passed in [True, False] + + # TODO: Check each facet with the controlled CMIP vocabulary + + # TODO Check store validity? + + +@dataclass +class IIDResult: + """Class to handle the results pertaining to a single IID.""" + + results: bigquery.table.RowIterator + iid: str + + def __post_init__(self): + if self.results.total_rows > 0: + self.exists = True + self.rows = [r for r in self.results] + self.latest_row = self.rows[0] + else: + self.exists = False + + +class CMIPBQInterface(BQInterface): + """Class to read/write information from BigQuery table + :param table_id: BigQuery table ID + :param client: BigQuery client object + :param result_limit: Maximum number of results to return from query + """ + + def __post_init__(self): + # TODO how do I handle the schema? This class could be used for any table, but for + # TODO this specific case I want to prescribe the schema + # for now just hardcode it + if not self.schema: + self.schema = [ + bigquery.SchemaField('instance_id', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('store', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('timestamp', 'TIMESTAMP', mode='REQUIRED'), + bigquery.SchemaField('retracted', 'BOOL', mode='REQUIRED'), + bigquery.SchemaField('tests_passed', 'BOOL', mode='REQUIRED'), + ] + super().__post_init__() + + def _get_timestamp(self) -> str: + """Get the current timestamp""" + return datetime.datetime.utcnow().isoformat() + + def insert_iid(self, IID_entry): + """Insert a row into the table for a given IID_entry object""" + fields = { + 'instance_id': IID_entry.iid, + 'store': IID_entry.store, + 'retracted': IID_entry.retracted, + 'tests_passed': IID_entry.tests_passed, + 'timestamp': self._get_timestamp(), + } + self.insert(fields) + + def insert_multiple_iids(self, IID_entries: list[IIDEntry]): + """Insert multiple rows into the table for a given list of IID_entry objects""" + # FIXME This repeats a bunch of code from the parent class .insert() method + timestamp = self._get_timestamp() + rows_to_insert = [ + { + 'instance_id': IID_entry.iid, + 'store': IID_entry.store, + 'retracted': IID_entry.retracted, + 'tests_passed': IID_entry.tests_passed, + 'timestamp': timestamp, + } + for IID_entry in IID_entries + ] + errors = self.client.insert_rows_json(self._get_table(), rows_to_insert) + if errors: + raise RuntimeError(f'Error inserting row: {errors}') + + def _get_iid_results(self, iid: str) -> IIDResult: + # keep this in case I ever need the row index again... + # query = f""" + # WITH table_with_index AS (SELECT *, ROW_NUMBER() OVER ()-1 as row_index FROM `{self.table_id}`) + # SELECT * + # FROM `table_with_index` + # WHERE instance_id='{iid}' + # """ + """Get the full result object for a given iid""" + query = f""" + SELECT * + FROM `{self.table_id}` + WHERE instance_id='{iid}' + ORDER BY timestamp DESC + LIMIT {self.result_limit} + """ + results = self._get_query_job( + query + ).result() # TODO: `.result()` is waiting for the query. Should I do this here? + return IIDResult(results, iid) + + def iid_exists(self, iid: str) -> bool: + """Check if iid exists in the table""" + return self._get_iid_results(iid).exists + + def _iid_list_exists_batch(self, iids: list[str]) -> list[str]: + """More efficient way to check if a list of iids exists in the table + Passes the entire list to a single SQL query. + Returns a list of iids that exist in the table + ``` + """ + if len(iids) > 10000: + raise ValueError('List of iids is too long. Please work in batches.') + + # source: https://stackoverflow.com/questions/26441928/how-do-i-check-if-multiple-values-exists-in-database + query = f""" + SELECT instance_id, store + FROM {self.table_id} + WHERE instance_id IN ({",".join([f"'{iid}'" for iid in iids])}) + """ + results = self._get_query_job(query).result() + # this is a full row iterator, for now just return the iids + return list(set([r['instance_id'] for r in results])) + + def iid_list_exists(self, iids: list[str]) -> list[str]: + """More efficient way to check if a list of iids exists in the table + Passes the entire list in batches into SQL querys for maximum efficiency. + Returns a list of iids that exist in the table + """ + + # make batches of the input, since bq cannot handle more than 10k elements here + iids_in_bq = [] + batchsize = 10000 + iid_batches = [iids[i : i + batchsize] for i in range(0, len(iids), batchsize)] + for iids_batch in tqdm(iid_batches): + iids_in_bq_batch = self._iid_list_exists_batch(iids_batch) + iids_in_bq.extend(iids_in_bq_batch) + return iids_in_bq diff --git a/leap_data_management_utils/cmip_transforms.py b/leap_data_management_utils/cmip_transforms.py index 7b8f7c9..4dc5d49 100644 --- a/leap_data_management_utils/cmip_transforms.py +++ b/leap_data_management_utils/cmip_transforms.py @@ -2,7 +2,6 @@ utils that are specific to CMIP data management """ -import datetime import logging import warnings from dataclasses import dataclass @@ -16,12 +15,10 @@ even_divisor_algo, iterative_ratio_increase_algo, ) -from google.cloud import bigquery from pangeo_forge_recipes.transforms import Indexed, T -from tqdm.auto import tqdm +from leap_data_management_utils.bq_interfaces import CMIPBQInterface, IIDEntry from leap_data_management_utils.cmip_testing import test_all -from leap_data_management_utils.data_management_transforms import BQInterface # TODO: I am not sure the chunking function belongs here, but it clutters the recipe and I did not want # To open a whole file for this. @@ -93,163 +90,6 @@ def dynamic_chunking_func(ds: xr.Dataset) -> dict[str, int]: return target_chunks -@dataclass -class IIDEntry: - """Single row/entry for an iid - :param iid: CMIP6 instance id - :param store: URL to zarr store - """ - - iid: str - store: str # TODO: should this allow other objects? - retracted: bool - tests_passed: bool - - # Check if the iid conforms to a schema - def __post_init__(self): - schema = 'mip_era.activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.version' - facets = self.iid.split('.') - if len(facets) != len(schema.split('.')): - raise ValueError(f'IID does not conform to CMIP6 {schema =}. Got {self.iid =}') - assert self.store.startswith('gs://') - assert self.retracted in [True, False] - assert self.tests_passed in [True, False] - - # TODO: Check each facet with the controlled CMIP vocabulary - - # TODO Check store validity? - - -@dataclass -class IIDResult: - """Class to handle the results pertaining to a single IID.""" - - results: bigquery.table.RowIterator - iid: str - - def __post_init__(self): - if self.results.total_rows > 0: - self.exists = True - self.rows = [r for r in self.results] - self.latest_row = self.rows[0] - else: - self.exists = False - - -class CMIPBQInterface(BQInterface): - """Class to read/write information from BigQuery table - :param table_id: BigQuery table ID - :param client: BigQuery client object - :param result_limit: Maximum number of results to return from query - """ - - def __post_init__(self): - # TODO how do I handle the schema? This class could be used for any table, but for - # TODO this specific case I want to prescribe the schema - # for now just hardcode it - if not self.schema: - self.schema = [ - bigquery.SchemaField('instance_id', 'STRING', mode='REQUIRED'), - bigquery.SchemaField('store', 'STRING', mode='REQUIRED'), - bigquery.SchemaField('timestamp', 'TIMESTAMP', mode='REQUIRED'), - bigquery.SchemaField('retracted', 'BOOL', mode='REQUIRED'), - bigquery.SchemaField('tests_passed', 'BOOL', mode='REQUIRED'), - ] - super().__post_init__() - - def _get_timestamp(self) -> str: - """Get the current timestamp""" - return datetime.datetime.utcnow().isoformat() - - def insert_iid(self, IID_entry): - """Insert a row into the table for a given IID_entry object""" - fields = { - 'instance_id': IID_entry.iid, - 'store': IID_entry.store, - 'retracted': IID_entry.retracted, - 'tests_passed': IID_entry.tests_passed, - 'timestamp': self._get_timestamp(), - } - self.insert(fields) - - def insert_multiple_iids(self, IID_entries: list[IIDEntry]): - """Insert multiple rows into the table for a given list of IID_entry objects""" - # FIXME This repeats a bunch of code from the parent class .insert() method - timestamp = self._get_timestamp() - rows_to_insert = [ - { - 'instance_id': IID_entry.iid, - 'store': IID_entry.store, - 'retracted': IID_entry.retracted, - 'tests_passed': IID_entry.tests_passed, - 'timestamp': timestamp, - } - for IID_entry in IID_entries - ] - errors = self.client.insert_rows_json(self._get_table(), rows_to_insert) - if errors: - raise RuntimeError(f'Error inserting row: {errors}') - - def _get_iid_results(self, iid: str) -> IIDResult: - # keep this in case I ever need the row index again... - # query = f""" - # WITH table_with_index AS (SELECT *, ROW_NUMBER() OVER ()-1 as row_index FROM `{self.table_id}`) - # SELECT * - # FROM `table_with_index` - # WHERE instance_id='{iid}' - # """ - """Get the full result object for a given iid""" - query = f""" - SELECT * - FROM `{self.table_id}` - WHERE instance_id='{iid}' - ORDER BY timestamp DESC - LIMIT {self.result_limit} - """ - results = self._get_query_job( - query - ).result() # TODO: `.result()` is waiting for the query. Should I do this here? - return IIDResult(results, iid) - - def iid_exists(self, iid: str) -> bool: - """Check if iid exists in the table""" - return self._get_iid_results(iid).exists - - def _iid_list_exists_batch(self, iids: list[str]) -> list[str]: - """More efficient way to check if a list of iids exists in the table - Passes the entire list to a single SQL query. - Returns a list of iids that exist in the table - ``` - """ - if len(iids) > 10000: - raise ValueError('List of iids is too long. Please work in batches.') - - # source: https://stackoverflow.com/questions/26441928/how-do-i-check-if-multiple-values-exists-in-database - query = f""" - SELECT instance_id, store - FROM {self.table_id} - WHERE instance_id IN ({",".join([f"'{iid}'" for iid in iids])}) - """ - results = self._get_query_job(query).result() - # this is a full row iterator, for now just return the iids - return list(set([r['instance_id'] for r in results])) - - def iid_list_exists(self, iids: list[str]) -> list[str]: - """More efficient way to check if a list of iids exists in the table - Passes the entire list in batches into SQL querys for maximum efficiency. - Returns a list of iids that exist in the table - """ - - # make batches of the input, since bq cannot handle more than 10k elements here - iids_in_bq = [] - batchsize = 10000 - iid_batches = [iids[i : i + batchsize] for i in range(0, len(iids), batchsize)] - for iids_batch in tqdm(iid_batches): - iids_in_bq_batch = self._iid_list_exists_batch(iids_batch) - iids_in_bq.extend(iids_in_bq_batch) - return iids_in_bq - - # ---------------------------------------------------------------------------------------------- # apache Beam stages # ---------------------------------------------------------------------------------------------- diff --git a/leap_data_management_utils/data_management_transforms.py b/leap_data_management_utils/data_management_transforms.py index b953a88..6362139 100644 --- a/leap_data_management_utils/data_management_transforms.py +++ b/leap_data_management_utils/data_management_transforms.py @@ -5,14 +5,13 @@ import subprocess from dataclasses import dataclass from datetime import datetime, timezone -from typing import Optional import apache_beam as beam import zarr -from google.api_core.exceptions import NotFound -from google.cloud import bigquery from ruamel.yaml import YAML +from leap_data_management_utils.bq_interfaces import BQInterface + yaml = YAML(typ='safe') @@ -85,94 +84,6 @@ def get_catalog_store_urls(catalog_yaml_path: str) -> dict[str, str]: return {d['id']: d['url'] for d in catalog_meta['stores']} -@dataclass -class BQInterface: - """Class to read/write information from BigQuery table - :param table_id: BigQuery table ID - :param client: BigQuery client object - :param result_limit: Maximum number of results to return from query - """ - - table_id: str - client: Optional[bigquery.client.Client] = None - result_limit: Optional[int] = 10 - schema: Optional[list] = None - - def __post_init__(self): - # TODO how do I handle the schema? This class could be used for any table, but for - # TODO this specific case I want to prescribe the schema - # for now just hardcode it - if not self.schema: - self.schema = [ - bigquery.SchemaField('dataset_id', 'STRING', mode='REQUIRED'), - bigquery.SchemaField('dataset_url', 'STRING', mode='REQUIRED'), - bigquery.SchemaField('timestamp', 'TIMESTAMP', mode='REQUIRED'), - ] - if self.client is None: - self.client = bigquery.Client() - - # check if table exists, otherwise create it - try: - self._get_table() - except NotFound: - self.create_table() - - def create_table(self) -> bigquery.table.Table: - """Create the table if it does not exist""" - print(f'Creating {self.table_id =}') - table = bigquery.Table(self.table_id, schema=self.schema) - self.client.create_table(table) # Make an API request. - - def _get_table(self) -> bigquery.table.Table: - """Get the table object""" - return self.client.get_table(self.table_id) - - def insert(self, fields: dict = {}): - timestamp = datetime.now().isoformat() - - rows_to_insert = [ - fields | {'timestamp': timestamp} # timestamp is always overridden - ] - - errors = self.client.insert_rows_json(self._get_table(), rows_to_insert) - if errors: - raise RuntimeError(f'Error inserting row: {errors}') - - def catalog_insert(self, dataset_id: str, dataset_url: str, extra_fields: dict = {}): - rows_to_insert = [ - { - 'dataset_id': dataset_id, - 'dataset_url': dataset_url, - } - | extra_fields - ] - self.insert(rows_to_insert) - - def _get_query_job(self, query: str) -> bigquery.job.query.QueryJob: - return self.client.query(query) - - def get_all(self) -> list[bigquery.table.Row]: - """Get all rows in the table""" - query = f""" - SELECT * FROM {self.table_id}; - """ - results = self._get_query_job(query) - return results.to_dataframe() - - def get_latest(self) -> list[bigquery.table.Row]: - """Get the latest row for all iids in the table""" - # adopted from https://stackoverflow.com/a/1313293 - query = f""" - WITH ranked_iids AS ( - SELECT i.*, ROW_NUMBER() OVER (PARTITION BY instance_id ORDER BY timestamp DESC) AS rn - FROM {self.table_id} AS i - ) - SELECT * FROM ranked_iids WHERE rn = 1; - """ - results = self._get_query_job(query) - return results.to_dataframe().drop(columns=['rn']) - - # ---------------------------------------------------------------------------------------------- # apache Beam stages # ---------------------------------------------------------------------------------------------- diff --git a/leap_data_management_utils/tests/test_cmip_catalog.py b/leap_data_management_utils/tests/test_cmip_catalog.py new file mode 100644 index 0000000..28549ab --- /dev/null +++ b/leap_data_management_utils/tests/test_cmip_catalog.py @@ -0,0 +1,45 @@ +import pandas as pd + +from leap_data_management_utils.cmip_catalog import bq_df_to_intake_esm + + +def test_bq_df_to_intake_esm(): + bq_df = pd.DataFrame( + { + 'instance_id': [ + 'CMIP6.AerChemMIP.MIROC.MIROC6.piClim-NTCF.sub-r1i1p1f1.Amon.tasmin.gn.v20190807', # modified to test the sub_experiment split + 'CMIP6.AerChemMIP.MIROC.MIROC6.piClim-OC.r1i1p1f1.Amon.rlut.gn.v20190807', + ], + 'store': [ + 'gs://cmip6/CMIP6/AerChemMIP/MIROC/MIROC6/piClim-NTCF/r1i1p1f1/Amon/tasmin/gn/v20190807/', + 'gs://cmip6/CMIP6/AerChemMIP/MIROC/MIROC6/piClim-OC/r1i1p1f1/Amon/rlut/gn/v20190807/', + ], + 'retracted': [False, False], + 'tests_passed': [True, True], + } + ) + intake_df = bq_df_to_intake_esm(bq_df) + for c in intake_df.columns: + print(c) + print(intake_df[c].to_list()) + + expected_intake_df = pd.DataFrame( + { + 'activity_id': ['AerChemMIP', 'AerChemMIP'], + 'institution_id': ['MIROC', 'MIROC'], + 'source_id': ['MIROC6', 'MIROC6'], + 'experiment_id': ['piClim-NTCF', 'piClim-OC'], + 'member_id': ['sub-r1i1p1f1', 'r1i1p1f1'], + 'table_id': ['Amon', 'Amon'], + 'variable_id': ['tasmin', 'rlut'], + 'grid_label': ['gn', 'gn'], + 'sub_experiment_id': ['sub', 'none'], + 'variant_label': ['r1i1p1f1', 'r1i1p1f1'], + 'version': ['v20190807', 'v20190807'], + 'zstore': [ + 'gs://cmip6/CMIP6/AerChemMIP/MIROC/MIROC6/piClim-NTCF/r1i1p1f1/Amon/tasmin/gn/v20190807/', + 'gs://cmip6/CMIP6/AerChemMIP/MIROC/MIROC6/piClim-OC/r1i1p1f1/Amon/rlut/gn/v20190807/', + ], + } + ) + assert intake_df.equals(expected_intake_df) diff --git a/pyproject.toml b/pyproject.toml index 3046d29..420f596 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,16 +31,18 @@ dependencies = [ [project.optional-dependencies] -pangeo-forge=[ +bigquery=[ "tqdm", - "db_dtypes", "google-api-core", "google-cloud-bigquery", - "pangeo-forge-esgf", + "db_dtypes", + "pangeo-forge-esgf>0.3.0", +] +pangeo-forge=[ "pangeo-forge-recipes", "apache-beam", "dynamic-chunks", - + "leap-data-management-utils[bigquery]", ] catalog = [ "aiohttp",