diff --git a/.bumpversion.cfg b/.bumpversion.cfg index ca30bd520..0535dd429 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.4.4 +current_version = 0.4.5 commit = False tag = False diff --git a/dev/local/setup.cfg b/dev/local/setup.cfg index ac1338ca9..005f1d30b 100644 --- a/dev/local/setup.cfg +++ b/dev/local/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = Delphi Development -version = 0.4.4 +version = 0.4.5 [options] packages = diff --git a/docs/api/covidcast-signals/covid-act-now.md b/docs/api/covidcast-signals/covid-act-now.md index f56ecf015..72a5a9a20 100644 --- a/docs/api/covidcast-signals/covid-act-now.md +++ b/docs/api/covidcast-signals/covid-act-now.md @@ -1,6 +1,6 @@ --- title: COVID Act Now -parent: Data Sources and Signals +parent: Inactive Signals grand_parent: COVIDcast Main Endpoint --- diff --git a/docs/api/covidcast-signals/quidel-inactive.md b/docs/api/covidcast-signals/quidel-inactive.md new file mode 100644 index 000000000..19b98ee6a --- /dev/null +++ b/docs/api/covidcast-signals/quidel-inactive.md @@ -0,0 +1,52 @@ +--- +title: Quidel +parent: Inactive Signals +grand_parent: COVIDcast Main Endpoint +--- + +# Quidel +{: .no_toc} + +* **Source name:** `quidel` + +## Table of Contents +{: .no_toc .text-delta} + +1. TOC +{:toc} + + +## COVID-19 Tests +These signals are still active. Documentation is available on the [Quidel page](quidel.md). + +## Flu Tests + +* **Earliest issue available:** April 29, 2020 +* **Last issued:** May 19, 2020 +* **Number of data revisions since May 19, 2020:** 0 +* **Date of last change:** Never +* **Available for:** msa, state (see [geography coding docs](../covidcast_geography.md)) +* **Time type:** day (see [date format docs](../covidcast_times.md)) + +Data source based on flu lab tests, provided to us by Quidel, Inc. When a +patient (whether at a doctor’s office, clinic, or hospital) has COVID-like +symptoms, doctors may perform a flu test to rule out seasonal flu (influenza), +because these two diseases have similar symptoms. Using this lab test data, we +estimate the total number of flu tests per medical device (a measure of testing +frequency), and the percentage of flu tests that are *negative* (since ruling +out flu leaves open another cause---possibly covid---for the patient's +symptoms), in a given location, on a given day. + +The number of flu tests conducted in individual counties can be quite small, so +we do not report these signals at the county level. + +The flu test data is no longer updated as of May 19, 2020, as the number of flu +tests conducted during the summer (outside of the normal flu season) is quite +small. The data may be updated again when the Winter 2020 flu season begins. + +| Signal | Description | +| --- | --- | +| `raw_pct_negative` | The percentage of flu tests that are negative, suggesting the patient's illness has another cause, possibly COVID-19
**Earliest date available:** 2020-01-31 | +| `smoothed_pct_negative` | Same as above, but smoothed in time
**Earliest date available:** 2020-01-31 | +| `raw_tests_per_device` | The average number of flu tests conducted by each testing device; measures volume of testing
**Earliest date available:** 2020-01-31 | +| `smoothed_tests_per_device` | Same as above, but smoothed in time
**Earliest date available:** 2020-01-31 | diff --git a/docs/api/covidcast-signals/quidel.md b/docs/api/covidcast-signals/quidel.md index 5f003904f..54296dbb3 100644 --- a/docs/api/covidcast-signals/quidel.md +++ b/docs/api/covidcast-signals/quidel.md @@ -157,34 +157,5 @@ not enough samples can be filled in from the parent state for smoothed signals s no data is reported for that area on that day; an API query for all reported geographic areas on that day will not include it. -## Flu Tests - -* **Earliest issue available:** April 29, 2020 -* **Last issued:** May 19, 2020 -* **Number of data revisions since May 19, 2020:** 0 -* **Date of last change:** Never -* **Available for:** msa, state (see [geography coding docs](../covidcast_geography.md)) -* **Time type:** day (see [date format docs](../covidcast_times.md)) - -Data source based on flu lab tests, provided to us by Quidel, Inc. When a -patient (whether at a doctor’s office, clinic, or hospital) has COVID-like -symptoms, doctors may perform a flu test to rule out seasonal flu (influenza), -because these two diseases have similar symptoms. Using this lab test data, we -estimate the total number of flu tests per medical device (a measure of testing -frequency), and the percentage of flu tests that are *negative* (since ruling -out flu leaves open another cause---possibly covid---for the patient's -symptoms), in a given location, on a given day. - -The number of flu tests conducted in individual counties can be quite small, so -we do not report these signals at the county level. - -The flu test data is no longer updated as of May 19, 2020, as the number of flu -tests conducted during the summer (outside of the normal flu season) is quite -small. The data may be updated again when the Winter 2020 flu season begins. - -| Signal | Description | -| --- | --- | -| `raw_pct_negative` | The percentage of flu tests that are negative, suggesting the patient's illness has another cause, possibly COVID-19
**Earliest date available:** 2020-01-31 | -| `smoothed_pct_negative` | Same as above, but smoothed in time
**Earliest date available:** 2020-01-31 | -| `raw_tests_per_device` | The average number of flu tests conducted by each testing device; measures volume of testing
**Earliest date available:** 2020-01-31 | -| `smoothed_tests_per_device` | Same as above, but smoothed in time
**Earliest date available:** 2020-01-31 | +## Flu Tests (inactive) +These signals were updated until May 19, 2020. Documentation is still available on the [inactive Quidel page](quidel-inactive.md). \ No newline at end of file diff --git a/docs/api/covidcast-signals/safegraph.md b/docs/api/covidcast-signals/safegraph.md deleted file mode 100644 index 3da66ed79..000000000 --- a/docs/api/covidcast-signals/safegraph.md +++ /dev/null @@ -1,87 +0,0 @@ ---- -title: SafeGraph -parent: Data Sources and Signals -grand_parent: COVIDcast Main Endpoint ---- - -# SafeGraph -{: .no_toc} -* **Source name:** `safegraph` -* **Available for:** county, MSA, HRR, state (see [geography coding docs](../covidcast_geography.md)) -* **Time type:** day (see [date format docs](../covidcast_times.md)) -* **License:** [CC BY](../covidcast_licensing.md#creative-commons-attribution) - -This data source uses data reported by [SafeGraph](https://www.safegraph.com/) -using anonymized location data from mobile phones. SafeGraph provides several -different datasets to eligible researchers. We currently surface signals from one such -dataset. - -## Table of contents -{: .no_toc .text-delta} - -1. TOC -{:toc} - -## SafeGraph Social Distancing Metrics (Inactive) - -These signals were updated until April 19th, 2021, when Safegraph ceased updating the dataset. -Documentation for these signals is still available on the [inactive Safegraph page](safegraph-inactive.md). - -## SafeGraph Weekly Patterns - -* **Earliest issue available:** November 30, 2020 -* **Number of data revisions since June 23, 2020:** 0 -* **Date of last change:** never - -Data source based on [Weekly -Patterns](https://docs.safegraph.com/docs/weekly-patterns) dataset. SafeGraph -provides this data for different points of interest -([POIs](https://docs.safegraph.com/v4.0/docs#section-core-places)) considering -individual census block groups, using differential privacy to protect individual -people's data privacy. - -Delphi gathers the number of daily visits to POIs of certain types(bars, -restaurants, etc.) from SafeGraph's Weekly Patterns data at the 5-digit ZipCode -level, then aggregates and reports these features to the county, MSA, HRR, and -state levels. The aggregated data is freely available through the COVIDcast API. - -For precise definitions of the quantities below, consult the [SafeGraph Weekly -Patterns documentation](https://docs.safegraph.com/docs/weekly-patterns). - -| Signal | Description | -| --- | --- | -| `bars_visit_num` | The number of daily visits made by those with SafeGraph's apps to bar-related POIs in a certain region
**Earliest date available:** 01/01/2019 | -| `bars_visit_prop` | The number of daily visits made by those with SafeGraph's apps to bar-related POIs in a certain region, per 100,000 population
**Earliest date available:** 01/01/2019 | -| `restaurants_visit_num` | The number of daily visits made by those with SafeGraph's apps to restaurant-related POIs in a certain region
**Earliest date available:** 01/01/2019 | -| `restaurants_visit_prop` | The number of daily visits made by those with SafeGraph's apps to restaurant-related POIs in a certain region, per 100,000 population
**Earliest date available:** 01/01/2019 | - -SafeGraph delivers the number of daily visits to U.S. POIs, the details of which -are described in the [Places -Manual](https://readme.safegraph.com/docs/places-manual#section-placekey) -dataset. Delphi aggregates the number of visits to certain types of places, -such as bars (places with [NAICS code = -722410](https://www.census.gov/cgi-bin/sssd/naics/naicsrch?input=722410&search=2017+NAICS+Search&search=2017)) -and restaurants (places with [NAICS code = -722511](https://www.census.gov/cgi-bin/sssd/naics/naicsrch)). For example, -Adagio Teas is coded as a bar because it serves alcohol, while Napkin Burger is -considered to be a full-service restaurant. More information on NAICS codes is -available from the [US Census Bureau: North American Industry Classification -System](https://www.census.gov/eos/www/naics/index.html). - -The number of POIs coded as bars is much smaller than the number of POIs coded as restaurants. -SafeGraph's Weekly Patterns data consistently lacks data on bar visits for Alaska, Delaware, Maine, North Dakota, New Hampshire, South Dakota, Vermont, West Virginia, and Wyoming. -For certain dates, bar visits data is also missing for District of Columbia, Idaho and Washington. Restaurant visits data is available for all of the states, as well as the District of Columbia and Puerto Rico. - -### Lag - -SafeGraph provides newly updated data for the previous week every Wednesday, -meaning estimates for a specific day are only available 3-9 days later. It may -take up to an additional day for SafeGraph's data to be ingested into the -COVIDcast API. - -## Limitations - -SafeGraph's Social Distancing Metrics and Weekly Patterns are based on mobile devices that are members of SafeGraph panels, which is not necessarily the same thing as measuring the general public. These counts do not represent absolute counts, and only count visits by members of the panel in that region. This can result in several biases: - -* **Geographic bias.** If some regions have a greater density of SafeGraph panel members as a percentage of the population than other regions, comparisons of metrics between regions may be biased. Regions with more SafeGraph panel members will appear to have more visits counted, even if the rate of visits in the general population is the same. -* **Demographic bias.** SafeGraph panels may not be representative of the local population as a whole. For example, [some research suggests](https://doi.org/10.1145/3442188.3445881) that "older and non-white voters are less likely to be captured by mobility data", so this data will not accurately reflect behavior in those populations. Since population demographics vary across the United States, this can also contribute to geographic biases. diff --git a/docs/api/covidcast-signals/usa-facts.md b/docs/api/covidcast-signals/usa-facts.md index fc68c2dea..e0160cfde 100644 --- a/docs/api/covidcast-signals/usa-facts.md +++ b/docs/api/covidcast-signals/usa-facts.md @@ -1,6 +1,6 @@ --- title: USAFacts Cases and Deaths -parent: Data Sources and Signals +parent: Inactive Signals grand_parent: COVIDcast Main Endpoint --- diff --git a/integrations/server/test_covidcast.py b/integrations/server/test_covidcast.py index 3de69f02c..bcca3b199 100644 --- a/integrations/server/test_covidcast.py +++ b/integrations/server/test_covidcast.py @@ -73,6 +73,18 @@ def _insert_placeholder_set_four(self): for i in [4, 5, 6] ] self._insert_rows(rows) + return rows + + def _insert_placeholder_set_five(self): + rows = [ + self._make_placeholder_row(time_value=2000_01_01, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03+i)[0] + for i in [1, 2, 3] + ] + [ + # different time_values, same issues + self._make_placeholder_row(time_value=2000_01_01+i-3, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03+i-3)[0] + for i in [4, 5, 6] + ] + self._insert_rows(rows) return rows def test_round_trip(self): @@ -237,6 +249,46 @@ def test_location_wildcard(self): 'message': 'success', }) + def test_time_values_wildcard(self): + """Select all time_values with a wildcard query.""" + + # insert placeholder data + rows = self._insert_placeholder_set_three() + expected_time_values = [ + self.expected_from_row(r) for r in rows[:3] + ] + + # make the request + response, _ = self.request_based_on_row(rows[0], time_values="*") + + self.maxDiff = None + # assert that the right data came back + self.assertEqual(response, { + 'result': 1, + 'epidata': expected_time_values, + 'message': 'success', + }) + + def test_issues_wildcard(self): + """Select all issues with a wildcard query.""" + + # insert placeholder data + rows = self._insert_placeholder_set_five() + expected_issues = [ + self.expected_from_row(r) for r in rows[:3] + ] + + # make the request + response, _ = self.request_based_on_row(rows[0], issues="*") + + self.maxDiff = None + # assert that the right data came back + self.assertEqual(response, { + 'result': 1, + 'epidata': expected_issues, + 'message': 'success', + }) + def test_signal_wildcard(self): """Select all signals with a wildcard query.""" diff --git a/src/acquisition/covid_hosp/common/database.py b/src/acquisition/covid_hosp/common/database.py index 4bdfb4222..8875828fa 100644 --- a/src/acquisition/covid_hosp/common/database.py +++ b/src/acquisition/covid_hosp/common/database.py @@ -19,6 +19,7 @@ class Database: def __init__(self, connection, table_name=None, + hhs_dataset_id=None, columns_and_types=None, key_columns=None, additional_fields=None): @@ -30,6 +31,8 @@ def __init__(self, An open connection to a database. table_name : str The name of the table which holds the dataset. + hhs_dataset_id : str + The 9-character healthdata.gov identifier for this dataset. columns_and_types : tuple[str, str, Callable] List of 3-tuples of (CSV header name, SQL column name, data type) for all the columns in the CSV file. @@ -40,6 +43,7 @@ def __init__(self, self.connection = connection self.table_name = table_name + self.hhs_dataset_id = hhs_dataset_id self.publication_col_name = "issue" if table_name == 'covid_hosp_state_timeseries' else \ 'publication_date' self.columns_and_types = { @@ -115,8 +119,8 @@ def contains_revision(self, revision): FROM `covid_hosp_meta` WHERE - `dataset_name` = %s AND `revision_timestamp` = %s - ''', (self.table_name, revision)) + `hhs_dataset_id` = %s AND `revision_timestamp` = %s + ''', (self.hhs_dataset_id, revision)) for (result,) in cursor: return bool(result) @@ -138,14 +142,15 @@ def insert_metadata(self, publication_date, revision, meta_json): INSERT INTO `covid_hosp_meta` ( `dataset_name`, + `hhs_dataset_id`, `publication_date`, `revision_timestamp`, `metadata_json`, `acquisition_datetime` ) VALUES - (%s, %s, %s, %s, NOW()) - ''', (self.table_name, publication_date, revision, meta_json)) + (%s, %s, %s, %s, %s, NOW()) + ''', (self.table_name, self.hhs_dataset_id, publication_date, revision, meta_json)) def insert_dataset(self, publication_date, dataframe): """Add a dataset to the database. @@ -232,7 +237,7 @@ def get_max_issue(self): from `covid_hosp_meta` WHERE - dataset_name = "{self.table_name}" + hhs_dataset_id = "{self.hhs_dataset_id}" ''') for (result,) in cursor: if result is not None: diff --git a/src/acquisition/covid_hosp/common/utils.py b/src/acquisition/covid_hosp/common/utils.py index 1ac414968..99a6b4f33 100644 --- a/src/acquisition/covid_hosp/common/utils.py +++ b/src/acquisition/covid_hosp/common/utils.py @@ -169,19 +169,20 @@ def update_dataset(database, network, newer_than=None, older_than=None): # download the dataset and add it to the database dataset = Utils.merge_by_key_cols([network.fetch_dataset(url) for url, _ in revisions], db.KEY_COLS) - # add metadata to the database using the last revision seen. - last_url, last_index = revisions[-1] - metadata_json = metadata.loc[last_index].reset_index().to_json() + # add metadata to the database + all_metadata = [] + for url, index in revisions: + all_metadata.append((url, metadata.loc[index].reset_index().to_json())) datasets.append(( issue_int, dataset, - last_url, - metadata_json + all_metadata )) with database.connect() as db: - for issue_int, dataset, last_url, metadata_json in datasets: + for issue_int, dataset, all_metadata in datasets: db.insert_dataset(issue_int, dataset) - db.insert_metadata(issue_int, last_url, metadata_json) + for url, metadata_json in all_metadata: + db.insert_metadata(issue_int, url, metadata_json) print(f'successfully acquired {len(dataset)} rows') # note that the transaction is committed by exiting the `with` block diff --git a/src/acquisition/covid_hosp/facility/database.py b/src/acquisition/covid_hosp/facility/database.py index 20bb74100..665256a4f 100644 --- a/src/acquisition/covid_hosp/facility/database.py +++ b/src/acquisition/covid_hosp/facility/database.py @@ -2,6 +2,7 @@ from delphi.epidata.acquisition.covid_hosp.common.database import Database as BaseDatabase from delphi.epidata.acquisition.covid_hosp.common.database import Columndef from delphi.epidata.acquisition.covid_hosp.common.utils import Utils +from delphi.epidata.acquisition.covid_hosp.facility.network import Network class Database(BaseDatabase): @@ -213,5 +214,6 @@ def __init__(self, *args, **kwargs): *args, **kwargs, table_name=Database.TABLE_NAME, + hhs_dataset_id=Network.DATASET_ID, key_columns=Database.KEY_COLS, columns_and_types=Database.ORDERED_CSV_COLUMNS) diff --git a/src/acquisition/covid_hosp/state_daily/database.py b/src/acquisition/covid_hosp/state_daily/database.py index 58eaf8190..6a8228994 100644 --- a/src/acquisition/covid_hosp/state_daily/database.py +++ b/src/acquisition/covid_hosp/state_daily/database.py @@ -2,6 +2,7 @@ from delphi.epidata.acquisition.covid_hosp.common.database import Database as BaseDatabase from delphi.epidata.acquisition.covid_hosp.common.database import Columndef from delphi.epidata.acquisition.covid_hosp.common.utils import Utils +from delphi.epidata.acquisition.covid_hosp.state_daily.network import Network class Database(BaseDatabase): @@ -223,6 +224,7 @@ def __init__(self, *args, **kwargs): *args, **kwargs, table_name=Database.TABLE_NAME, + hhs_dataset_id=Network.DATASET_ID, columns_and_types=Database.ORDERED_CSV_COLUMNS, key_columns=Database.KEY_COLS, additional_fields=[Columndef('D', 'record_type', None)]) diff --git a/src/acquisition/covid_hosp/state_timeseries/database.py b/src/acquisition/covid_hosp/state_timeseries/database.py index aa6ba7580..348d9fc0b 100644 --- a/src/acquisition/covid_hosp/state_timeseries/database.py +++ b/src/acquisition/covid_hosp/state_timeseries/database.py @@ -2,6 +2,7 @@ from delphi.epidata.acquisition.covid_hosp.common.database import Database as BaseDatabase from delphi.epidata.acquisition.covid_hosp.common.database import Columndef from delphi.epidata.acquisition.covid_hosp.common.utils import Utils +from delphi.epidata.acquisition.covid_hosp.state_timeseries.network import Network class Database(BaseDatabase): @@ -222,6 +223,7 @@ def __init__(self, *args, **kwargs): *args, **kwargs, table_name=Database.TABLE_NAME, + hhs_dataset_id=Network.DATASET_ID, columns_and_types=Database.ORDERED_CSV_COLUMNS, key_columns=Database.KEY_COLS, additional_fields=[Columndef('T', 'record_type', None)]) diff --git a/src/client/delphi_epidata.R b/src/client/delphi_epidata.R index 7d9f7f5e6..598bba814 100644 --- a/src/client/delphi_epidata.R +++ b/src/client/delphi_epidata.R @@ -15,7 +15,7 @@ Epidata <- (function() { # API base url BASE_URL <- 'https://delphi.cmu.edu/epidata/api.php' - client_version <- '0.4.4' + client_version <- '0.4.5' # Helper function to cast values and/or ranges to strings .listitem <- function(value) { diff --git a/src/client/delphi_epidata.js b/src/client/delphi_epidata.js index 978fa715b..e92b4abb9 100644 --- a/src/client/delphi_epidata.js +++ b/src/client/delphi_epidata.js @@ -22,7 +22,7 @@ } })(this, function (exports, fetchImpl, jQuery) { const BASE_URL = "https://delphi.cmu.edu/epidata/"; - const client_version = "0.4.4"; + const client_version = "0.4.5"; // Helper function to cast values and/or ranges to strings function _listitem(value) { diff --git a/src/client/packaging/npm/package.json b/src/client/packaging/npm/package.json index 2c5618498..b9e20df9d 100644 --- a/src/client/packaging/npm/package.json +++ b/src/client/packaging/npm/package.json @@ -2,7 +2,7 @@ "name": "delphi_epidata", "description": "Delphi Epidata API Client", "authors": "Delphi Group", - "version": "0.4.4", + "version": "0.4.5", "license": "MIT", "homepage": "https://github.com/cmu-delphi/delphi-epidata", "bugs": { diff --git a/src/client/packaging/pypi/delphi_epidata/__init__.py b/src/client/packaging/pypi/delphi_epidata/__init__.py index d53321b3f..9306b81b7 100644 --- a/src/client/packaging/pypi/delphi_epidata/__init__.py +++ b/src/client/packaging/pypi/delphi_epidata/__init__.py @@ -1,4 +1,4 @@ from .delphi_epidata import Epidata name = 'delphi_epidata' -__version__ = '0.4.4' +__version__ = '0.4.5' diff --git a/src/client/packaging/pypi/setup.py b/src/client/packaging/pypi/setup.py index 64b1952ab..1878dda74 100644 --- a/src/client/packaging/pypi/setup.py +++ b/src/client/packaging/pypi/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="delphi_epidata", - version="0.4.4", + version="0.4.5", author="David Farrow", author_email="dfarrow0@gmail.com", description="A programmatic interface to Delphi's Epidata API.", diff --git a/src/ddl/covid_hosp.sql b/src/ddl/covid_hosp.sql index 92ad87efc..2ffe7c71a 100644 --- a/src/ddl/covid_hosp.sql +++ b/src/ddl/covid_hosp.sql @@ -48,6 +48,7 @@ surfaced through the Epidata API. CREATE TABLE `covid_hosp_meta` ( `id` INT NOT NULL AUTO_INCREMENT, `dataset_name` VARCHAR(64) NOT NULL, + `hhs_dataset_id` CHAR(9) NOT NULL DEFAULT "????-????", `publication_date` INT NOT NULL, `revision_timestamp` VARCHAR(512) NOT NULL, `metadata_json` JSON NOT NULL, diff --git a/src/ddl/migrations/covid_hosp_meta_v0.4.4-v0.4.5.sql b/src/ddl/migrations/covid_hosp_meta_v0.4.4-v0.4.5.sql new file mode 100644 index 000000000..bab2623b6 --- /dev/null +++ b/src/ddl/migrations/covid_hosp_meta_v0.4.4-v0.4.5.sql @@ -0,0 +1,4 @@ +ALTER TABLE covid_hosp_meta ADD COLUMN hhs_dataset_id CHAR(9) NOT NULL DEFAULT "????-????"; +UPDATE covid_hosp_meta SET hhs_dataset_id="g62h-syeh" WHERE revision_timestamp LIKE "%g62h-syeh%"; +UPDATE covid_hosp_meta SET hhs_dataset_id="6xf2-c3ie" WHERE revision_timestamp LIKE "%6xf2-c3ie%"; +UPDATE covid_hosp_meta SET hhs_dataset_id="anag-cw7u" WHERE revision_timestamp LIKE "%anag-cw7u%"; diff --git a/src/server/_common.py b/src/server/_common.py index 45813e451..d8e2bc068 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -2,15 +2,16 @@ import time from flask import Flask, g, request -from sqlalchemy import event -from sqlalchemy.engine import Connection +from sqlalchemy import create_engine, event +from sqlalchemy.engine import Connection, Engine from werkzeug.local import LocalProxy from .utils.logger import get_structured_logger -from ._config import SECRET -from ._db import engine +from ._config import SECRET, SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS from ._exceptions import DatabaseErrorException, EpiDataException +engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) + app = Flask("EpiData", static_url_path="") app.config["SECRET"] = SECRET diff --git a/src/server/_config.py b/src/server/_config.py index 8f8fcacd1..187d4581a 100644 --- a/src/server/_config.py +++ b/src/server/_config.py @@ -4,7 +4,7 @@ load_dotenv() -VERSION = "0.4.4" +VERSION = "0.4.5" MAX_RESULTS = int(10e6) MAX_COMPATIBILITY_RESULTS = int(3650) diff --git a/src/server/_db.py b/src/server/_db.py deleted file mode 100644 index 9d15ef5b4..000000000 --- a/src/server/_db.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Dict, List -from sqlalchemy import MetaData, create_engine, inspect -from sqlalchemy.engine import Engine -from sqlalchemy.engine.reflection import Inspector - -from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS - -engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) -metadata = MetaData(bind=engine) - -TABLE_OPTIONS = dict( - mysql_engine="InnoDB", - # mariadb_engine="InnoDB", - mysql_charset="utf8mb4", - # mariadb_charset="utf8", -) - - -def sql_table_has_columns(table: str, columns: List[str]) -> bool: - """ - checks whether the given table has all the given columns defined - """ - inspector: Inspector = inspect(engine) - table_columns: List[Dict] = inspector.get_columns(table) - table_column_names = set(str(d.get("name", "")).lower() for d in table_columns) - return all(c.lower() in table_column_names for c in columns) diff --git a/src/server/_params.py b/src/server/_params.py index 2cef9725b..a7d36353c 100644 --- a/src/server/_params.py +++ b/src/server/_params.py @@ -7,7 +7,8 @@ from ._exceptions import ValidationFailedException -from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, TimeValues, days_to_ranges, weeks_to_ranges +from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, IntRange, TimeValues, days_to_ranges, weeks_to_ranges +from ._validate import require_any, require_all def _parse_common_multi_arg(key: str) -> List[Tuple[str, Union[bool, Sequence[str]]]]: @@ -48,7 +49,7 @@ def _parse_single_arg(key: str) -> Tuple[str, str]: @dataclass -class GeoPair: +class GeoSet: geo_type: str geo_values: Union[bool, Sequence[str]] @@ -57,27 +58,27 @@ def matches(self, geo_type: str, geo_value: str) -> bool: def count(self) -> float: """ - returns the count of items in this pair + returns the count of items in this set """ if isinstance(self.geo_values, bool): return inf if self.geo_values else 0 return len(self.geo_values) -def parse_geo_arg(key: str = "geo") -> List[GeoPair]: - return [GeoPair(geo_type, geo_values) for [geo_type, geo_values] in _parse_common_multi_arg(key)] +def parse_geo_arg(key: str = "geo") -> List[GeoSet]: + return [GeoSet(geo_type, geo_values) for [geo_type, geo_values] in _parse_common_multi_arg(key)] -def parse_single_geo_arg(key: str) -> GeoPair: +def parse_single_geo_arg(key: str) -> GeoSet: """ - parses a single geo pair with only one value + parses a single geo set with only one value """ r = _parse_single_arg(key) - return GeoPair(r[0], [r[1]]) + return GeoSet(r[0], [r[1]]) @dataclass -class SourceSignalPair: +class SourceSignalSet: source: str signal: Union[bool, Sequence[str]] @@ -86,27 +87,27 @@ def matches(self, source: str, signal: str) -> bool: def count(self) -> float: """ - returns the count of items in this pair + returns the count of items in this set """ if isinstance(self.signal, bool): return inf if self.signal else 0 return len(self.signal) -def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalPair]: - return [SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)] +def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalSet]: + return [SourceSignalSet(source, signals) for [source, signals] in _parse_common_multi_arg(key)] -def parse_single_source_signal_arg(key: str) -> SourceSignalPair: +def parse_single_source_signal_arg(key: str) -> SourceSignalSet: """ - parses a single source signal pair with only one value + parses a single source signal set with only one value """ r = _parse_single_arg(key) - return SourceSignalPair(r[0], [r[1]]) + return SourceSignalSet(r[0], [r[1]]) @dataclass -class TimePair: +class TimeSet: time_type: str time_values: Union[bool, TimeValues] @@ -121,7 +122,7 @@ def is_day(self) -> bool: def count(self) -> float: """ - returns the count of items in this pair + returns the count of items in this set """ if isinstance(self.time_values, bool): return inf if self.time_values else 0 @@ -131,16 +132,16 @@ def count(self) -> float: def to_ranges(self): """ - returns this pair with times converted to ranges + returns this set with times converted to ranges """ if isinstance(self.time_values, bool): - return TimePair(self.time_type, self.time_values) + return TimeSet(self.time_type, self.time_values) if self.time_type == 'week': - return TimePair(self.time_type, weeks_to_ranges(self.time_values)) - return TimePair(self.time_type, days_to_ranges(self.time_values)) + return TimeSet(self.time_type, weeks_to_ranges(self.time_values)) + return TimeSet(self.time_type, days_to_ranges(self.time_values)) -def _verify_range(start: int, end: int) -> Union[int, Tuple[int, int]]: +def _verify_range(start: int, end: int) -> IntRange: if start == end: # the first and last numbers are the same, just treat it as a singe value return start @@ -151,7 +152,7 @@ def _verify_range(start: int, end: int) -> Union[int, Tuple[int, int]]: raise ValidationFailedException(f"the given range {start}-{end} is inverted") -def parse_week_value(time_value: str) -> Union[int, Tuple[int, int]]: +def parse_week_value(time_value: str) -> IntRange: count_dashes = time_value.count("-") msg = f"{time_value} does not match a known format YYYYWW or YYYYWW-YYYYWW" @@ -171,7 +172,7 @@ def parse_week_value(time_value: str) -> Union[int, Tuple[int, int]]: raise ValidationFailedException(msg) -def parse_day_value(time_value: str) -> Union[int, Tuple[int, int]]: +def parse_day_value(time_value: str) -> IntRange: count_dashes = time_value.count("-") msg = f"{time_value} does not match a known format YYYYMMDD, YYYY-MM-DD, YYYYMMDD-YYYYMMDD, or YYYY-MM-DD--YYYY-MM-DD" @@ -204,47 +205,47 @@ def parse_day_value(time_value: str) -> Union[int, Tuple[int, int]]: raise ValidationFailedException(msg) -def _parse_time_pair(time_type: str, time_values: Union[bool, Sequence[str]]) -> TimePair: +def _parse_time_set(time_type: str, time_values: Union[bool, Sequence[str]]) -> TimeSet: if isinstance(time_values, bool): - return TimePair(time_type, time_values) + return TimeSet(time_type, time_values) if time_type == "week": - return TimePair("week", [parse_week_value(t) for t in time_values]) + return TimeSet("week", [parse_week_value(t) for t in time_values]) elif time_type == "day": - return TimePair("day", [parse_day_value(t) for t in time_values]) + return TimeSet("day", [parse_day_value(t) for t in time_values]) raise ValidationFailedException(f'time param: {time_type} is not one of "day" or "week"') -def parse_time_arg(key: str = "time") -> Optional[TimePair]: - time_pairs = [_parse_time_pair(time_type, time_values) for [time_type, time_values] in _parse_common_multi_arg(key)] +def parse_time_arg(key: str = "time") -> Optional[TimeSet]: + time_sets = [_parse_time_set(time_type, time_values) for [time_type, time_values] in _parse_common_multi_arg(key)] # single value - if len(time_pairs) == 0: + if len(time_sets) == 0: return None - if len(time_pairs) == 1: - return time_pairs[0] + if len(time_sets) == 1: + return time_sets[0] # make sure 'day' and 'week' aren't mixed - time_types = set(time_pair.time_type for time_pair in time_pairs) + time_types = set(time_set.time_type for time_set in time_sets) if len(time_types) >= 2: - raise ValidationFailedException(f'{key}: {time_pairs} mixes "day" and "week" time types') + raise ValidationFailedException(f'{key}: {time_sets} mixes "day" and "week" time types') - # merge all time pairs into one + # merge all time sets into one merged = [] - for time_pair in time_pairs: - if time_pair.time_values is True: - return time_pair + for time_set in time_sets: + if time_set.time_values is True: + return time_set else: - merged.extend(time_pair.time_values) - return TimePair(time_pairs[0].time_type, merged).to_ranges() + merged.extend(time_set.time_values) + return TimeSet(time_sets[0].time_type, merged).to_ranges() -def parse_single_time_arg(key: str) -> TimePair: +def parse_single_time_arg(key: str) -> TimeSet: """ - parses a single time pair with only one value + parses a single time set with only one value """ r = _parse_single_arg(key) - return _parse_time_pair(r[0], [r[1]]) + return _parse_time_set(r[0], [r[1]]) def parse_day_range_arg(key: str) -> Tuple[int, int]: @@ -285,20 +286,20 @@ def parse_week_range_arg(key: str) -> Tuple[int, int]: raise ValidationFailedException(f"{key} must match YYYYWW-YYYYWW") return r -def parse_day_or_week_arg(key: str, default_value: Optional[int] = None) -> TimePair: +def parse_day_or_week_arg(key: str, default_value: Optional[int] = None) -> TimeSet: v = request.values.get(key) if not v: if default_value is not None: time_type = "day" if guess_time_value_is_day(default_value) else "week" - return TimePair(time_type, [default_value]) + return TimeSet(time_type, [default_value]) raise ValidationFailedException(f"{key} param is required") # format is either YYYY-MM-DD or YYYYMMDD or YYYYMM is_week = guess_time_value_is_week(v) if is_week: - return TimePair("week", [parse_week_arg(key)]) - return TimePair("day", [parse_day_arg(key)]) + return TimeSet("week", [parse_week_arg(key)]) + return TimeSet("day", [parse_day_arg(key)]) -def parse_day_or_week_range_arg(key: str) -> TimePair: +def parse_day_or_week_range_arg(key: str) -> TimeSet: v = request.values.get(key) if not v: raise ValidationFailedException(f"{key} param is required") @@ -306,5 +307,184 @@ def parse_day_or_week_range_arg(key: str) -> TimePair: # so if the first before the - has length 6, it must be a week is_week = guess_time_value_is_week(v.split('-', 2)[0]) if is_week: - return TimePair("week", [parse_week_range_arg(key)]) - return TimePair("day", [parse_day_range_arg(key)]) + return TimeSet("week", [parse_week_range_arg(key)]) + return TimeSet("day", [parse_day_range_arg(key)]) + + +def _extract_value(key: Union[str, Sequence[str]]) -> Optional[str]: + if isinstance(key, str): + return request.values.get(key) + for k in key: + if k in request.values: + return request.values[k] + return None + + +def _extract_list_value(key: Union[str, Sequence[str]]) -> List[str]: + if isinstance(key, str): + return request.values.getlist(key) + for k in key: + if k in request.values: + return request.values.getlist(k) + return [] + + +def extract_strings(key: Union[str, Sequence[str]]) -> Optional[List[str]]: + s = _extract_list_value(key) + if not s: + # nothing to do + return None + # we can have multiple values + return [v for vs in s for v in vs.split(",")] + + +def extract_integer(key: Union[str, Sequence[str]]) -> Optional[int]: + s = _extract_value(key) + if not s: + # nothing to do + return None + try: + return int(s) + except ValueError: + raise ValidationFailedException(f"{key}: not a number: {s}") + + +def extract_integers(key: Union[str, Sequence[str]]) -> Optional[List[IntRange]]: + parts = extract_strings(key) + if not parts: + # nothing to do + return None + + def _parse_range(part: str): + if "-" not in part: + return int(part) + r = part.split("-", 2) + first = int(r[0]) + last = int(r[1]) + if first == last: + # the first and last numbers are the same, just treat it as a singe value + return first + elif last > first: + # add the range as an array + return (first, last) + # the range is inverted, this is an error + raise ValidationFailedException(f"{key}: the given range is inverted") + + try: + values = [_parse_range(part) for part in parts] + # check for invalid values + return None if any(v is None for v in values) else values + except ValueError as e: + raise ValidationFailedException(f"{key}: not a number: {str(e)}") + + +def parse_date(s: str) -> int: + # parses a given string in format YYYYMMDD or YYYY-MM-DD to a number in the form YYYYMMDD + try: + if s == "*": + return s + else: + return int(s.replace("-", "")) + except ValueError: + raise ValidationFailedException(f"not a valid date: {s}") + + +def extract_date(key: Union[str, Sequence[str]]) -> Optional[int]: + s = _extract_value(key) + if not s: + return None + return parse_date(s) + + +def extract_dates(key: Union[str, Sequence[str]]) -> Optional[TimeValues]: + parts = extract_strings(key) + if not parts: + return None + values: TimeValues = [] + + def push_range(first: str, last: str): + first_d = parse_date(first) + last_d = parse_date(last) + if first_d == last_d: + # the first and last numbers are the same, just treat it as a singe value + return first_d + if last_d > first_d: + # add the range as an array + return (first_d, last_d) + # the range is inverted, this is an error + raise ValidationFailedException(f"{key}: the given range is inverted") + + for part in parts: + if "-" not in part and ":" not in part: + # YYYYMMDD + values.append(parse_date(part)) + continue + if ":" in part: + # YYYY-MM-DD:YYYY-MM-DD + range_part = part.split(":", 2) + r = push_range(range_part[0], range_part[1]) + if r is None: + return None + values.append(r) + continue + # YYYY-MM-DD or YYYYMMDD-YYYYMMDD + # split on the dash + range_part = part.split("-") + if len(range_part) == 2: + # YYYYMMDD-YYYYMMDD + r = push_range(range_part[0], range_part[1]) + if r is None: + return None + values.append(r) + continue + # YYYY-MM-DD + values.append(parse_date(part)) + # success, return the list + return values + +def parse_source_signal_sets() -> List[SourceSignalSet]: + ds = request.values.get("data_source") + if ds: + # old version + require_any("signal", "signals", empty=True) + signals = extract_strings(("signals", "signal")) + if len(signals) == 1 and signals[0] == "*": + return [SourceSignalSet(ds, True)] + return [SourceSignalSet(ds, signals)] + + if ":" not in request.values.get("signal", ""): + raise ValidationFailedException("missing parameter: signal or (data_source and signal[s])") + + return parse_source_signal_arg() + + +def parse_geo_sets() -> List[GeoSet]: + geo_type = request.values.get("geo_type") + if geo_type: + # old version + require_any("geo_value", "geo_values", empty=True) + geo_values = extract_strings(("geo_values", "geo_value")) + if len(geo_values) == 1 and geo_values[0] == "*": + return [GeoSet(geo_type, True)] + return [GeoSet(geo_type, geo_values)] + + if ":" not in request.values.get("geo", ""): + raise ValidationFailedException("missing parameter: geo or (geo_type and geo_value[s])") + + return parse_geo_arg() + + +def parse_time_set() -> TimeSet: + time_type = request.values.get("time_type") + if time_type: + # old version + require_all("time_type", "time_values") + time_values = extract_dates("time_values") + if time_values == ["*"]: + return TimeSet(time_type, True) + return TimeSet(time_type, time_values) + + if ":" not in request.values.get("time", ""): + raise ValidationFailedException("missing parameter: time or (time_type and time_values)") + + return parse_time_arg() diff --git a/src/server/_printer.py b/src/server/_printer.py index bbe3ee10e..04196c71d 100644 --- a/src/server/_printer.py +++ b/src/server/_printer.py @@ -58,7 +58,7 @@ def gen(): r = self._print_row(row) if r is not None: yield r - except: + except Exception as e: get_structured_logger('server_error').error("Exception while executing printer", exception=e) self.result = -1 yield self._error(e) diff --git a/src/server/_query.py b/src/server/_query.py index 69607255f..3c23f94ad 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -18,9 +18,8 @@ from ._common import db from ._printer import create_printer, APrinter from ._exceptions import DatabaseErrorException -from ._validate import extract_strings -from ._params import GeoPair, SourceSignalPair, TimePair -from .utils import time_values_to_ranges, TimeValues +from ._params import extract_strings, GeoSet, SourceSignalSet, TimeSet +from .utils import time_values_to_ranges, IntRange, TimeValues def date_string(value: int) -> str: @@ -34,7 +33,7 @@ def date_string(value: int) -> str: def to_condition( field: str, - value: Union[str, Tuple[int, int], int], + value: Union[str, IntRange], param_key: str, params: Dict[str, Any], formatter=lambda x: x, @@ -50,7 +49,7 @@ def to_condition( def filter_values( field: str, - values: Optional[Sequence[Union[str, Tuple[int, int], int]]], + values: Optional[Sequence[Union[str, IntRange]]], param_key: str, params: Dict[str, Any], formatter=lambda x: x, @@ -75,7 +74,7 @@ def filter_strings( def filter_integers( field: str, - values: Optional[Sequence[Union[Tuple[int, int], int]]], + values: Optional[Sequence[IntRange]], param_key: str, params: Dict[str, Any], ): @@ -115,25 +114,25 @@ def filter_fields(generator: Iterable[Dict[str, Any]]): yield filtered -def filter_geo_pairs( +def filter_geo_sets( type_field: str, value_field: str, - values: Sequence[GeoPair], + values: Sequence[GeoSet], param_key: str, params: Dict[str, Any], ) -> str: """ - returns the SQL sub query to filter by the given geo pairs + returns the SQL sub query to filter by the given geo sets """ - def filter_pair(pair: GeoPair, i) -> str: + def filter_set(gset: GeoSet, i) -> str: type_param = f"{param_key}_{i}t" - params[type_param] = pair.geo_type - if isinstance(pair.geo_values, bool) and pair.geo_values: + params[type_param] = gset.geo_type + if isinstance(gset.geo_values, bool) and gset.geo_values: return f"{type_field} = :{type_param}" - return f"({type_field} = :{type_param} AND {filter_strings(value_field, cast(Sequence[str], pair.geo_values), type_param, params)})" + return f"({type_field} = :{type_param} AND {filter_strings(value_field, cast(Sequence[str], gset.geo_values), type_param, params)})" - parts = [filter_pair(p, i) for i, p in enumerate(values)] + parts = [filter_set(p, i) for i, p in enumerate(values)] if not parts: # something has to be selected @@ -142,25 +141,25 @@ def filter_pair(pair: GeoPair, i) -> str: return f"({' OR '.join(parts)})" -def filter_source_signal_pairs( +def filter_source_signal_sets( source_field: str, signal_field: str, - values: Sequence[SourceSignalPair], + values: Sequence[SourceSignalSet], param_key: str, params: Dict[str, Any], ) -> str: """ - returns the SQL sub query to filter by the given source signal pairs + returns the SQL sub query to filter by the given source signal sets """ - def filter_pair(pair: SourceSignalPair, i) -> str: + def filter_set(ssset: SourceSignalSet, i) -> str: source_param = f"{param_key}_{i}t" - params[source_param] = pair.source - if isinstance(pair.signal, bool) and pair.signal: + params[source_param] = ssset.source + if isinstance(ssset.signal, bool) and ssset.signal: return f"{source_field} = :{source_param}" - return f"({source_field} = :{source_param} AND {filter_strings(signal_field, cast(Sequence[str], pair.signal), source_param, params)})" + return f"({source_field} = :{source_param} AND {filter_strings(signal_field, cast(Sequence[str], ssset.signal), source_param, params)})" - parts = [filter_pair(p, i) for i, p in enumerate(values)] + parts = [filter_set(p, i) for i, p in enumerate(values)] if not parts: # something has to be selected @@ -169,26 +168,26 @@ def filter_pair(pair: SourceSignalPair, i) -> str: return f"({' OR '.join(parts)})" -def filter_time_pair( +def filter_time_set( type_field: str, time_field: str, - pair: Optional[TimePair], + tset: Optional[TimeSet], param_key: str, params: Dict[str, Any], ) -> str: """ - returns the SQL sub query to filter by the given time pair + returns the SQL sub query to filter by the given time set """ - # safety path; should normally not be reached as time pairs are enforced by the API - if not pair: + # safety path; should normally not be reached as time sets are enforced by the API + if not tset: return "FALSE" type_param = f"{param_key}_0t" - params[type_param] = pair.time_type - if isinstance(pair.time_values, bool) and pair.time_values: + params[type_param] = tset.time_type + if isinstance(tset.time_values, bool) and tset.time_values: parts = f"{type_field} = :{type_param}" else: - ranges = pair.to_ranges().time_values + ranges = tset.to_ranges().time_values parts = f"({type_field} = :{type_param} AND {filter_integers(time_field, ranges, type_param, params)})" return f"({parts})" @@ -399,24 +398,24 @@ def _fq_field(self, field: str) -> str: def where_integers( self, field: str, - values: Optional[Sequence[Union[Tuple[int, int], int]]], + values: Optional[Sequence[IntRange]], param_key: Optional[str] = None, ) -> "QueryBuilder": fq_field = self._fq_field(field) self.conditions.append(filter_integers(fq_field, values, param_key or field, self.params)) return self - def where_geo_pairs( + def apply_geo_filters( self, type_field: str, value_field: str, - values: Sequence[GeoPair], + values: Sequence[GeoSet], param_key: Optional[str] = None, ) -> "QueryBuilder": fq_type_field = self._fq_field(type_field) fq_value_field = self._fq_field(value_field) self.conditions.append( - filter_geo_pairs( + filter_geo_sets( fq_type_field, fq_value_field, values, @@ -426,17 +425,17 @@ def where_geo_pairs( ) return self - def where_source_signal_pairs( + def apply_source_signal_filters( self, type_field: str, value_field: str, - values: Sequence[SourceSignalPair], + values: Sequence[SourceSignalSet], param_key: Optional[str] = None, ) -> "QueryBuilder": fq_type_field = self._fq_field(type_field) fq_value_field = self._fq_field(value_field) self.conditions.append( - filter_source_signal_pairs( + filter_source_signal_sets( fq_type_field, fq_value_field, values, @@ -446,17 +445,17 @@ def where_source_signal_pairs( ) return self - def where_time_pair( + def apply_time_filter( self, type_field: str, value_field: str, - values: Optional[TimePair], + values: Optional[TimeSet], param_key: Optional[str] = None, ) -> "QueryBuilder": fq_type_field = self._fq_field(type_field) fq_value_field = self._fq_field(value_field) self.conditions.append( - filter_time_pair( + filter_time_set( fq_type_field, fq_value_field, values, @@ -466,25 +465,44 @@ def where_time_pair( ) return self + def apply_lag_filter(self, history_table: str, lag: Optional[int]) -> "QueryBuilder": + if lag is not None: + self.retable(history_table) + # history_table has full spectrum of lag values to search from whereas the latest_table does not + self.where(lag=lag) + return self + + def apply_issues_filter(self, history_table: str, issues: Optional[TimeValues]) -> "QueryBuilder": + if issues: + if issues == ["*"]: + self.retable(history_table) + else: + self.retable(history_table) + self.where_integers("issue", issues) + return self + + def apply_as_of_filter(self, history_table: str, as_of: Optional[int]) -> "QueryBuilder": + if as_of is not None: + self.retable(history_table) + sub_condition_asof = "(issue <= :as_of)" + self.params["as_of"] = as_of + sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value" + sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value" + alias = self.alias + sub_condition = f"x.max_issue = {alias}.issue AND x.time_type = {alias}.time_type AND x.time_value = {alias}.time_value AND x.source = {alias}.source AND x.signal = {alias}.signal AND x.geo_type = {alias}.geo_type AND x.geo_value = {alias}.geo_value" + self.subquery = f"JOIN (SELECT {sub_fields} FROM {self.table} WHERE {self.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}" + return self + def set_fields(self, *fields: Iterable[str]) -> "QueryBuilder": self.fields = [f"{self.alias}.{field}" for field_list in fields for field in field_list] return self - def set_order(self, *args: str, **kwargs: Union[str, bool]) -> "QueryBuilder": + def set_sort_order(self, *args: str) -> "QueryBuilder": """ sets the order for the given fields (as key word arguments), True = ASC, False = DESC """ - def to_asc(v: Union[str, bool]) -> str: - if v is True: - return "ASC" - elif v is False: - return "DESC" - return cast(str, v) - - args_order = [f"{self.alias}.{k} ASC" for k in args] - kw_order = [f"{self.alias}.{k} {to_asc(v)}" for k, v in kwargs.items()] - self.order = args_order + kw_order + self.order = [f"{self.alias}.{k} ASC" for k in args] return self def with_max_issue(self, *args: str) -> "QueryBuilder": diff --git a/src/server/_validate.py b/src/server/_validate.py index 59e5aa7d0..ffdd15232 100644 --- a/src/server/_validate.py +++ b/src/server/_validate.py @@ -3,7 +3,7 @@ from flask import request from ._exceptions import UnAuthenticatedException, ValidationFailedException -from .utils import TimeValues +from .utils import IntRange, TimeValues def resolve_auth_token() -> Optional[str]: @@ -55,135 +55,3 @@ def require_any(*values: str, empty=False) -> bool: if request.values.get(value) or (empty and value in request.values): return True raise ValidationFailedException(f"missing parameter: need one of [{', '.join(values)}]") - - -def _extract_value(key: Union[str, Sequence[str]]) -> Optional[str]: - if isinstance(key, str): - return request.values.get(key) - for k in key: - if k in request.values: - return request.values[k] - return None - - -def _extract_list_value(key: Union[str, Sequence[str]]) -> List[str]: - if isinstance(key, str): - return request.values.getlist(key) - for k in key: - if k in request.values: - return request.values.getlist(k) - return [] - - -def extract_strings(key: Union[str, Sequence[str]]) -> Optional[List[str]]: - s = _extract_list_value(key) - if not s: - # nothing to do - return None - # we can have multiple values - return [v for vs in s for v in vs.split(",")] - - -IntRange = Union[Tuple[int, int], int] - - -def extract_integer(key: Union[str, Sequence[str]]) -> Optional[int]: - s = _extract_value(key) - if not s: - # nothing to do - return None - try: - return int(s) - except ValueError: - raise ValidationFailedException(f"{key}: not a number: {s}") - - -def extract_integers(key: Union[str, Sequence[str]]) -> Optional[List[IntRange]]: - parts = extract_strings(key) - if not parts: - # nothing to do - return None - - def _parse_range(part: str): - if "-" not in part: - return int(part) - r = part.split("-", 2) - first = int(r[0]) - last = int(r[1]) - if first == last: - # the first and last numbers are the same, just treat it as a singe value - return first - elif last > first: - # add the range as an array - return (first, last) - # the range is inverted, this is an error - raise ValidationFailedException(f"{key}: the given range is inverted") - - try: - values = [_parse_range(part) for part in parts] - # check for invalid values - return None if any(v is None for v in values) else values - except ValueError as e: - raise ValidationFailedException(f"{key}: not a number: {str(e)}") - - -def parse_date(s: str) -> int: - # parses a given string in format YYYYMMDD or YYYY-MM-DD to a number in the form YYYYMMDD - try: - return int(s.replace("-", "")) - except ValueError: - raise ValidationFailedException(f"not a valid date: {s}") - - -def extract_date(key: Union[str, Sequence[str]]) -> Optional[int]: - s = _extract_value(key) - if not s: - return None - return parse_date(s) - - -def extract_dates(key: Union[str, Sequence[str]]) -> Optional[TimeValues]: - parts = extract_strings(key) - if not parts: - return None - values: TimeValues = [] - - def push_range(first: str, last: str): - first_d = parse_date(first) - last_d = parse_date(last) - if first_d == last_d: - # the first and last numbers are the same, just treat it as a singe value - return first_d - if last_d > first_d: - # add the range as an array - return (first_d, last_d) - # the range is inverted, this is an error - raise ValidationFailedException(f"{key}: the given range is inverted") - - for part in parts: - if "-" not in part and ":" not in part: - # YYYYMMDD - values.append(parse_date(part)) - continue - if ":" in part: - # YYYY-MM-DD:YYYY-MM-DD - range_part = part.split(":", 2) - r = push_range(range_part[0], range_part[1]) - if r is None: - return None - values.append(r) - continue - # YYYY-MM-DD or YYYYMMDD-YYYYMMDD - # split on the dash - range_part = part.split("-") - if len(range_part) == 2: - # YYYYMMDD-YYYYMMDD - r = push_range(range_part[0], range_part[1]) - if r is None: - return None - values.append(r) - continue - # YYYY-MM-DD - values.append(parse_date(part)) - # success, return the list - return values diff --git a/src/server/endpoints/afhsb.py b/src/server/endpoints/afhsb.py index 9f05eac9d..69c2d2431 100644 --- a/src/server/endpoints/afhsb.py +++ b/src/server/endpoints/afhsb.py @@ -3,8 +3,9 @@ from flask import Blueprint from .._config import AUTH +from .._params import extract_integers, extract_strings from .._query import execute_queries, filter_integers, filter_strings -from .._validate import check_auth_token, extract_integers, extract_strings, require_all +from .._validate import check_auth_token, require_all # first argument is the endpoint name bp = Blueprint("afhsb", __name__) diff --git a/src/server/endpoints/cdc.py b/src/server/endpoints/cdc.py index 6023d4f16..6b7b9450d 100644 --- a/src/server/endpoints/cdc.py +++ b/src/server/endpoints/cdc.py @@ -1,8 +1,9 @@ from flask import Blueprint from .._config import AUTH, NATION_REGION, REGION_TO_STATE -from .._validate import require_all, extract_strings, extract_integers, check_auth_token +from .._params import extract_strings, extract_integers from .._query import filter_strings, execute_queries, filter_integers +from .._validate import require_all, check_auth_token # first argument is the endpoint name bp = Blueprint("cdc", __name__) diff --git a/src/server/endpoints/covid_hosp_facility.py b/src/server/endpoints/covid_hosp_facility.py index 8d5cb0b77..d1c9fad8a 100644 --- a/src/server/endpoints/covid_hosp_facility.py +++ b/src/server/endpoints/covid_hosp_facility.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("covid_hosp_facility", __name__) @@ -139,7 +140,7 @@ def handle(): q.set_fields(fields_string, fields_int, fields_float) # basic query info - q.set_order("collection_week", "hospital_pk", "publication_date") + q.set_sort_order("collection_week", "hospital_pk", "publication_date") # build the filter q.where_integers("collection_week", collection_weeks) diff --git a/src/server/endpoints/covid_hosp_facility_lookup.py b/src/server/endpoints/covid_hosp_facility_lookup.py index 880767135..54a3b9183 100644 --- a/src/server/endpoints/covid_hosp_facility_lookup.py +++ b/src/server/endpoints/covid_hosp_facility_lookup.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_strings, require_any +from .._validate import require_any # first argument is the endpoint name bp = Blueprint("covid_hosp_facility_lookup", __name__) @@ -33,7 +34,7 @@ def handle(): ] ) # basic query info - q.set_order("hospital_pk") + q.set_sort_order("hospital_pk") # build the filter # these are all fast because the table has indexes on each of these fields if state: diff --git a/src/server/endpoints/covid_hosp_state_timeseries.py b/src/server/endpoints/covid_hosp_state_timeseries.py index 5da4d4e16..a20e74d25 100644 --- a/src/server/endpoints/covid_hosp_state_timeseries.py +++ b/src/server/endpoints/covid_hosp_state_timeseries.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integers, extract_strings, extract_date from .._query import execute_query, QueryBuilder -from .._validate import extract_integers, extract_strings, extract_date, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("covid_hosp_state_timeseries", __name__) @@ -145,7 +146,7 @@ def handle(): ] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("date", "state", "issue") + q.set_sort_order("date", "state", "issue") # build the filter q.where_integers("date", dates) diff --git a/src/server/endpoints/covidcast.py b/src/server/endpoints/covidcast.py index 0c22e4573..05f7cfc3f 100644 --- a/src/server/endpoints/covidcast.py +++ b/src/server/endpoints/covidcast.py @@ -11,28 +11,26 @@ from .._common import is_compatibility_mode, db from .._exceptions import ValidationFailedException, DatabaseErrorException from .._params import ( - GeoPair, - SourceSignalPair, - TimePair, + GeoSet, + SourceSignalSet, + TimeSet, + extract_date, + extract_dates, + extract_integer, parse_geo_arg, parse_source_signal_arg, - parse_time_arg, parse_day_or_week_arg, parse_day_or_week_range_arg, parse_single_source_signal_arg, parse_single_time_arg, parse_single_geo_arg, + parse_geo_sets, + parse_source_signal_sets, + parse_time_set, ) from .._query import QueryBuilder, execute_query, run_query, parse_row, filter_fields from .._printer import create_printer, CSVPrinter -from .._validate import ( - extract_date, - extract_dates, - extract_integer, - extract_strings, - require_all, - require_any, -) +from .._validate import require_all from .._pandas import as_pandas, print_pandas from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry from ..utils import shift_day_value, day_to_time_value, time_value_to_iso, time_value_to_day, shift_week_value, time_value_to_week, guess_time_value_is_day, week_to_time_value, TimeValues @@ -45,80 +43,12 @@ latest_table = "epimetric_latest_v" history_table = "epimetric_full_v" -def parse_source_signal_pairs() -> List[SourceSignalPair]: - ds = request.values.get("data_source") - if ds: - # old version - require_any("signal", "signals", empty=True) - signals = extract_strings(("signals", "signal")) - if len(signals) == 1 and signals[0] == "*": - return [SourceSignalPair(ds, True)] - return [SourceSignalPair(ds, signals)] - - if ":" not in request.values.get("signal", ""): - raise ValidationFailedException("missing parameter: signal or (data_source and signal[s])") - - return parse_source_signal_arg() - - -def parse_geo_pairs() -> List[GeoPair]: - geo_type = request.values.get("geo_type") - if geo_type: - # old version - require_any("geo_value", "geo_values", empty=True) - geo_values = extract_strings(("geo_values", "geo_value")) - if len(geo_values) == 1 and geo_values[0] == "*": - return [GeoPair(geo_type, True)] - return [GeoPair(geo_type, geo_values)] - - if ":" not in request.values.get("geo", ""): - raise ValidationFailedException("missing parameter: geo or (geo_type and geo_value[s])") - - return parse_geo_arg() - - -def parse_time_pairs() -> TimePair: - time_type = request.values.get("time_type") - if time_type: - # old version - require_all("time_type", "time_values") - time_values = extract_dates("time_values") - return TimePair(time_type, time_values) - - if ":" not in request.values.get("time", ""): - raise ValidationFailedException("missing parameter: time or (time_type and time_values)") - - return parse_time_arg() - - -def _handle_lag_issues_as_of(q: QueryBuilder, issues: Optional[TimeValues] = None, lag: Optional[int] = None, as_of: Optional[int] = None): - if issues: - q.retable(history_table) - q.where_integers("issue", issues) - elif lag is not None: - q.retable(history_table) - # history_table has full spectrum of lag values to search from whereas the latest_table does not - q.where(lag=lag) - elif as_of is not None: - # fetch the most recent issue as of a certain date (not to be confused w/ plain-old "most recent issue" - q.retable(history_table) - sub_condition_asof = "(issue <= :as_of)" - q.params["as_of"] = as_of - sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value" - sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value" - sub_condition = f"x.max_issue = {q.alias}.issue AND x.time_type = {q.alias}.time_type AND x.time_value = {q.alias}.time_value AND x.source = {q.alias}.source AND x.signal = {q.alias}.signal AND x.geo_type = {q.alias}.geo_type AND x.geo_value = {q.alias}.geo_value" - q.subquery = f"JOIN (SELECT {sub_fields} FROM {q.table} WHERE {q.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}" - else: - # else we are using the (standard/default) `latest_table`, to fetch the most recent issue quickly - pass - - @bp.route("/", methods=("GET", "POST")) def handle(): - source_signal_pairs = parse_source_signal_pairs() - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) - time_pair = parse_time_pairs() - geo_pairs = parse_geo_pairs() + source_signal_sets = parse_source_signal_sets() + source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) + time_set = parse_time_set() + geo_sets = parse_geo_sets() as_of = extract_date("as_of") issues = extract_dates("issues") @@ -132,22 +62,24 @@ def handle(): fields_float = ["value", "stderr", "sample_size"] is_compatibility = is_compatibility_mode() if is_compatibility: - q.set_order("signal", "time_value", "geo_value", "issue") + q.set_sort_order("signal", "time_value", "geo_value", "issue") else: # transfer also the new detail columns fields_string.extend(["source", "geo_type", "time_type"]) - q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue") + q.set_sort_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue") q.set_fields(fields_string, fields_int, fields_float) # basic query info # data type of each field # build the source, signal, time, and location (type and id) filters - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_geo_pairs("geo_type", "geo_value", geo_pairs) - q.where_time_pair("time_type", "time_value", time_pair) + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_geo_filters("geo_type", "geo_value", geo_sets) + q.apply_time_filter("time_type", "time_value", time_set) - _handle_lag_issues_as_of(q, issues, lag, as_of) + q.apply_issues_filter(history_table, issues) + q.apply_lag_filter(history_table, lag) + q.apply_as_of_filter(history_table, as_of) def transform_row(row, proxy): if is_compatibility or not alias_mapper or "source" not in row: @@ -169,15 +101,15 @@ def _verify_argument_time_type_matches(is_day_argument: bool, count_daily_signal @bp.route("/trend", methods=("GET", "POST")) def handle_trend(): require_all("window", "date") - source_signal_pairs = parse_source_signal_pairs() - daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) - geo_pairs = parse_geo_pairs() + source_signal_sets = parse_source_signal_sets() + daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) + source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) + geo_sets = parse_geo_sets() time_window = parse_day_or_week_range_arg("window") is_day = time_window.is_day - time_pair = parse_day_or_week_arg("date") - time_value, is_also_day = time_pair.time_values[0], time_pair.is_day + time_set = parse_day_or_week_arg("date") + time_value, is_also_day = time_set.time_values[0], time_set.is_day if is_day != is_also_day: raise ValidationFailedException("mixing weeks with day arguments") _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) @@ -195,14 +127,11 @@ def handle_trend(): fields_int = ["time_value"] fields_float = ["value"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("geo_type", "geo_value", "source", "signal", "time_value") + q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value") - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_geo_pairs("geo_type", "geo_value", geo_pairs) - q.where_time_pair("time_type", "time_value", time_window) - - # fetch most recent issue fast - _handle_lag_issues_as_of(q, None, None, None) + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_geo_filters("geo_type", "geo_value", geo_sets) + q.apply_time_filter("time_type", "time_value", time_window) p = create_printer() @@ -227,10 +156,10 @@ def gen(rows): @bp.route("/trendseries", methods=("GET", "POST")) def handle_trendseries(): require_all("window") - source_signal_pairs = parse_source_signal_pairs() - daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) - geo_pairs = parse_geo_pairs() + source_signal_sets = parse_source_signal_sets() + daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) + source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) + geo_sets = parse_geo_sets() time_window = parse_day_or_week_range_arg("window") is_day = time_window.is_day @@ -246,14 +175,11 @@ def handle_trendseries(): fields_int = ["time_value"] fields_float = ["value"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("geo_type", "geo_value", "source", "signal", "time_value") - - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_geo_pairs("geo_type", "geo_value", geo_pairs) - q.where_time_pair("time_type", "time_value", time_window) + q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value") - # fetch most recent issue fast - _handle_lag_issues_as_of(q, None, None, None) + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_geo_filters("geo_type", "geo_value", geo_sets) + q.apply_time_filter("time_type", "time_value", time_window) p = create_printer() @@ -284,10 +210,10 @@ def gen(rows): def handle_correlation(): require_all("reference", "window", "others", "geo") reference = parse_single_source_signal_arg("reference") - other_pairs = parse_source_signal_arg("others") - daily_signals, weekly_signals = count_signal_time_types(other_pairs + [reference]) - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(other_pairs + [reference]) - geo_pairs = parse_geo_arg() + other_sets = parse_source_signal_arg("others") + daily_signals, weekly_signals = count_signal_time_types(other_sets + [reference]) + source_signal_sets, alias_mapper = create_source_signal_alias_mapper(other_sets + [reference]) + geo_sets = parse_geo_arg() time_window = parse_day_or_week_range_arg("window") is_day = time_window.is_day _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) @@ -303,15 +229,15 @@ def handle_correlation(): fields_int = ["time_value"] fields_float = ["value"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("geo_type", "geo_value", "source", "signal", "time_value") + q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value") - q.where_source_signal_pairs( + q.apply_source_signal_filters( "source", "signal", - source_signal_pairs, + source_signal_sets, ) - q.where_geo_pairs("geo_type", "geo_value", geo_pairs) - q.where_time_pair("time_type", "time_value", time_window) + q.apply_geo_filters("geo_type", "geo_value", geo_sets) + q.apply_time_filter("time_type", "time_value", time_window) df = as_pandas(str(q), q.params) if is_day: @@ -356,13 +282,13 @@ def gen(): @bp.route("/csv", methods=("GET", "POST")) def handle_export(): source, signal = request.values.get("signal", "jhu-csse:confirmed_incidence_num").split(":") - source_signal_pairs = [SourceSignalPair(source, [signal])] - daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) - start_pair = parse_day_or_week_arg("start_day", 202001 if weekly_signals > 0 else 20200401) - start_day, is_day = start_pair.time_values[0], start_pair.is_day - end_pair = parse_day_or_week_arg("end_day", 202020 if weekly_signals > 0 else 20200901) - end_day, is_end_day = end_pair.time_values[0], end_pair.is_day + source_signal_sets = [SourceSignalSet(source, [signal])] + daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) + source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) + start_time_set = parse_day_or_week_arg("start_day", 202001 if weekly_signals > 0 else 20200401) + start_day, is_day = start_time_set.time_values[0], start_time_set.is_day + end_time_set = parse_day_or_week_arg("end_day", 202020 if weekly_signals > 0 else 20200901) + end_day, is_end_day = end_time_set.time_values[0], end_time_set.is_day if is_day != is_end_day: raise ValidationFailedException("mixing weeks with day arguments") _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) @@ -381,12 +307,13 @@ def handle_export(): q = QueryBuilder(latest_table, "t") q.set_fields(["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "source"], [], []) - q.set_order("time_value", "geo_value") - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_time_pair("time_type", "time_value", TimePair("day" if is_day else "week", [(start_day, end_day)])) - q.where_geo_pairs("geo_type", "geo_value", [GeoPair(geo_type, True if geo_values == "*" else geo_values)]) - _handle_lag_issues_as_of(q, None, None, as_of) + q.set_sort_order("time_value", "geo_value") + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_time_filter("time_type", "time_value", TimeSet("day" if is_day else "week", [(start_day, end_day)])) + q.apply_geo_filters("geo_type", "geo_value", [GeoSet(geo_type, True if geo_values == "*" else geo_values)]) + + q.apply_as_of_filter(history_table, as_of) format_date = time_value_to_iso if is_day else lambda x: time_value_to_week(x).cdcformat() # tag as_of in filename, if it was specified @@ -438,16 +365,16 @@ def handle_backfill(): example query: http://localhost:5000/covidcast/backfill?signal=fb-survey:smoothed_cli&time=day:20200101-20220101&geo=state:ny&anchor_lag=60 """ require_all("geo", "time", "signal") - signal_pair = parse_single_source_signal_arg("signal") - daily_signals, weekly_signals = count_signal_time_types([signal_pair]) - source_signal_pairs, _ = create_source_signal_alias_mapper([signal_pair]) + source_signal_set = parse_single_source_signal_arg("signal") + daily_signals, weekly_signals = count_signal_time_types([source_signal_set]) + source_signal_sets, _ = create_source_signal_alias_mapper([source_signal_set]) # don't need the alias mapper since we don't return the source - time_pair = parse_single_time_arg("time") - is_day = time_pair.is_day + time_set = parse_single_time_arg("time") + is_day = time_set.is_day _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) - geo_pair = parse_single_geo_arg("geo") + geo_set = parse_single_geo_arg("geo") reference_anchor_lag = extract_integer("anchor_lag") # in days or weeks if reference_anchor_lag is None: reference_anchor_lag = 60 @@ -459,15 +386,12 @@ def handle_backfill(): fields_int = ["time_value", "issue"] fields_float = ["value", "sample_size"] # sort by time value and issue asc - q.set_order(time_value=True, issue=True) + q.set_sort_order("time_value", "issue") q.set_fields(fields_string, fields_int, fields_float, ["is_latest_issue"]) - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_geo_pairs("geo_type", "geo_value", [geo_pair]) - q.where_time_pair("time_type", "time_value", time_pair) - - # no restriction of issues or dates since we want all issues - # _handle_lag_issues_as_of(q, issues, lag, as_of) + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_geo_filters("geo_type", "geo_value", [geo_set]) + q.apply_time_filter("time_type", "time_value", time_set) p = create_printer() @@ -596,9 +520,9 @@ def handle_coverage(): similar to /signal_dashboard_coverage for a specific signal returns the coverage (number of locations for a given geo_type) """ - source_signal_pairs = parse_source_signal_pairs() - daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) + source_signal_sets = parse_source_signal_sets() + daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) + source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) geo_type = request.values.get("geo_type", "county") if "window" in request.values: @@ -615,13 +539,13 @@ def handle_coverage(): last_weeks = last or 30 is_day = False now_week = Week.thisweek() if now_time is None else time_value_to_week(now_time) - time_window = TimePair("week", [(week_to_time_value(now_week - last_weeks), week_to_time_value(now_week))]) + time_window = TimeSet("week", [(week_to_time_value(now_week - last_weeks), week_to_time_value(now_week))]) else: is_day = True if last is None: last = 30 now = date.today() if now_time is None else time_value_to_day(now_time) - time_window = TimePair("day", [(day_to_time_value(now - timedelta(days=last)), day_to_time_value(now))]) + time_window = TimeSet("day", [(day_to_time_value(now - timedelta(days=last)), day_to_time_value(now))]) _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) q = QueryBuilder(latest_table, "c") @@ -639,12 +563,10 @@ def handle_coverage(): q.conditions.append('geo_value not like "%000"') else: q.where(geo_type=geo_type) - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_time_pair("time_type", "time_value", time_window) + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_time_filter("time_type", "time_value", time_window) q.group_by = "c.source, c.signal, c.time_value" - q.set_order("source", "signal", "time_value") - - _handle_lag_issues_as_of(q, None, None, None) + q.set_sort_order("source", "signal", "time_value") def transform_row(row, proxy): if not alias_mapper or "source" not in row: diff --git a/src/server/endpoints/covidcast_meta.py b/src/server/endpoints/covidcast_meta.py index 08e919d24..86eeb8b64 100644 --- a/src/server/endpoints/covidcast_meta.py +++ b/src/server/endpoints/covidcast_meta.py @@ -6,9 +6,9 @@ from sqlalchemy import text from .._common import db +from .._params import extract_strings from .._printer import create_printer from .._query import filter_fields -from .._validate import extract_strings from ..utils.logger import get_structured_logger bp = Blueprint("covidcast_meta", __name__) diff --git a/src/server/endpoints/covidcast_nowcast.py b/src/server/endpoints/covidcast_nowcast.py index ae47259f8..d71ff9404 100644 --- a/src/server/endpoints/covidcast_nowcast.py +++ b/src/server/endpoints/covidcast_nowcast.py @@ -1,11 +1,13 @@ from flask import Blueprint, request -from .._query import execute_query, filter_integers, filter_strings -from .._validate import ( +from .._params import ( extract_date, extract_dates, extract_integer, - extract_strings, + extract_strings +) +from .._query import execute_query, filter_integers, filter_strings +from .._validate import ( require_all, require_any, ) diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index 154bb3668..abab0033b 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -6,7 +6,7 @@ import pandas as pd import numpy as np -from ..._params import SourceSignalPair +from ..._params import SourceSignalSet class HighValuesAre(str, Enum): @@ -236,23 +236,18 @@ def _load_data_signals(sources: List[DataSource]): data_signals_by_key[(source.db_source, d.signal)] = d - -def get_related_signals(signal: DataSignal) -> List[DataSignal]: - return [s for s in data_signals if s != signal and s.signal_basename == signal.signal_basename] - - -def count_signal_time_types(source_signals: List[SourceSignalPair]) -> Tuple[int, int]: +def count_signal_time_types(source_signals: List[SourceSignalSet]) -> Tuple[int, int]: """ count the number of signals in this query for each time type @returns daily counts, weekly counts """ weekly = 0 daily = 0 - for pair in source_signals: - if pair.signal == True: + for ssset in source_signals: + if ssset.signal == True: continue - for s in pair.signal: - signal = data_signals_by_key.get((pair.source, s)) + for s in ssset.signal: + signal = data_signals_by_key.get((ssset.source, s)) if not signal: continue if signal.time_type == TimeType.week: @@ -262,21 +257,21 @@ def count_signal_time_types(source_signals: List[SourceSignalPair]) -> Tuple[int return daily, weekly -def create_source_signal_alias_mapper(source_signals: List[SourceSignalPair]) -> Tuple[List[SourceSignalPair], Optional[Callable[[str, str], str]]]: +def create_source_signal_alias_mapper(source_signals: List[SourceSignalSet]) -> Tuple[List[SourceSignalSet], Optional[Callable[[str, str], str]]]: alias_to_data_sources: Dict[str, List[DataSource]] = {} - transformed_pairs: List[SourceSignalPair] = [] - for pair in source_signals: - source = data_source_by_id.get(pair.source) + transformed_sets: List[SourceSignalSet] = [] + for ssset in source_signals: + source = data_source_by_id.get(ssset.source) if not source or not source.uses_db_alias: - transformed_pairs.append(pair) + transformed_sets.append(ssset) continue # uses an alias alias_to_data_sources.setdefault(source.db_source, []).append(source) - if pair.signal is True: + if ssset.signal is True: # list all signals of this source (*) so resolve to a plain list of all in this alias - transformed_pairs.append(SourceSignalPair(source.db_source, [s.signal for s in source.signals])) + transformed_sets.append(SourceSignalSet(source.db_source, [s.signal for s in source.signals])) else: - transformed_pairs.append(SourceSignalPair(source.db_source, pair.signal)) + transformed_sets.append(SourceSignalSet(source.db_source, ssset.signal)) if not alias_to_data_sources: # no alias needed @@ -299,4 +294,4 @@ def map_row(source: str, signal: str) -> str: signal_source = possible_data_sources[0] return signal_source.source - return transformed_pairs, map_row + return transformed_sets, map_row diff --git a/src/server/endpoints/dengue_nowcast.py b/src/server/endpoints/dengue_nowcast.py index 206d4dff0..f77f6bd18 100644 --- a/src/server/endpoints/dengue_nowcast.py +++ b/src/server/endpoints/dengue_nowcast.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("dengue_nowcast", __name__) @@ -22,7 +23,7 @@ def handle(): fields_float = ["value", "std"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "location") + q.set_sort_order("epiweek", "location") # build the filter q.where_strings("location", locations) diff --git a/src/server/endpoints/dengue_sensors.py b/src/server/endpoints/dengue_sensors.py index df3672209..0837dc3fc 100644 --- a/src/server/endpoints/dengue_sensors.py +++ b/src/server/endpoints/dengue_sensors.py @@ -1,8 +1,9 @@ from flask import Blueprint from .._config import AUTH +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import check_auth_token, extract_integers, extract_strings, require_all +from .._validate import check_auth_token, require_all # first argument is the endpoint name bp = Blueprint("dengue_sensors", __name__) @@ -26,7 +27,7 @@ def handle(): fields_float = ["value"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order('epiweek', 'name', 'location') + q.set_sort_order('epiweek', 'name', 'location') q.where_strings('name', names) q.where_strings('location', locations) diff --git a/src/server/endpoints/ecdc_ili.py b/src/server/endpoints/ecdc_ili.py index 75b253b1e..b15dc7cb2 100644 --- a/src/server/endpoints/ecdc_ili.py +++ b/src/server/endpoints/ecdc_ili.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integer, extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integer, extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("ecdc_ili", __name__) @@ -24,7 +25,7 @@ def handle(): fields_float = ["incidence_rate"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "region", "issue") + q.set_sort_order("epiweek", "region", "issue") q.where_integers("epiweek", epiweeks) q.where_strings("region", regions) diff --git a/src/server/endpoints/flusurv.py b/src/server/endpoints/flusurv.py index 5205056db..67e842cb8 100644 --- a/src/server/endpoints/flusurv.py +++ b/src/server/endpoints/flusurv.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integer, extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integer, extract_integers, extract_strings, require_all +from .._validate import require_all bp = Blueprint("flusurv", __name__) @@ -29,7 +30,7 @@ def handle(): "rate_overall", ] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "location", "issue") + q.set_sort_order("epiweek", "location", "issue") q.where_integers("epiweek", epiweeks) q.where_strings("location", locations) diff --git a/src/server/endpoints/fluview.py b/src/server/endpoints/fluview.py index 8b92fa052..75e928c86 100644 --- a/src/server/endpoints/fluview.py +++ b/src/server/endpoints/fluview.py @@ -3,12 +3,14 @@ from flask import Blueprint from .._config import AUTH -from .._query import execute_queries, filter_integers, filter_strings -from .._validate import ( - check_auth_token, +from .._params import ( extract_integer, extract_integers, extract_strings, +) +from .._query import execute_queries, filter_integers, filter_strings +from .._validate import ( + check_auth_token, require_all, ) diff --git a/src/server/endpoints/fluview_clinicial.py b/src/server/endpoints/fluview_clinicial.py index 650ec9add..e213a1638 100644 --- a/src/server/endpoints/fluview_clinicial.py +++ b/src/server/endpoints/fluview_clinicial.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integer, extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integer, extract_integers, extract_strings, require_all +from .._validate import require_all bp = Blueprint("fluview_clinical", __name__) @@ -22,7 +23,7 @@ def handle(): fields_int = ["issue", "epiweek", "lag", "total_specimens", "total_a", "total_b"] fields_float = ["percent_positive", "percent_a", "percent_b"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "region", "issue") + q.set_sort_order("epiweek", "region", "issue") q.where_integers("epiweek", epiweeks) q.where_strings("region", regions) diff --git a/src/server/endpoints/gft.py b/src/server/endpoints/gft.py index 8179b3522..343f565f4 100644 --- a/src/server/endpoints/gft.py +++ b/src/server/endpoints/gft.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("gft", __name__) @@ -22,7 +23,7 @@ def handle(): fields_int = ["epiweek", "num"] fields_float = [] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "location") + q.set_sort_order("epiweek", "location") # build the filter q.where_integers("epiweek", epiweeks) diff --git a/src/server/endpoints/ght.py b/src/server/endpoints/ght.py index 3d5c0dec1..ab858e79c 100644 --- a/src/server/endpoints/ght.py +++ b/src/server/endpoints/ght.py @@ -1,8 +1,9 @@ from flask import Blueprint, request from .._config import AUTH +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import check_auth_token, extract_integers, extract_strings, require_all +from .._validate import check_auth_token, require_all # first argument is the endpoint name bp = Blueprint("ght", __name__) @@ -26,7 +27,7 @@ def handle(): fields_float = ["value"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "location") + q.set_sort_order("epiweek", "location") # build the filter q.where_strings("location", locations) diff --git a/src/server/endpoints/kcdc_ili.py b/src/server/endpoints/kcdc_ili.py index 08158cdaf..fc9328898 100644 --- a/src/server/endpoints/kcdc_ili.py +++ b/src/server/endpoints/kcdc_ili.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integer, extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integer, extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("kcdc_ili", __name__) @@ -24,7 +25,7 @@ def handle(): fields_float = ["ili"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "region", "issue") + q.set_sort_order("epiweek", "region", "issue") # build the filter q.where_integers("epiweek", epiweeks) q.where_strings("region", regions) diff --git a/src/server/endpoints/nidss_dengue.py b/src/server/endpoints/nidss_dengue.py index 131f6eb9a..8d7c12624 100644 --- a/src/server/endpoints/nidss_dengue.py +++ b/src/server/endpoints/nidss_dengue.py @@ -2,8 +2,9 @@ from flask import Blueprint +from .._params import extract_integers, extract_strings from .._query import execute_queries, filter_integers -from .._validate import extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("nidss_dengue", __name__) diff --git a/src/server/endpoints/nidss_flu.py b/src/server/endpoints/nidss_flu.py index 3caf099dc..8eb7d3b56 100644 --- a/src/server/endpoints/nidss_flu.py +++ b/src/server/endpoints/nidss_flu.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integer, extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integer, extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("nidss_flu", __name__) @@ -23,7 +24,7 @@ def handle(): fields_int = ["issue", "epiweek", "lag", "visits"] fields_float = ["ili"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order("epiweek", "region", "issue") + q.set_sort_order("epiweek", "region", "issue") # build the filter q.where_integers("epiweek", epiweeks) diff --git a/src/server/endpoints/norostat.py b/src/server/endpoints/norostat.py index 9586f8e3f..24867a8d4 100644 --- a/src/server/endpoints/norostat.py +++ b/src/server/endpoints/norostat.py @@ -1,8 +1,9 @@ from flask import Blueprint, request from .._config import AUTH +from .._params import extract_integers from .._query import execute_query, filter_integers, filter_strings -from .._validate import check_auth_token, extract_integers, require_all +from .._validate import check_auth_token, require_all # first argument is the endpoint name bp = Blueprint("norostat", __name__) diff --git a/src/server/endpoints/nowcast.py b/src/server/endpoints/nowcast.py index 88ee83400..77c535ee6 100644 --- a/src/server/endpoints/nowcast.py +++ b/src/server/endpoints/nowcast.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("nowcast", __name__) @@ -22,7 +23,7 @@ def handle(): fields_float = ["value", "std"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order(epiweek=True, location=True) + q.set_sort_order("epiweek", "location") # build the filter q.where_strings("location", locations) diff --git a/src/server/endpoints/paho_dengue.py b/src/server/endpoints/paho_dengue.py index 3afd11a6f..e793a7c17 100644 --- a/src/server/endpoints/paho_dengue.py +++ b/src/server/endpoints/paho_dengue.py @@ -1,7 +1,8 @@ from flask import Blueprint +from .._params import extract_integer, extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import extract_integer, extract_integers, extract_strings, require_all +from .._validate import require_all # first argument is the endpoint name bp = Blueprint("paho_dengue", __name__) @@ -32,7 +33,7 @@ def handle(): fields_float = ["incidence_rate"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order(epiweek=True, region=True, issue=True) + q.set_sort_order("epiweek", "region", "issue") # build the filter q.where_integers("epiweek", epiweeks) diff --git a/src/server/endpoints/quidel.py b/src/server/endpoints/quidel.py index c32a8a040..081706190 100644 --- a/src/server/endpoints/quidel.py +++ b/src/server/endpoints/quidel.py @@ -1,8 +1,9 @@ from flask import Blueprint from .._config import AUTH +from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import check_auth_token, extract_integers, extract_strings, require_all +from .._validate import check_auth_token, require_all # first argument is the endpoint name bp = Blueprint("quidel", __name__) @@ -25,7 +26,7 @@ def handle(): fields_float = ["value"] q.set_fields(fields_string, fields_int, fields_float) - q.set_order(epiweek=True, location=True) + q.set_sort_order("epiweek", "location") # build the filter q.where_strings("location", locations) diff --git a/src/server/endpoints/sensors.py b/src/server/endpoints/sensors.py index 68199e2b1..f803dd396 100644 --- a/src/server/endpoints/sensors.py +++ b/src/server/endpoints/sensors.py @@ -1,14 +1,16 @@ from flask import Blueprint from .._config import AUTH, GRANULAR_SENSOR_AUTH_TOKENS, OPEN_SENSORS -from .._validate import ( - require_all, +from .._exceptions import EpiDataException +from .._params import ( extract_strings, extract_integers, - resolve_auth_token, ) from .._query import filter_strings, execute_query, filter_integers -from .._exceptions import EpiDataException +from .._validate import ( + require_all, + resolve_auth_token, +) from typing import List # first argument is the endpoint name diff --git a/src/server/endpoints/twitter.py b/src/server/endpoints/twitter.py index 78b297ef8..41cbe3492 100644 --- a/src/server/endpoints/twitter.py +++ b/src/server/endpoints/twitter.py @@ -1,11 +1,13 @@ from flask import Blueprint, request from .._config import AUTH, NATION_REGION, REGION_TO_STATE +from .._params import ( + extract_integers, + extract_strings, +) from .._query import execute_queries, filter_dates, filter_integers, filter_strings from .._validate import ( check_auth_token, - extract_integers, - extract_strings, require_all, require_any, ) diff --git a/src/server/endpoints/wiki.py b/src/server/endpoints/wiki.py index a6bfcb27f..61139578f 100644 --- a/src/server/endpoints/wiki.py +++ b/src/server/endpoints/wiki.py @@ -1,7 +1,8 @@ from flask import Blueprint, request +from .._params import extract_integers, extract_strings from .._query import execute_query, filter_dates, filter_integers, filter_strings -from .._validate import extract_integers, extract_strings, require_all, require_any +from .._validate import require_all, require_any # first argument is the endpoint name bp = Blueprint("wiki", __name__) diff --git a/src/server/utils/__init__.py b/src/server/utils/__init__.py index efab6c030..2e99dfeba 100644 --- a/src/server/utils/__init__.py +++ b/src/server/utils/__init__.py @@ -1 +1 @@ -from .dates import shift_day_value, day_to_time_value, time_value_to_iso, time_value_to_day, days_in_range, weeks_in_range, shift_week_value, week_to_time_value, time_value_to_week, guess_time_value_is_day, guess_time_value_is_week, time_values_to_ranges, days_to_ranges, weeks_to_ranges, TimeValues +from .dates import shift_day_value, day_to_time_value, time_value_to_iso, time_value_to_day, days_in_range, weeks_in_range, shift_week_value, week_to_time_value, time_value_to_week, guess_time_value_is_day, guess_time_value_is_week, time_values_to_ranges, days_to_ranges, weeks_to_ranges, IntRange, TimeValues diff --git a/src/server/utils/dates.py b/src/server/utils/dates.py index b85465bb8..126f79383 100644 --- a/src/server/utils/dates.py +++ b/src/server/utils/dates.py @@ -13,7 +13,8 @@ from .logger import get_structured_logger # Alias for a sequence of date ranges (int, int) or date integers -TimeValues: TypeAlias = Sequence[Union[Tuple[int, int], int]] +IntRange: TypeAlias = Union[Tuple[int, int], int] +TimeValues: TypeAlias = Sequence[IntRange] def time_value_to_day(value: int) -> date: year, month, day = value // 10000, (value % 10000) // 100, value % 100 diff --git a/tests/acquisition/covid_hosp/common/test_database.py b/tests/acquisition/covid_hosp/common/test_database.py index 12f834032..09244dd2f 100644 --- a/tests/acquisition/covid_hosp/common/test_database.py +++ b/tests/acquisition/covid_hosp/common/test_database.py @@ -68,7 +68,7 @@ def test_contains_revision(self): mock_connection = MagicMock() mock_cursor = mock_connection.cursor() - database = Database(mock_connection, table_name=sentinel.table_name) + database = Database(mock_connection, table_name=sentinel.table_name, hhs_dataset_id=sentinel.hhs_dataset_id) with self.subTest(name='new revision'): mock_cursor.__iter__.return_value = [(0,)] @@ -78,7 +78,7 @@ def test_contains_revision(self): # compare with boolean literal to test the type cast self.assertIs(result, False) query_values = mock_cursor.execute.call_args[0][-1] - self.assertEqual(query_values, (sentinel.table_name, sentinel.revision)) + self.assertEqual(query_values, (sentinel.hhs_dataset_id, sentinel.revision)) with self.subTest(name='old revision'): mock_cursor.__iter__.return_value = [(1,)] @@ -88,7 +88,7 @@ def test_contains_revision(self): # compare with boolean literal to test the type cast self.assertIs(result, True) query_values = mock_cursor.execute.call_args[0][-1] - self.assertEqual(query_values, (sentinel.table_name, sentinel.revision)) + self.assertEqual(query_values, (sentinel.hhs_dataset_id, sentinel.revision)) def test_insert_metadata(self): """Add new metadata to the database.""" @@ -98,7 +98,7 @@ def test_insert_metadata(self): mock_connection = MagicMock() mock_cursor = mock_connection.cursor() - database = Database(mock_connection, table_name=sentinel.dataset_name) + database = Database(mock_connection, table_name=sentinel.table_name, hhs_dataset_id=sentinel.hhs_dataset_id) result = database.insert_metadata( sentinel.publication_date, @@ -108,7 +108,8 @@ def test_insert_metadata(self): self.assertIsNone(result) actual_values = mock_cursor.execute.call_args[0][-1] expected_values = ( - sentinel.dataset_name, + sentinel.table_name, + sentinel.hhs_dataset_id, sentinel.publication_date, sentinel.revision, sentinel.meta_json, diff --git a/tests/acquisition/covid_hosp/common/test_utils.py b/tests/acquisition/covid_hosp/common/test_utils.py index 1284e1a87..85dbd110c 100644 --- a/tests/acquisition/covid_hosp/common/test_utils.py +++ b/tests/acquisition/covid_hosp/common/test_utils.py @@ -129,7 +129,10 @@ def test_run_acquire_new_dataset(self): self.assertTrue(result) - mock_connection.insert_metadata.assert_called_once() + # should have been called twice + mock_connection.insert_metadata.assert_called() + assert mock_connection.insert_metadata.call_count == 2 + # most recent call should be for the final revision at url2 args = mock_connection.insert_metadata.call_args[0] self.assertEqual(args[:2], (20210315, "url2")) pd.testing.assert_frame_equal( diff --git a/tests/server/endpoints/test_covidcast.py b/tests/server/endpoints/test_covidcast.py index b7ecdc263..bb4fee873 100644 --- a/tests/server/endpoints/test_covidcast.py +++ b/tests/server/endpoints/test_covidcast.py @@ -5,11 +5,6 @@ from flask import Response from delphi.epidata.server.main import app -from delphi.epidata.server._params import ( - GeoPair, - TimePair, -) - # py3tester coverage target __test_target__ = "delphi.epidata.server.endpoints.covidcast" diff --git a/tests/server/test_params.py b/tests/server/test_params.py index 2d22a5d37..177ff5cba 100644 --- a/tests/server/test_params.py +++ b/tests/server/test_params.py @@ -7,6 +7,11 @@ # from flask.testing import FlaskClient from delphi.epidata.server._common import app from delphi.epidata.server._params import ( + extract_strings, + extract_integers, + extract_integer, + extract_date, + extract_dates, parse_geo_arg, parse_single_geo_arg, parse_source_signal_arg, @@ -16,9 +21,9 @@ parse_week_value, parse_day_range_arg, parse_day_arg, - GeoPair, - TimePair, - SourceSignalPair, + GeoSet, + TimeSet, + SourceSignalSet, ) from delphi.epidata.server._exceptions import ( ValidationFailedException, @@ -38,46 +43,46 @@ def setUp(self): app.config["WTF_CSRF_ENABLED"] = False app.config["DEBUG"] = False - def test_geo_pair(self): + def test_geo_set(self): with self.subTest("*"): - p = GeoPair("hrr", True) + p = GeoSet("hrr", True) self.assertTrue(p.matches("hrr", "any")) self.assertFalse(p.matches("msa", "any")) with self.subTest("subset"): - p = GeoPair("hrr", ["a", "b"]) + p = GeoSet("hrr", ["a", "b"]) self.assertTrue(p.matches("hrr", "a")) self.assertTrue(p.matches("hrr", "b")) self.assertFalse(p.matches("hrr", "c")) self.assertFalse(p.matches("msa", "any")) with self.subTest("count"): - self.assertEqual(GeoPair("a", True).count(), inf) - self.assertEqual(GeoPair("a", False).count(), 0) - self.assertEqual(GeoPair("a", ["a", "b"]).count(), 2) + self.assertEqual(GeoSet("a", True).count(), inf) + self.assertEqual(GeoSet("a", False).count(), 0) + self.assertEqual(GeoSet("a", ["a", "b"]).count(), 2) - def test_source_signal_pair(self): + def test_source_signal_set(self): with self.subTest("*"): - p = SourceSignalPair("src1", True) + p = SourceSignalSet("src1", True) self.assertTrue(p.matches("src1", "any")) self.assertFalse(p.matches("src2", "any")) with self.subTest("subset"): - p = SourceSignalPair("src1", ["a", "b"]) + p = SourceSignalSet("src1", ["a", "b"]) self.assertTrue(p.matches("src1", "a")) self.assertTrue(p.matches("src1", "b")) self.assertFalse(p.matches("src1", "c")) self.assertFalse(p.matches("src2", "any")) with self.subTest("count"): - self.assertEqual(SourceSignalPair("a", True).count(), inf) - self.assertEqual(SourceSignalPair("a", False).count(), 0) - self.assertEqual(SourceSignalPair("a", ["a", "b"]).count(), 2) + self.assertEqual(SourceSignalSet("a", True).count(), inf) + self.assertEqual(SourceSignalSet("a", False).count(), 0) + self.assertEqual(SourceSignalSet("a", ["a", "b"]).count(), 2) - def test_time_pair(self): + def test_time_set(self): with self.subTest("count"): - self.assertEqual(TimePair("day", True).count(), inf) - self.assertEqual(TimePair("day", False).count(), 0) - self.assertEqual(TimePair("day", [20200202, 20200201]).count(), 2) - self.assertEqual(TimePair("day", [(20200201, 20200202)]).count(), 2) - self.assertEqual(TimePair("day", [(20200201, 20200205)]).count(), 5) - self.assertEqual(TimePair("day", [(20200201, 20200205), 20201212]).count(), 6) + self.assertEqual(TimeSet("day", True).count(), inf) + self.assertEqual(TimeSet("day", False).count(), 0) + self.assertEqual(TimeSet("day", [20200202, 20200201]).count(), 2) + self.assertEqual(TimeSet("day", [(20200201, 20200202)]).count(), 2) + self.assertEqual(TimeSet("day", [(20200201, 20200205)]).count(), 5) + self.assertEqual(TimeSet("day", [(20200201, 20200205), 20201212]).count(), 6) def test_parse_geo_arg(self): with self.subTest("empty"): @@ -85,32 +90,32 @@ def test_parse_geo_arg(self): self.assertEqual(parse_geo_arg(), []) with self.subTest("single"): with app.test_request_context("/?geo=state:*"): - self.assertEqual(parse_geo_arg(), [GeoPair("state", True)]) + self.assertEqual(parse_geo_arg(), [GeoSet("state", True)]) with app.test_request_context("/?geo=state:AK"): - self.assertEqual(parse_geo_arg(), [GeoPair("state", ["ak"])]) + self.assertEqual(parse_geo_arg(), [GeoSet("state", ["ak"])]) with self.subTest("single list"): with app.test_request_context("/?geo=state:AK,TK"): - self.assertEqual(parse_geo_arg(), [GeoPair("state", ["ak", "tk"])]) + self.assertEqual(parse_geo_arg(), [GeoSet("state", ["ak", "tk"])]) with self.subTest("multi"): with app.test_request_context("/?geo=state:*;nation:*"): - self.assertEqual(parse_geo_arg(), [GeoPair("state", True), GeoPair("nation", True)]) + self.assertEqual(parse_geo_arg(), [GeoSet("state", True), GeoSet("nation", True)]) with app.test_request_context("/?geo=state:AK;nation:US"): self.assertEqual( parse_geo_arg(), - [GeoPair("state", ["ak"]), GeoPair("nation", ["us"])], + [GeoSet("state", ["ak"]), GeoSet("nation", ["us"])], ) with app.test_request_context("/?geo=state:AK;state:KY"): self.assertEqual( parse_geo_arg(), - [GeoPair("state", ["ak"]), GeoPair("state", ["ky"])], + [GeoSet("state", ["ak"]), GeoSet("state", ["ky"])], ) with self.subTest("multi list"): with app.test_request_context("/?geo=state:AK,TK;county:42003,40556"): self.assertEqual( parse_geo_arg(), [ - GeoPair("state", ["ak", "tk"]), - GeoPair("county", ["42003", "40556"]), + GeoSet("state", ["ak", "tk"]), + GeoSet("county", ["42003", "40556"]), ], ) with self.subTest("hybrid"): @@ -118,9 +123,9 @@ def test_parse_geo_arg(self): self.assertEqual( parse_geo_arg(), [ - GeoPair("nation", True), - GeoPair("state", ["pa"]), - GeoPair("county", ["42003", "42002"]), + GeoSet("nation", True), + GeoSet("state", ["pa"]), + GeoSet("county", ["42003", "42002"]), ], ) @@ -136,7 +141,7 @@ def test_single_parse_geo_arg(self): self.assertRaises(ValidationFailedException, parse_single_geo_arg, "geo") with self.subTest("single"): with app.test_request_context("/?geo=state:AK"): - self.assertEqual(parse_single_geo_arg("geo"), GeoPair("state", ["ak"])) + self.assertEqual(parse_single_geo_arg("geo"), GeoSet("state", ["ak"])) with self.subTest("single list"): with app.test_request_context("/?geo=state:AK,TK"): self.assertRaises(ValidationFailedException, parse_single_geo_arg, "geo") @@ -155,35 +160,35 @@ def test_parse_source_signal_arg(self): self.assertEqual(parse_source_signal_arg(), []) with self.subTest("single"): with app.test_request_context("/?signal=src1:*"): - self.assertEqual(parse_source_signal_arg(), [SourceSignalPair("src1", True)]) + self.assertEqual(parse_source_signal_arg(), [SourceSignalSet("src1", True)]) with app.test_request_context("/?signal=src1:sig1"): - self.assertEqual(parse_source_signal_arg(), [SourceSignalPair("src1", ["sig1"])]) + self.assertEqual(parse_source_signal_arg(), [SourceSignalSet("src1", ["sig1"])]) with self.subTest("single list"): with app.test_request_context("/?signal=src1:sig1,sig2"): self.assertEqual( parse_source_signal_arg(), - [SourceSignalPair("src1", ["sig1", "sig2"])], + [SourceSignalSet("src1", ["sig1", "sig2"])], ) with self.subTest("multi"): with app.test_request_context("/?signal=src1:*;src2:*"): self.assertEqual( parse_source_signal_arg(), - [SourceSignalPair("src1", True), SourceSignalPair("src2", True)], + [SourceSignalSet("src1", True), SourceSignalSet("src2", True)], ) with app.test_request_context("/?signal=src1:sig1;src2:sig3"): self.assertEqual( parse_source_signal_arg(), [ - SourceSignalPair("src1", ["sig1"]), - SourceSignalPair("src2", ["sig3"]), + SourceSignalSet("src1", ["sig1"]), + SourceSignalSet("src2", ["sig3"]), ], ) with app.test_request_context("/?signal=src1:sig1;src1:sig4"): self.assertEqual( parse_source_signal_arg(), [ - SourceSignalPair("src1", ["sig1"]), - SourceSignalPair("src1", ["sig4"]), + SourceSignalSet("src1", ["sig1"]), + SourceSignalSet("src1", ["sig4"]), ], ) with self.subTest("multi list"): @@ -191,8 +196,8 @@ def test_parse_source_signal_arg(self): self.assertEqual( parse_source_signal_arg(), [ - SourceSignalPair("src1", ["sig1", "sig2"]), - SourceSignalPair("county", ["sig5", "sig6"]), + SourceSignalSet("src1", ["sig1", "sig2"]), + SourceSignalSet("county", ["sig5", "sig6"]), ], ) with self.subTest("hybrid"): @@ -200,9 +205,9 @@ def test_parse_source_signal_arg(self): self.assertEqual( parse_source_signal_arg(), [ - SourceSignalPair("src2", True), - SourceSignalPair("src1", ["sig4"]), - SourceSignalPair("src3", ["sig5", "sig6"]), + SourceSignalSet("src2", True), + SourceSignalSet("src1", ["sig4"]), + SourceSignalSet("src3", ["sig5", "sig6"]), ], ) @@ -218,7 +223,7 @@ def test_single_parse_source_signal_arg(self): self.assertRaises(ValidationFailedException, parse_single_source_signal_arg, "signal") with self.subTest("single"): with app.test_request_context("/?signal=src1:sig1"): - self.assertEqual(parse_single_source_signal_arg("signal"), SourceSignalPair("src1", ["sig1"])) + self.assertEqual(parse_single_source_signal_arg("signal"), SourceSignalSet("src1", ["sig1"])) with self.subTest("single list"): with app.test_request_context("/?signal=src1:sig1,sig2"): self.assertRaises(ValidationFailedException, parse_single_source_signal_arg, "signal") @@ -270,35 +275,35 @@ def test_parse_time_arg(self): self.assertEqual(parse_time_arg(), None) with self.subTest("single"): with app.test_request_context("/?time=day:*"): - self.assertEqual(parse_time_arg(), TimePair("day", True)) + self.assertEqual(parse_time_arg(), TimeSet("day", True)) with app.test_request_context("/?time=day:20201201"): - self.assertEqual(parse_time_arg(), TimePair("day", [20201201])) + self.assertEqual(parse_time_arg(), TimeSet("day", [20201201])) with self.subTest("single list"): with app.test_request_context("/?time=day:20201201,20201202"): - self.assertEqual(parse_time_arg(), TimePair("day", [20201201, 20201202])) + self.assertEqual(parse_time_arg(), TimeSet("day", [20201201, 20201202])) with self.subTest("single range"): with app.test_request_context("/?time=day:20201201-20201204"): - self.assertEqual(parse_time_arg(), TimePair("day", [(20201201, 20201204)])) + self.assertEqual(parse_time_arg(), TimeSet("day", [(20201201, 20201204)])) with self.subTest("multi"): with app.test_request_context("/?time=day:*;day:20201201"): self.assertEqual( parse_time_arg(), - TimePair("day", True) + TimeSet("day", True) ) with app.test_request_context("/?time=week:*;week:202012"): self.assertEqual( parse_time_arg(), - TimePair("week", True) + TimeSet("week", True) ) with app.test_request_context("/?time=day:20201201;day:20201202-20201205"): self.assertEqual( parse_time_arg(), - TimePair("day", [(20201201, 20201205)]) + TimeSet("day", [(20201201, 20201205)]) ) with app.test_request_context("/?time=week:202012;week:202013-202015"): self.assertEqual( parse_time_arg(), - TimePair("week", [(202012, 202015)]) + TimeSet("week", [(202012, 202015)]) ) with self.subTest("wrong"): @@ -366,3 +371,128 @@ def test_parse_day_arg(self): self.assertRaises(ValidationFailedException, parse_day_arg, "time") with app.test_request_context("/?time=week:20121010"): self.assertRaises(ValidationFailedException, parse_day_arg, "time") + + def test_extract_strings(self): + with self.subTest("empty"): + with app.test_request_context("/"): + self.assertIsNone(extract_strings("s")) + with self.subTest("single"): + with app.test_request_context("/?s=a"): + self.assertEqual(extract_strings("s"), ["a"]) + with self.subTest("multiple"): + with app.test_request_context("/?s=a,b"): + self.assertEqual(extract_strings("s"), ["a", "b"]) + with self.subTest("multiple param"): + with app.test_request_context("/?s=a&s=b"): + self.assertEqual(extract_strings("s"), ["a", "b"]) + with self.subTest("multiple param mixed"): + with app.test_request_context("/?s=a&s=b,c"): + self.assertEqual(extract_strings("s"), ["a", "b", "c"]) + + def test_extract_integer(self): + with self.subTest("empty"): + with app.test_request_context("/"): + self.assertIsNone(extract_integer("s")) + with self.subTest("single"): + with app.test_request_context("/?s=1"): + self.assertEqual(extract_integer("s"), 1) + with self.subTest("not a number"): + with app.test_request_context("/?s=a"): + self.assertRaises(ValidationFailedException, lambda: extract_integer("s")) + + def test_extract_integers(self): + with self.subTest("empty"): + with app.test_request_context("/"): + self.assertIsNone(extract_integers("s")) + with self.subTest("single"): + with app.test_request_context("/?s=1"): + self.assertEqual(extract_integers("s"), [1]) + with self.subTest("multiple"): + with app.test_request_context("/?s=1,2"): + self.assertEqual(extract_integers("s"), [1,2]) + with self.subTest("multiple param"): + with app.test_request_context("/?s=1&s=2"): + self.assertEqual(extract_integers("s"), [1,2]) + with self.subTest("multiple param mixed"): + with app.test_request_context("/?s=1&s=2,3"): + self.assertEqual(extract_integers("s"), [1, 2, 3]) + + with self.subTest("not a number"): + with app.test_request_context("/?s=a"): + self.assertRaises(ValidationFailedException, lambda: extract_integers("s")) + + with self.subTest("simple range"): + with app.test_request_context("/?s=1-2"): + self.assertEqual(extract_integers("s"), [(1, 2)]) + with self.subTest("inverted range"): + with app.test_request_context("/?s=2-1"): + self.assertRaises(ValidationFailedException, lambda: extract_integers("s")) + with self.subTest("single range"): + with app.test_request_context("/?s=1-1"): + self.assertEqual(extract_integers("s"), [1]) + + def test_extract_date(self): + with self.subTest("empty"): + with app.test_request_context("/"): + self.assertIsNone(extract_date("s")) + with self.subTest("single"): + with app.test_request_context("/?s=2020-01-01"): + self.assertEqual(extract_date("s"), 20200101) + with app.test_request_context("/?s=20200101"): + self.assertEqual(extract_date("s"), 20200101) + with self.subTest("not a date"): + with app.test_request_context("/?s=abc"): + self.assertRaises(ValidationFailedException, lambda: extract_date("s")) + + def test_extract_dates(self): + with self.subTest("empty"): + with app.test_request_context("/"): + self.assertIsNone(extract_dates("s")) + with self.subTest("single"): + with app.test_request_context("/?s=20200101"): + self.assertEqual(extract_dates("s"), [20200101]) + with self.subTest("multiple"): + with app.test_request_context("/?s=20200101,20200102"): + self.assertEqual(extract_dates("s"), [20200101, 20200102]) + with self.subTest("multiple param"): + with app.test_request_context("/?s=20200101&s=20200102"): + self.assertEqual(extract_dates("s"), [20200101, 20200102]) + with self.subTest("multiple param mixed"): + with app.test_request_context("/?s=20200101&s=20200102,20200103"): + self.assertEqual(extract_dates("s"), [20200101, 20200102, 20200103]) + with self.subTest("single iso"): + with app.test_request_context("/?s=2020-01-01"): + self.assertEqual(extract_dates("s"), [20200101]) + with self.subTest("multiple iso"): + with app.test_request_context("/?s=2020-01-01,2020-01-02"): + self.assertEqual(extract_dates("s"), [20200101, 20200102]) + with self.subTest("multiple param iso"): + with app.test_request_context("/?s=2020-01-01&s=2020-01-02"): + self.assertEqual(extract_dates("s"), [20200101, 20200102]) + with self.subTest("multiple param mixed iso"): + with app.test_request_context("/?s=2020-01-01&s=2020-01-02,2020-01-03"): + self.assertEqual(extract_dates("s"), [20200101, 20200102, 20200103]) + + with self.subTest("not a date"): + with app.test_request_context("/?s=a"): + self.assertRaises(ValidationFailedException, lambda: extract_dates("s")) + + with self.subTest("simple range"): + with app.test_request_context("/?s=20200101-20200102"): + self.assertEqual(extract_dates("s"), [(20200101, 20200102)]) + with self.subTest("inverted range"): + with app.test_request_context("/?s=20200102-20200101"): + self.assertRaises(ValidationFailedException, lambda: extract_dates("s")) + with self.subTest("single range"): + with app.test_request_context("/?s=20200101-20200101"): + self.assertEqual(extract_dates("s"), [20200101]) + + with self.subTest("simple range iso"): + with app.test_request_context("/?s=2020-01-01:2020-01-02"): + self.assertEqual(extract_dates("s"), [(20200101, 20200102)]) + with self.subTest("inverted range iso"): + with app.test_request_context("/?s=2020-01-02:2020-01-01"): + self.assertRaises(ValidationFailedException, lambda: extract_dates("s")) + with self.subTest("single range iso"): + with app.test_request_context("/?s=2020-01-01:2020-01-01"): + self.assertEqual(extract_dates("s"), [20200101]) diff --git a/tests/server/test_query.py b/tests/server/test_query.py index a1292764f..53aca5621 100644 --- a/tests/server/test_query.py +++ b/tests/server/test_query.py @@ -12,14 +12,14 @@ filter_strings, filter_integers, filter_dates, - filter_geo_pairs, - filter_source_signal_pairs, - filter_time_pair, + filter_geo_sets, + filter_source_signal_sets, + filter_time_set, ) from delphi.epidata.server._params import ( - GeoPair, - TimePair, - SourceSignalPair, + GeoSet, + TimeSet, + SourceSignalSet, ) # py3tester coverage target @@ -130,39 +130,39 @@ def test_filter_dates(self): }, ) - def test_filter_geo_pairs(self): + def test_filter_geo_sets(self): with self.subTest("empty"): params = {} - self.assertEqual(filter_geo_pairs("t", "v", [], "p", params), "FALSE") + self.assertEqual(filter_geo_sets("t", "v", [], "p", params), "FALSE") self.assertEqual(params, {}) with self.subTest("*"): params = {} self.assertEqual( - filter_geo_pairs("t", "v", [GeoPair("state", True)], "p", params), + filter_geo_sets("t", "v", [GeoSet("state", True)], "p", params), "(t = :p_0t)", ) self.assertEqual(params, {"p_0t": "state"}) with self.subTest("single"): params = {} self.assertEqual( - filter_geo_pairs("t", "v", [GeoPair("state", ["KY"])], "p", params), + filter_geo_sets("t", "v", [GeoSet("state", ["KY"])], "p", params), "((t = :p_0t AND (v = :p_0t_0)))", ) self.assertEqual(params, {"p_0t": "state", "p_0t_0": "KY"}) with self.subTest("multi"): params = {} self.assertEqual( - filter_geo_pairs("t", "v", [GeoPair("state", ["KY", "AK"])], "p", params), + filter_geo_sets("t", "v", [GeoSet("state", ["KY", "AK"])], "p", params), "((t = :p_0t AND (v = :p_0t_0 OR v = :p_0t_1)))", ) self.assertEqual(params, {"p_0t": "state", "p_0t_0": "KY", "p_0t_1": "AK"}) with self.subTest("multiple pairs"): params = {} self.assertEqual( - filter_geo_pairs( + filter_geo_sets( "t", "v", - [GeoPair("state", True), GeoPair("nation", True)], + [GeoSet("state", True), GeoSet("nation", True)], "p", params, ), @@ -172,10 +172,10 @@ def test_filter_geo_pairs(self): with self.subTest("multiple pairs with value"): params = {} self.assertEqual( - filter_geo_pairs( + filter_geo_sets( "t", "v", - [GeoPair("state", ["AK"]), GeoPair("nation", ["US"])], + [GeoSet("state", ["AK"]), GeoSet("nation", ["US"])], "p", params, ), @@ -186,39 +186,39 @@ def test_filter_geo_pairs(self): {"p_0t": "state", "p_0t_0": "AK", "p_1t": "nation", "p_1t_0": "US"}, ) - def test_filter_source_signal_pairs(self): + def test_filter_source_signal_sets(self): with self.subTest("empty"): params = {} - self.assertEqual(filter_source_signal_pairs("t", "v", [], "p", params), "FALSE") + self.assertEqual(filter_source_signal_sets("t", "v", [], "p", params), "FALSE") self.assertEqual(params, {}) with self.subTest("*"): params = {} self.assertEqual( - filter_source_signal_pairs("t", "v", [SourceSignalPair("src1", True)], "p", params), + filter_source_signal_sets("t", "v", [SourceSignalSet("src1", True)], "p", params), "(t = :p_0t)", ) self.assertEqual(params, {"p_0t": "src1"}) with self.subTest("single"): params = {} self.assertEqual( - filter_source_signal_pairs("t", "v", [SourceSignalPair("src1", ["sig1"])], "p", params), + filter_source_signal_sets("t", "v", [SourceSignalSet("src1", ["sig1"])], "p", params), "((t = :p_0t AND (v = :p_0t_0)))", ) self.assertEqual(params, {"p_0t": "src1", "p_0t_0": "sig1"}) with self.subTest("multi"): params = {} self.assertEqual( - filter_source_signal_pairs("t", "v", [SourceSignalPair("src1", ["sig1", "sig2"])], "p", params), + filter_source_signal_sets("t", "v", [SourceSignalSet("src1", ["sig1", "sig2"])], "p", params), "((t = :p_0t AND (v = :p_0t_0 OR v = :p_0t_1)))", ) self.assertEqual(params, {"p_0t": "src1", "p_0t_0": "sig1", "p_0t_1": "sig2"}) with self.subTest("multiple pairs"): params = {} self.assertEqual( - filter_source_signal_pairs( + filter_source_signal_sets( "t", "v", - [SourceSignalPair("src1", True), SourceSignalPair("src2", True)], + [SourceSignalSet("src1", True), SourceSignalSet("src2", True)], "p", params, ), @@ -228,12 +228,12 @@ def test_filter_source_signal_pairs(self): with self.subTest("multiple pairs with value"): params = {} self.assertEqual( - filter_source_signal_pairs( + filter_source_signal_sets( "t", "v", [ - SourceSignalPair("src1", ["sig2"]), - SourceSignalPair("src2", ["srcx"]), + SourceSignalSet("src1", ["sig2"]), + SourceSignalSet("src2", ["srcx"]), ], "p", params, @@ -245,57 +245,57 @@ def test_filter_source_signal_pairs(self): {"p_0t": "src1", "p_0t_0": "sig2", "p_1t": "src2", "p_1t_0": "srcx"}, ) - def test_filter_time_pair(self): + def test_filter_time_set(self): with self.subTest("empty"): params = {} - self.assertEqual(filter_time_pair("t", "v", None, "p", params), "FALSE") + self.assertEqual(filter_time_set("t", "v", None, "p", params), "FALSE") self.assertEqual(params, {}) with self.subTest("*"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", True), "p", params), + filter_time_set("t", "v", TimeSet("day", True), "p", params), "(t = :p_0t)", ) self.assertEqual(params, {"p_0t": "day"}) with self.subTest("single"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", [20201201]), "p", params), + filter_time_set("t", "v", TimeSet("day", [20201201]), "p", params), "((t = :p_0t AND (v = :p_0t_0)))", ) self.assertEqual(params, {"p_0t": "day", "p_0t_0": 20201201}) with self.subTest("multi"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", [20201201, 20201203]), "p", params), + filter_time_set("t", "v", TimeSet("day", [20201201, 20201203]), "p", params), "((t = :p_0t AND (v = :p_0t_0 OR v = :p_0t_1)))", ) self.assertEqual(params, {"p_0t": "day", "p_0t_0": 20201201, "p_0t_1": 20201203}) with self.subTest("range"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", [(20201201, 20201203)]), "p", params), + filter_time_set("t", "v", TimeSet("day", [(20201201, 20201203)]), "p", params), "((t = :p_0t AND (v BETWEEN :p_0t_0 AND :p_0t_0_2)))", ) self.assertEqual(params, {"p_0t": "day", "p_0t_0": 20201201, "p_0t_0_2": 20201203}) with self.subTest("dedupe"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", [20200101, 20200101, (20200101, 20200101), 20200101]), "p", params), + filter_time_set("t", "v", TimeSet("day", [20200101, 20200101, (20200101, 20200101), 20200101]), "p", params), "((t = :p_0t AND (v = :p_0t_0)))", ) self.assertEqual(params, {"p_0t": "day", "p_0t_0": 20200101}) with self.subTest("merge single range"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", [20200101, 20200102, (20200101, 20200104)]), "p", params), + filter_time_set("t", "v", TimeSet("day", [20200101, 20200102, (20200101, 20200104)]), "p", params), "((t = :p_0t AND (v BETWEEN :p_0t_0 AND :p_0t_0_2)))", ) self.assertEqual(params, {"p_0t": "day", "p_0t_0": 20200101, "p_0t_0_2": 20200104}) with self.subTest("merge ranges and singles"): params = {} self.assertEqual( - filter_time_pair("t", "v", TimePair("day", [20200101, 20200103, (20200105, 20200107)]), "p", params), + filter_time_set("t", "v", TimeSet("day", [20200101, 20200103, (20200105, 20200107)]), "p", params), "((t = :p_0t AND (v = :p_0t_0 OR v = :p_0t_1 OR v BETWEEN :p_0t_2 AND :p_0t_2_2)))", ) self.assertEqual(params, {"p_0t": "day", "p_0t_0": 20200101, "p_0t_1": 20200103, 'p_0t_2': 20200105, 'p_0t_2_2': 20200107}) diff --git a/tests/server/test_validate.py b/tests/server/test_validate.py index c254950ff..ca45c78e2 100644 --- a/tests/server/test_validate.py +++ b/tests/server/test_validate.py @@ -11,11 +11,6 @@ check_auth_token, require_all, require_any, - extract_strings, - extract_integers, - extract_integer, - extract_date, - extract_dates ) from delphi.epidata.server._exceptions import ( ValidationFailedException, @@ -110,129 +105,3 @@ def test_require_any(self): with self.subTest("one options given with is empty but ok"): with app.test_request_context("/?abc="): self.assertTrue(require_any("abc", empty=True)) - - def test_extract_strings(self): - with self.subTest("empty"): - with app.test_request_context("/"): - self.assertIsNone(extract_strings("s")) - with self.subTest("single"): - with app.test_request_context("/?s=a"): - self.assertEqual(extract_strings("s"), ["a"]) - with self.subTest("multiple"): - with app.test_request_context("/?s=a,b"): - self.assertEqual(extract_strings("s"), ["a", "b"]) - with self.subTest("multiple param"): - with app.test_request_context("/?s=a&s=b"): - self.assertEqual(extract_strings("s"), ["a", "b"]) - with self.subTest("multiple param mixed"): - with app.test_request_context("/?s=a&s=b,c"): - self.assertEqual(extract_strings("s"), ["a", "b", "c"]) - - def test_extract_integer(self): - with self.subTest("empty"): - with app.test_request_context("/"): - self.assertIsNone(extract_integer("s")) - with self.subTest("single"): - with app.test_request_context("/?s=1"): - self.assertEqual(extract_integer("s"), 1) - with self.subTest("not a number"): - with app.test_request_context("/?s=a"): - self.assertRaises(ValidationFailedException, lambda: extract_integer("s")) - - def test_extract_integers(self): - with self.subTest("empty"): - with app.test_request_context("/"): - self.assertIsNone(extract_integers("s")) - with self.subTest("single"): - with app.test_request_context("/?s=1"): - self.assertEqual(extract_integers("s"), [1]) - with self.subTest("multiple"): - with app.test_request_context("/?s=1,2"): - self.assertEqual(extract_integers("s"), [1,2]) - with self.subTest("multiple param"): - with app.test_request_context("/?s=1&s=2"): - self.assertEqual(extract_integers("s"), [1,2]) - with self.subTest("multiple param mixed"): - with app.test_request_context("/?s=1&s=2,3"): - self.assertEqual(extract_integers("s"), [1, 2, 3]) - - with self.subTest("not a number"): - with app.test_request_context("/?s=a"): - self.assertRaises(ValidationFailedException, lambda: extract_integers("s")) - - with self.subTest("simple range"): - with app.test_request_context("/?s=1-2"): - self.assertEqual(extract_integers("s"), [(1, 2)]) - with self.subTest("inverted range"): - with app.test_request_context("/?s=2-1"): - self.assertRaises(ValidationFailedException, lambda: extract_integers("s")) - with self.subTest("single range"): - with app.test_request_context("/?s=1-1"): - self.assertEqual(extract_integers("s"), [1]) - - def test_extract_date(self): - with self.subTest("empty"): - with app.test_request_context("/"): - self.assertIsNone(extract_date("s")) - with self.subTest("single"): - with app.test_request_context("/?s=2020-01-01"): - self.assertEqual(extract_date("s"), 20200101) - with app.test_request_context("/?s=20200101"): - self.assertEqual(extract_date("s"), 20200101) - with self.subTest("not a date"): - with app.test_request_context("/?s=abc"): - self.assertRaises(ValidationFailedException, lambda: extract_date("s")) - - def test_extract_dates(self): - with self.subTest("empty"): - with app.test_request_context("/"): - self.assertIsNone(extract_dates("s")) - with self.subTest("single"): - with app.test_request_context("/?s=20200101"): - self.assertEqual(extract_dates("s"), [20200101]) - with self.subTest("multiple"): - with app.test_request_context("/?s=20200101,20200102"): - self.assertEqual(extract_dates("s"), [20200101, 20200102]) - with self.subTest("multiple param"): - with app.test_request_context("/?s=20200101&s=20200102"): - self.assertEqual(extract_dates("s"), [20200101, 20200102]) - with self.subTest("multiple param mixed"): - with app.test_request_context("/?s=20200101&s=20200102,20200103"): - self.assertEqual(extract_dates("s"), [20200101, 20200102, 20200103]) - with self.subTest("single iso"): - with app.test_request_context("/?s=2020-01-01"): - self.assertEqual(extract_dates("s"), [20200101]) - with self.subTest("multiple iso"): - with app.test_request_context("/?s=2020-01-01,2020-01-02"): - self.assertEqual(extract_dates("s"), [20200101, 20200102]) - with self.subTest("multiple param iso"): - with app.test_request_context("/?s=2020-01-01&s=2020-01-02"): - self.assertEqual(extract_dates("s"), [20200101, 20200102]) - with self.subTest("multiple param mixed iso"): - with app.test_request_context("/?s=2020-01-01&s=2020-01-02,2020-01-03"): - self.assertEqual(extract_dates("s"), [20200101, 20200102, 20200103]) - - with self.subTest("not a date"): - with app.test_request_context("/?s=a"): - self.assertRaises(ValidationFailedException, lambda: extract_dates("s")) - - with self.subTest("simple range"): - with app.test_request_context("/?s=20200101-20200102"): - self.assertEqual(extract_dates("s"), [(20200101, 20200102)]) - with self.subTest("inverted range"): - with app.test_request_context("/?s=20200102-20200101"): - self.assertRaises(ValidationFailedException, lambda: extract_dates("s")) - with self.subTest("single range"): - with app.test_request_context("/?s=20200101-20200101"): - self.assertEqual(extract_dates("s"), [20200101]) - - with self.subTest("simple range iso"): - with app.test_request_context("/?s=2020-01-01:2020-01-02"): - self.assertEqual(extract_dates("s"), [(20200101, 20200102)]) - with self.subTest("inverted range iso"): - with app.test_request_context("/?s=2020-01-02:2020-01-01"): - self.assertRaises(ValidationFailedException, lambda: extract_dates("s")) - with self.subTest("single range iso"): - with app.test_request_context("/?s=2020-01-01:2020-01-01"): - self.assertEqual(extract_dates("s"), [20200101]) -