diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/composite_leading_indicator.py b/openbb_platform/providers/oecd/openbb_oecd/models/composite_leading_indicator.py index 956fdb90c21e..8184bea44205 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/composite_leading_indicator.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/composite_leading_indicator.py @@ -10,34 +10,11 @@ CLIQueryParams, ) from openbb_oecd.utils import helpers +from openbb_oecd.utils.constants import CODE_TO_COUNTRY_CLI, COUNTRY_TO_CODE_CLI from pydantic import Field, field_validator -cli_mapping = { - "USA": "united_states", - "GBR": "united_kingdom", - "JPN": "japan", - "MEX": "mexico", - "IDN": "indonesia", - "AUS": "australia", - "BRA": "brazil", - "CAN": "canada", - "ITA": "italy", - "DEU": "germany", - "TUR": "turkey", - "FRA": "france", - "ZAF": "south_africa", - "KOR": "south_korea", - "ESP": "spain", - "IND": "india", - "CHN": "china", - "G7": "g7", - "G20": "g20", -} - - -countries = tuple(cli_mapping.values()) + ("all",) +countries = tuple(CODE_TO_COUNTRY_CLI.values()) + ("all",) CountriesLiteral = Literal[countries] # type: ignore -country_to_code = {v: k for k, v in cli_mapping.items()} class OECDCLIQueryParams(CLIQueryParams): @@ -70,10 +47,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 return date(_year, 12, 31) # Now match if it is monthly, i.e 2022-01 elif re.match(r"\d{4}-\d{2}$", in_date): - year, month = map(int, in_date.split("-")) + year, month = map(int, in_date.split("-")) # type: ignore if month == 12: - return date(year, month, 31) - next_month = date(year, month + 1, 1) + return date(year, month, 31) # type: ignore + next_month = date(year, month + 1, 1) # type: ignore return date(next_month.year, next_month.month, 1) - timedelta(days=1) # Now match if it is yearly, i.e 2022 elif re.match(r"\d{4}$", in_date): @@ -99,9 +76,10 @@ def transform_query(params: Dict[str, Any]) -> OECDCLIQueryParams: return OECDCLIQueryParams(**transformed_params) + # pylint: disable=unused-argument @staticmethod def extract_data( - query: OECDCLIQueryParams, # pylint: disable=W0613 + query: OECDCLIQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any, ) -> Dict: @@ -112,19 +90,27 @@ def extract_data( ) if query.country != "all": - data = data.query(f"REF_AREA == '{country_to_code[query.country]}'") + data = data.query(f"REF_AREA == '{COUNTRY_TO_CODE_CLI[query.country]}'") # Filter down data = data.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]].rename( columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} ) - data["country"] = data["country"].map(cli_mapping) + data["country"] = data["country"].map(CODE_TO_COUNTRY_CLI) - return data.to_dict(orient="records") + data = data.to_dict(orient="records") + start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore + end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore + data = list(filter(lambda x: start_date <= x["date"] <= end_date, data)) + return data + + # pylint: disable=unused-argument @staticmethod def transform_data( - query: OECDCLIQueryParams, data: Dict, **kwargs: Any + query: OECDCLIQueryParams, + data: Dict, + **kwargs: Any, ) -> List[OECDCLIData]: """Transform the data from the OECD endpoint.""" return [OECDCLIData.model_validate(d) for d in data] diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py b/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py index 6a5d2aa8947b..67b536bb8c71 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py @@ -12,59 +12,11 @@ LTIRQueryParams, ) from openbb_oecd.utils import helpers +from openbb_oecd.utils.constants import CODE_TO_COUNTRY_IR, COUNTRY_TO_CODE_IR from pydantic import Field, field_validator -ltir_mapping = { - "BEL": "belgium", - "IRL": "ireland", - "MEX": "mexico", - "IDN": "indonesia", - "NZL": "new_zealand", - "JPN": "japan", - "GBR": "united_kingdom", - "FRA": "france", - "CHL": "chile", - "CAN": "canada", - "NLD": "netherlands", - "USA": "united_states", - "KOR": "south_korea", - "NOR": "norway", - "AUT": "austria", - "ZAF": "south_africa", - "DNK": "denmark", - "CHE": "switzerland", - "HUN": "hungary", - "LUX": "luxembourg", - "AUS": "australia", - "DEU": "germany", - "SWE": "sweden", - "ISL": "iceland", - "TUR": "turkey", - "GRC": "greece", - "ISR": "israel", - "CZE": "czech_republic", - "LVA": "latvia", - "SVN": "slovenia", - "POL": "poland", - "EST": "estonia", - "LTU": "lithuania", - "PRT": "portugal", - "CRI": "costa_rica", - "SVK": "slovakia", - "FIN": "finland", - "ESP": "spain", - "RUS": "russia", - "EA19": "euro_area19", - "COL": "colombia", - "ITA": "italy", - "IND": "india", - "CHN": "china", - "HRV": "croatia", -} - -countries = tuple(ltir_mapping.values()) + ("all",) +countries = tuple(CODE_TO_COUNTRY_IR.values()) + ("all",) CountriesLiteral = Literal[countries] # type: ignore -country_to_code = {v: k for k, v in ltir_mapping.items()} class OECDLTIRQueryParams(LTIRQueryParams): @@ -101,10 +53,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 return date(_year, 12, 31) # Now match if it is monthly, i.e 2022-01 elif re.match(r"\d{4}-\d{2}$", in_date): - year, month = map(int, in_date.split("-")) + year, month = map(int, in_date.split("-")) # type: ignore if month == 12: - return date(year, month, 31) - next_month = date(year, month + 1, 1) + return date(year, month, 31) # type: ignore + next_month = date(year, month + 1, 1) # type: ignore return date(next_month.year, next_month.month, 1) - timedelta(days=1) # Now match if it is yearly, i.e 2022 elif re.match(r"\d{4}$", in_date): @@ -138,24 +90,30 @@ def extract_data( ) -> Dict: """Return the raw data from the OECD endpoint.""" frequency = query.frequency[0].upper() - country = "" if query.country == "all" else country_to_code[query.country] + country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country] url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IRLT...." data = helpers.get_possibly_cached_data( url, function="economy_long_term_interest_rate" ) - query = f"FREQ=='{frequency}'" - query = query + f" & REF_AREA=='{country}'" if country else query + url_query = f"FREQ=='{frequency}'" + url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query # Filter down data = ( - data.query(query) + data.query(url_query) .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]] .rename( columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} ) ) - data["country"] = data["country"].map(ltir_mapping) + data["country"] = data["country"].map(CODE_TO_COUNTRY_IR) data = data.fillna("N/A").replace("N/A", None) - return data.to_dict(orient="records") + data = data.to_dict(orient="records") + + start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore + end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore + data = list(filter(lambda x: start_date <= x["date"] <= end_date, data)) + + return data @staticmethod def transform_data( diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py b/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py index f06dd9291017..cd0183909994 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py @@ -12,59 +12,11 @@ STIRQueryParams, ) from openbb_oecd.utils import helpers +from openbb_oecd.utils.constants import CODE_TO_COUNTRY_IR, COUNTRY_TO_CODE_IR from pydantic import Field, field_validator -stir_mapping = { - "BEL": "belgium", - "IRL": "ireland", - "MEX": "mexico", - "IDN": "indonesia", - "NZL": "new_zealand", - "JPN": "japan", - "GBR": "united_kingdom", - "FRA": "france", - "CHL": "chile", - "CAN": "canada", - "NLD": "netherlands", - "USA": "united_states", - "KOR": "south_korea", - "NOR": "norway", - "AUT": "austria", - "ZAF": "south_africa", - "DNK": "denmark", - "CHE": "switzerland", - "HUN": "hungary", - "LUX": "luxembourg", - "AUS": "australia", - "DEU": "germany", - "SWE": "sweden", - "ISL": "iceland", - "TUR": "turkey", - "GRC": "greece", - "ISR": "israel", - "CZE": "czech_republic", - "LVA": "latvia", - "SVN": "slovenia", - "POL": "poland", - "EST": "estonia", - "LTU": "lithuania", - "PRT": "portugal", - "CRI": "costa_rica", - "SVK": "slovakia", - "FIN": "finland", - "ESP": "spain", - "RUS": "russia", - "EA19": "euro_area19", - "COL": "colombia", - "ITA": "italy", - "IND": "india", - "CHN": "china", - "HRV": "croatia", -} - -countries = tuple(stir_mapping.values()) + ("all",) +countries = tuple(CODE_TO_COUNTRY_IR.values()) + ("all",) CountriesLiteral = Literal[countries] # type: ignore -country_to_code = {v: k for k, v in stir_mapping.items()} class OECDSTIRQueryParams(STIRQueryParams): @@ -101,10 +53,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 return date(_year, 12, 31) # Now match if it is monthly, i.e 2022-01 elif re.match(r"\d{4}-\d{2}$", in_date): - year, month = map(int, in_date.split("-")) + year, month = map(int, in_date.split("-")) # type: ignore if month == 12: - return date(year, month, 31) - next_month = date(year, month + 1, 1) + return date(year, month, 31) # type: ignore + next_month = date(year, month + 1, 1) # type: ignore return date(next_month.year, next_month.month, 1) - timedelta(days=1) # Now match if it is yearly, i.e 2022 elif re.match(r"\d{4}$", in_date): @@ -138,24 +90,30 @@ def extract_data( ) -> Dict: """Return the raw data from the OECD endpoint.""" frequency = query.frequency[0].upper() - country = "" if query.country == "all" else country_to_code[query.country] + country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country] url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IR3TIB...." data = helpers.get_possibly_cached_data( url, function="economy_short_term_interest_rate" ) - query = f"FREQ=='{frequency}'" - query = query + f" & REF_AREA=='{country}'" if country else query + url_query = f"FREQ=='{frequency}'" + url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query # Filter down data = ( - data.query(query) + data.query(url_query) .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]] .rename( columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} ) ) - data["country"] = data["country"].map(stir_mapping) + data["country"] = data["country"].map(CODE_TO_COUNTRY_IR) data = data.fillna("N/A").replace("N/A", None) - return data.to_dict(orient="records") + data = data.to_dict(orient="records") + + start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore + end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore + data = list(filter(lambda x: start_date <= x["date"] <= end_date, data)) + + return data @staticmethod def transform_data( diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py b/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py index 7cc44d0863d3..406e5d500f19 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py @@ -10,63 +10,21 @@ UnemploymentQueryParams, ) from openbb_oecd.utils import helpers +from openbb_oecd.utils.constants import ( + CODE_TO_COUNTRY_UNEMPLOYMENT, + COUNTRY_TO_CODE_UNEMPLOYMENT, +) from pydantic import Field, field_validator -country_mapping = { - "COL": "colombia", - "NZL": "new_zealand", - "GBR": "united_kingdom", - "ITA": "italy", - "LUX": "luxembourg", - "EA19": "euro_area19", - "SWE": "sweden", - "OECD": "oecd", - "ZAF": "south_africa", - "DNK": "denmark", - "CAN": "canada", - "CHE": "switzerland", - "SVK": "slovakia", - "HUN": "hungary", - "PRT": "portugal", - "ESP": "spain", - "FRA": "france", - "CZE": "czech_republic", - "CRI": "costa_rica", - "JPN": "japan", - "SVN": "slovenia", - "RUS": "russia", - "AUT": "austria", - "LVA": "latvia", - "NLD": "netherlands", - "ISR": "israel", - "ISL": "iceland", - "USA": "united_states", - "IRL": "ireland", - "MEX": "mexico", - "DEU": "germany", - "GRC": "greece", - "TUR": "turkey", - "AUS": "australia", - "POL": "poland", - "KOR": "south_korea", - "CHL": "chile", - "FIN": "finland", - "EU27_2020": "european_union27_2020", - "NOR": "norway", - "LTU": "lithuania", - "EA20": "euro_area20", - "EST": "estonia", - "BEL": "belgium", - "BRA": "brazil", - "IDN": "indonesia", -} -countries = tuple(country_mapping.values()) + ("all",) +countries = tuple(CODE_TO_COUNTRY_UNEMPLOYMENT.values()) + ("all",) CountriesLiteral = Literal[countries] # type: ignore -country_to_code = {v: k for k, v in country_mapping.items()} class OECDUnemploymentQueryParams(UnemploymentQueryParams): - """OECD Unemployment Query.""" + """OECD Unemployment Query. + + Source: https://data-explorer.oecd.org/?lc=en + """ country: CountriesLiteral = Field( description="Country to get GDP for.", default="united_states" @@ -109,10 +67,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 return date(_year, 12, 31) # Now match if it is monthly, i.e 2022-01 elif re.match(r"\d{4}-\d{2}$", in_date): - year, month = map(int, in_date.split("-")) + year, month = map(int, in_date.split("-")) # type: ignore if month == 12: - return date(year, month, 31) - next_month = date(year, month + 1, 1) + return date(year, month, 31) # type: ignore + next_month = date(year, month + 1, 1) # type: ignore return date(next_month.year, next_month.month, 1) - timedelta(days=1) # Now match if it is yearly, i.e 2022 elif re.match(r"\d{4}$", in_date): @@ -140,9 +98,10 @@ def transform_query(params: Dict[str, Any]) -> OECDUnemploymentQueryParams: return OECDUnemploymentQueryParams(**transformed_params) + # pylint: disable=unused-argument @staticmethod def extract_data( - query: OECDUnemploymentQueryParams, # pylint: disable=W0613 + query: OECDUnemploymentQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any, ) -> Dict: @@ -157,23 +116,33 @@ def extract_data( "55-64": "Y55T64", }[query.age] seasonal_adjustment = "Y" if query.seasonal_adjustment else "N" - country = "" if query.country == "all" else country_to_code[query.country] + country = ( + "" + if query.country == "all" + else COUNTRY_TO_CODE_UNEMPLOYMENT[query.country] + ) url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_INDIC,1.0/.UNE_LF........" data = helpers.get_possibly_cached_data(url, function="economy_unemployment") - query = f"AGE=='{age}' & SEX=='{sex}' & FREQ=='{frequency}' & ADJUSTMENT=='{seasonal_adjustment}'" - query = query + f" & REF_AREA=='{country}'" if country else query + url_query = f"AGE=='{age}' & SEX=='{sex}' & FREQ=='{frequency}' & ADJUSTMENT=='{seasonal_adjustment}'" + url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query # Filter down data = ( - data.query(query) + data.query(url_query) .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]] .rename( columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} ) ) - data["country"] = data["country"].map(country_mapping) + data["country"] = data["country"].map(CODE_TO_COUNTRY_UNEMPLOYMENT) + + data = data.to_dict(orient="records") + start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore + end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore + data = list(filter(lambda x: start_date <= x["date"] <= end_date, data)) - return data.to_dict(orient="records") + return data + # pylint: disable=unused-argument @staticmethod def transform_data( query: OECDUnemploymentQueryParams, data: Dict, **kwargs: Any diff --git a/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py b/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py index af189d8d8579..f03885d5a835 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py +++ b/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py @@ -1,3 +1,5 @@ +"""Constants for the OECD provider.""" + COUNTRY_TO_CODE_GDP = { "australia": "AUS", "austria": "AUT", @@ -454,3 +456,128 @@ "united_kingdom": "GBR", "united_states": "USA", } + +COUNTRY_TO_CODE_UNEMPLOYMENT = { + "colombia": "COL", + "new_zealand": "NZL", + "united_kingdom": "GBR", + "italy": "ITA", + "luxembourg": "LUX", + "euro_area19": "EA19", + "sweden": "SWE", + "oecd": "OECD", + "south_africa": "ZAF", + "denmark": "DNK", + "canada": "CAN", + "switzerland": "CHE", + "slovakia": "SVK", + "hungary": "HUN", + "portugal": "PRT", + "spain": "ESP", + "france": "FRA", + "czech_republic": "CZE", + "costa_rica": "CRI", + "japan": "JPN", + "slovenia": "SVN", + "russia": "RUS", + "austria": "AUT", + "latvia": "LVA", + "netherlands": "NLD", + "israel": "ISR", + "iceland": "ISL", + "united_states": "USA", + "ireland": "IRL", + "mexico": "MEX", + "germany": "DEU", + "greece": "GRC", + "turkey": "TUR", + "australia": "AUS", + "poland": "POL", + "south_korea": "KOR", + "chile": "CHL", + "finland": "FIN", + "european_union27_2020": "EU27_2020", + "norway": "NOR", + "lithuania": "LTU", + "euro_area20": "EA20", + "estonia": "EST", + "belgium": "BEL", + "brazil": "BRA", + "indonesia": "IDN", +} + +CODE_TO_COUNTRY_UNEMPLOYMENT = {v: k for k, v in COUNTRY_TO_CODE_UNEMPLOYMENT.items()} + +COUNTRY_TO_CODE_CLI = { + "united_states": "USA", + "united_kingdom": "GBR", + "japan": "JPN", + "mexico": "MEX", + "indonesia": "IDN", + "australia": "AUS", + "brazil": "BRA", + "canada": "CAN", + "italy": "ITA", + "germany": "DEU", + "turkey": "TUR", + "france": "FRA", + "south_africa": "ZAF", + "south_korea": "KOR", + "spain": "ESP", + "india": "IND", + "china": "CHN", + "g7": "G7", + "g20": "G20", +} + +CODE_TO_COUNTRY_CLI = {v: k for k, v in COUNTRY_TO_CODE_CLI.items()} + +COUNTRY_TO_CODE_IR = { + "belgium": "BEL", + "ireland": "IRL", + "mexico": "MEX", + "indonesia": "IDN", + "new_zealand": "NZL", + "japan": "JPN", + "united_kingdom": "GBR", + "france": "FRA", + "chile": "CHL", + "canada": "CAN", + "netherlands": "NLD", + "united_states": "USA", + "south_korea": "KOR", + "norway": "NOR", + "austria": "AUT", + "south_africa": "ZAF", + "denmark": "DNK", + "switzerland": "CHE", + "hungary": "HUN", + "luxembourg": "LUX", + "australia": "AUS", + "germany": "DEU", + "sweden": "SWE", + "iceland": "ISL", + "turkey": "TUR", + "greece": "GRC", + "israel": "ISR", + "czech_republic": "CZE", + "latvia": "LVA", + "slovenia": "SVN", + "poland": "POL", + "estonia": "EST", + "lithuania": "LTU", + "portugal": "PRT", + "costa_rica": "CRI", + "slovakia": "SVK", + "finland": "FIN", + "spain": "ESP", + "russia": "RUS", + "euro_area19": "EA19", + "colombia": "COL", + "italy": "ITA", + "india": "IND", + "china": "CHN", + "croatia": "HRV", +} + +CODE_TO_COUNTRY_IR = {v: k for k, v in COUNTRY_TO_CODE_IR.items()}