Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Filter OECD data using start_date and end_date parameters #6144

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,11 @@
CLIQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_CLI, COUNTRY_TO_CODE_CLI
from pydantic import Field, field_validator

cli_mapping = {
"USA": "united_states",
"GBR": "united_kingdom",
"JPN": "japan",
"MEX": "mexico",
"IDN": "indonesia",
"AUS": "australia",
"BRA": "brazil",
"CAN": "canada",
"ITA": "italy",
"DEU": "germany",
"TUR": "turkey",
"FRA": "france",
"ZAF": "south_africa",
"KOR": "south_korea",
"ESP": "spain",
"IND": "india",
"CHN": "china",
"G7": "g7",
"G20": "g20",
}


countries = tuple(cli_mapping.values()) + ("all",)
countries = tuple(CODE_TO_COUNTRY_CLI.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore
country_to_code = {v: k for k, v in cli_mapping.items()}


class OECDCLIQueryParams(CLIQueryParams):
Expand Down Expand Up @@ -70,10 +47,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-"))
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31)
next_month = date(year, month + 1, 1)
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
Expand All @@ -99,9 +76,10 @@ def transform_query(params: Dict[str, Any]) -> OECDCLIQueryParams:

return OECDCLIQueryParams(**transformed_params)

# pylint: disable=unused-argument
@staticmethod
def extract_data(
query: OECDCLIQueryParams, # pylint: disable=W0613
query: OECDCLIQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
Expand All @@ -112,19 +90,27 @@ def extract_data(
)

if query.country != "all":
data = data.query(f"REF_AREA == '{country_to_code[query.country]}'")
data = data.query(f"REF_AREA == '{COUNTRY_TO_CODE_CLI[query.country]}'")

# Filter down
data = data.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]].rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
data["country"] = data["country"].map(cli_mapping)
data["country"] = data["country"].map(CODE_TO_COUNTRY_CLI)

return data.to_dict(orient="records")
data = data.to_dict(orient="records")
start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data

# pylint: disable=unused-argument
@staticmethod
def transform_data(
query: OECDCLIQueryParams, data: Dict, **kwargs: Any
query: OECDCLIQueryParams,
data: Dict,
**kwargs: Any,
) -> List[OECDCLIData]:
"""Transform the data from the OECD endpoint."""
return [OECDCLIData.model_validate(d) for d in data]
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,11 @@
LTIRQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_IR, COUNTRY_TO_CODE_IR
from pydantic import Field, field_validator

ltir_mapping = {
"BEL": "belgium",
"IRL": "ireland",
"MEX": "mexico",
"IDN": "indonesia",
"NZL": "new_zealand",
"JPN": "japan",
"GBR": "united_kingdom",
"FRA": "france",
"CHL": "chile",
"CAN": "canada",
"NLD": "netherlands",
"USA": "united_states",
"KOR": "south_korea",
"NOR": "norway",
"AUT": "austria",
"ZAF": "south_africa",
"DNK": "denmark",
"CHE": "switzerland",
"HUN": "hungary",
"LUX": "luxembourg",
"AUS": "australia",
"DEU": "germany",
"SWE": "sweden",
"ISL": "iceland",
"TUR": "turkey",
"GRC": "greece",
"ISR": "israel",
"CZE": "czech_republic",
"LVA": "latvia",
"SVN": "slovenia",
"POL": "poland",
"EST": "estonia",
"LTU": "lithuania",
"PRT": "portugal",
"CRI": "costa_rica",
"SVK": "slovakia",
"FIN": "finland",
"ESP": "spain",
"RUS": "russia",
"EA19": "euro_area19",
"COL": "colombia",
"ITA": "italy",
"IND": "india",
"CHN": "china",
"HRV": "croatia",
}

countries = tuple(ltir_mapping.values()) + ("all",)
countries = tuple(CODE_TO_COUNTRY_IR.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore
country_to_code = {v: k for k, v in ltir_mapping.items()}


class OECDLTIRQueryParams(LTIRQueryParams):
Expand Down Expand Up @@ -101,10 +53,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-"))
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31)
next_month = date(year, month + 1, 1)
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
Expand Down Expand Up @@ -138,24 +90,30 @@ def extract_data(
) -> Dict:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = "" if query.country == "all" else country_to_code[query.country]
country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country]
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IRLT...."
data = helpers.get_possibly_cached_data(
url, function="economy_long_term_interest_rate"
)
query = f"FREQ=='{frequency}'"
query = query + f" & REF_AREA=='{country}'" if country else query
url_query = f"FREQ=='{frequency}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
# Filter down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here, we would normally want these sections in the "transform_data", correct?

data = (
data.query(query)
data.query(url_query)
.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
.rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
)
data["country"] = data["country"].map(ltir_mapping)
data["country"] = data["country"].map(CODE_TO_COUNTRY_IR)
data = data.fillna("N/A").replace("N/A", None)
return data.to_dict(orient="records")
data = data.to_dict(orient="records")

start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data

@staticmethod
def transform_data(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,11 @@
STIRQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_IR, COUNTRY_TO_CODE_IR
from pydantic import Field, field_validator

stir_mapping = {
"BEL": "belgium",
"IRL": "ireland",
"MEX": "mexico",
"IDN": "indonesia",
"NZL": "new_zealand",
"JPN": "japan",
"GBR": "united_kingdom",
"FRA": "france",
"CHL": "chile",
"CAN": "canada",
"NLD": "netherlands",
"USA": "united_states",
"KOR": "south_korea",
"NOR": "norway",
"AUT": "austria",
"ZAF": "south_africa",
"DNK": "denmark",
"CHE": "switzerland",
"HUN": "hungary",
"LUX": "luxembourg",
"AUS": "australia",
"DEU": "germany",
"SWE": "sweden",
"ISL": "iceland",
"TUR": "turkey",
"GRC": "greece",
"ISR": "israel",
"CZE": "czech_republic",
"LVA": "latvia",
"SVN": "slovenia",
"POL": "poland",
"EST": "estonia",
"LTU": "lithuania",
"PRT": "portugal",
"CRI": "costa_rica",
"SVK": "slovakia",
"FIN": "finland",
"ESP": "spain",
"RUS": "russia",
"EA19": "euro_area19",
"COL": "colombia",
"ITA": "italy",
"IND": "india",
"CHN": "china",
"HRV": "croatia",
}

countries = tuple(stir_mapping.values()) + ("all",)
countries = tuple(CODE_TO_COUNTRY_IR.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore
country_to_code = {v: k for k, v in stir_mapping.items()}


class OECDSTIRQueryParams(STIRQueryParams):
Expand Down Expand Up @@ -101,10 +53,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-"))
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31)
next_month = date(year, month + 1, 1)
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
Expand Down Expand Up @@ -138,24 +90,30 @@ def extract_data(
) -> Dict:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = "" if query.country == "all" else country_to_code[query.country]
country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country]
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IR3TIB...."
data = helpers.get_possibly_cached_data(
url, function="economy_short_term_interest_rate"
)
query = f"FREQ=='{frequency}'"
query = query + f" & REF_AREA=='{country}'" if country else query
url_query = f"FREQ=='{frequency}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
# Filter down
data = (
data.query(query)
data.query(url_query)
.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
.rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
)
data["country"] = data["country"].map(stir_mapping)
data["country"] = data["country"].map(CODE_TO_COUNTRY_IR)
data = data.fillna("N/A").replace("N/A", None)
return data.to_dict(orient="records")
data = data.to_dict(orient="records")

start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data

@staticmethod
def transform_data(
Expand Down
Loading
Loading