From 12675734429e03cb05773c99a322c4cfeacc9f13 Mon Sep 17 00:00:00 2001 From: Danglewood <85772166+deeleeramone@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:59:32 -0800 Subject: [PATCH] [BugFix] Fix Limit Param For yFinance Financials (#6962) * fix limit param for yfinance financials * missing decorator --- .../openbb_yfinance/models/balance_sheet.py | 40 ++++++++++------- .../openbb_yfinance/models/cash_flow.py | 42 +++++++++++------- .../models/income_statement.py | 44 ++++++++++++------- 3 files changed, 80 insertions(+), 46 deletions(-) diff --git a/openbb_platform/providers/yfinance/openbb_yfinance/models/balance_sheet.py b/openbb_platform/providers/yfinance/openbb_yfinance/models/balance_sheet.py index 75c94f411e57..2e60294f39ec 100644 --- a/openbb_platform/providers/yfinance/openbb_yfinance/models/balance_sheet.py +++ b/openbb_platform/providers/yfinance/openbb_yfinance/models/balance_sheet.py @@ -1,8 +1,9 @@ """Yahoo Finance Balance Sheet Model.""" -import json +# pylint: disable=unused-argument + from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from openbb_core.provider.abstract.fetcher import Fetcher from openbb_core.provider.standard_models.balance_sheet import ( @@ -10,8 +11,6 @@ BalanceSheetQueryParams, ) from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS -from openbb_core.provider.utils.errors import EmptyDataError -from openbb_core.provider.utils.helpers import to_snake_case from pydantic import Field, field_validator @@ -27,10 +26,15 @@ class YFinanceBalanceSheetQueryParams(BalanceSheetQueryParams): } } - period: Literal["annual", "quarter", "ttm"] = Field( + period: Literal["annual", "quarter"] = Field( default="annual", description=QUERY_DESCRIPTIONS.get("period", ""), ) + limit: Optional[int] = Field( + default=5, + description=QUERY_DESCRIPTIONS.get("limit", ""), + le=5, + ) class YFinanceBalanceSheetData(BalanceSheetData): @@ -59,25 +63,29 @@ def date_validate(cls, v): # pylint: disable=E0213 class YFinanceBalanceSheetFetcher( Fetcher[ YFinanceBalanceSheetQueryParams, - List[YFinanceBalanceSheetData], + list[YFinanceBalanceSheetData], ] ): - """Transform the query, extract and transform the data from the Yahoo Finance endpoints.""" + """Yahoo Finance Balance Sheet Fetcher.""" @staticmethod - def transform_query(params: Dict[str, Any]) -> YFinanceBalanceSheetQueryParams: + def transform_query(params: dict[str, Any]) -> YFinanceBalanceSheetQueryParams: """Transform the query parameters.""" return YFinanceBalanceSheetQueryParams(**params) @staticmethod def extract_data( - # pylint: disable=unused-argument query: YFinanceBalanceSheetQueryParams, - credentials: Optional[Dict[str, str]], + credentials: Optional[dict[str, str]], **kwargs: Any, - ) -> List[Dict]: + ) -> list[dict]: """Extract the data from the Yahoo Finance endpoints.""" - from yfinance import Ticker # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import json # noqa + from numpy import nan + from openbb_core.provider.utils.errors import EmptyDataError + from openbb_core.provider.utils.helpers import to_snake_case + from yfinance import Ticker period = "yearly" if query.period == "annual" else "quarterly" # type: ignore data = Ticker(query.symbol).get_balance_sheet( @@ -85,9 +93,11 @@ def extract_data( ) if data is None: raise EmptyDataError() + if query.limit: + data = data.iloc[:, : query.limit] data.index = [to_snake_case(i) for i in data.index] data = data.reset_index().sort_index(ascending=False).set_index("index") - data = data.fillna("N/A").replace("N/A", None).to_dict() + data = data.replace({nan: None}).to_dict() data = [{"period_ending": str(key), **value} for key, value in data.items()] data = json.loads(json.dumps(data)) @@ -97,8 +107,8 @@ def extract_data( @staticmethod def transform_data( query: YFinanceBalanceSheetQueryParams, - data: List[Dict], + data: list[dict], **kwargs: Any, - ) -> List[YFinanceBalanceSheetData]: + ) -> list[YFinanceBalanceSheetData]: """Transform the data.""" return [YFinanceBalanceSheetData.model_validate(d) for d in data] diff --git a/openbb_platform/providers/yfinance/openbb_yfinance/models/cash_flow.py b/openbb_platform/providers/yfinance/openbb_yfinance/models/cash_flow.py index 523bc1926685..da27d3a3db13 100644 --- a/openbb_platform/providers/yfinance/openbb_yfinance/models/cash_flow.py +++ b/openbb_platform/providers/yfinance/openbb_yfinance/models/cash_flow.py @@ -1,8 +1,9 @@ """Yahoo Finance Cash Flow Statement Model.""" -import json +# pylint: disable=unused-argument + from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from openbb_core.provider.abstract.fetcher import Fetcher from openbb_core.provider.standard_models.cash_flow import ( @@ -10,8 +11,6 @@ CashFlowStatementQueryParams, ) from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS -from openbb_core.provider.utils.errors import EmptyDataError -from openbb_core.provider.utils.helpers import to_snake_case from pydantic import Field, field_validator @@ -31,6 +30,11 @@ class YFinanceCashFlowStatementQueryParams(CashFlowStatementQueryParams): default="annual", description=QUERY_DESCRIPTIONS.get("period", ""), ) + limit: Optional[int] = Field( + default=5, + description=QUERY_DESCRIPTIONS.get("limit", ""), + le=5, + ) class YFinanceCashFlowStatementData(CashFlowStatementData): @@ -46,7 +50,7 @@ class YFinanceCashFlowStatementData(CashFlowStatementData): @field_validator("period_ending", mode="before", check_fields=False) @classmethod - def date_validate(cls, v): # pylint: disable=E0213 + def date_validate(cls, v): """Return datetime object from string.""" if isinstance(v, str): return datetime.strptime(v, "%Y-%m-%d %H:%M:%S").date() @@ -56,25 +60,29 @@ def date_validate(cls, v): # pylint: disable=E0213 class YFinanceCashFlowStatementFetcher( Fetcher[ YFinanceCashFlowStatementQueryParams, - List[YFinanceCashFlowStatementData], + list[YFinanceCashFlowStatementData], ] ): - """Transform the query, extract and transform the data from the Yahoo Finance endpoints.""" + """Yahoo Finance Cash Flow Statement Fetcher.""" @staticmethod - def transform_query(params: Dict[str, Any]) -> YFinanceCashFlowStatementQueryParams: + def transform_query(params: dict[str, Any]) -> YFinanceCashFlowStatementQueryParams: """Transform the query parameters.""" return YFinanceCashFlowStatementQueryParams(**params) @staticmethod def extract_data( - # pylint: disable=unused-argument query: YFinanceCashFlowStatementQueryParams, - credentials: Optional[Dict[str, str]], + credentials: Optional[dict[str, str]], **kwargs: Any, - ) -> List[YFinanceCashFlowStatementData]: + ) -> list[YFinanceCashFlowStatementData]: """Extract the data from the Yahoo Finance endpoints.""" - from yfinance import Ticker # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import json # noqa + from numpy import nan + from openbb_core.provider.utils.errors import EmptyDataError + from openbb_core.provider.utils.helpers import to_snake_case + from yfinance import Ticker period = "yearly" if query.period == "annual" else "quarterly" # type: ignore data = Ticker(query.symbol).get_cash_flow( @@ -84,9 +92,12 @@ def extract_data( if data is None: raise EmptyDataError() + if query.limit: + data = data.iloc[:, : query.limit] + data.index = [to_snake_case(i) for i in data.index] data = data.reset_index().sort_index(ascending=False).set_index("index") - data = data.fillna("N/A").replace("N/A", None).to_dict() + data = data.replace({nan: None}).to_dict() data = [{"period_ending": str(key), **value} for key, value in data.items()] data = json.loads(json.dumps(data)) @@ -95,10 +106,9 @@ def extract_data( @staticmethod def transform_data( - # pylint: disable=unused-argument query: YFinanceCashFlowStatementQueryParams, - data: List[Dict], + data: list[dict], **kwargs: Any, - ) -> List[YFinanceCashFlowStatementData]: + ) -> list[YFinanceCashFlowStatementData]: """Transform the data.""" return [YFinanceCashFlowStatementData.model_validate(d) for d in data] diff --git a/openbb_platform/providers/yfinance/openbb_yfinance/models/income_statement.py b/openbb_platform/providers/yfinance/openbb_yfinance/models/income_statement.py index 825cb38d82b0..7383b3112855 100644 --- a/openbb_platform/providers/yfinance/openbb_yfinance/models/income_statement.py +++ b/openbb_platform/providers/yfinance/openbb_yfinance/models/income_statement.py @@ -1,8 +1,9 @@ """Yahoo Finance Income Statement Model.""" -import json +# pylint: disable=unused-argument + from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from openbb_core.provider.abstract.fetcher import Fetcher from openbb_core.provider.standard_models.income_statement import ( @@ -10,8 +11,6 @@ IncomeStatementQueryParams, ) from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS -from openbb_core.provider.utils.errors import EmptyDataError -from openbb_core.provider.utils.helpers import to_snake_case from pydantic import Field, field_validator @@ -31,6 +30,11 @@ class YFinanceIncomeStatementQueryParams(IncomeStatementQueryParams): default="annual", description=QUERY_DESCRIPTIONS.get("period", ""), ) + limit: Optional[int] = Field( + default=5, + description=QUERY_DESCRIPTIONS.get("limit", ""), + le=5, + ) class YFinanceIncomeStatementData(IncomeStatementData): @@ -48,7 +52,8 @@ class YFinanceIncomeStatementData(IncomeStatementData): } @field_validator("period_ending", mode="before", check_fields=False) - def date_validate(cls, v): # pylint: disable=E0213 + @classmethod + def date_validate(cls, v): """Validate the date field.""" if isinstance(v, str): return datetime.strptime(v, "%Y-%m-%d %H:%M:%S").date() @@ -58,35 +63,44 @@ def date_validate(cls, v): # pylint: disable=E0213 class YFinanceIncomeStatementFetcher( Fetcher[ YFinanceIncomeStatementQueryParams, - List[YFinanceIncomeStatementData], + list[YFinanceIncomeStatementData], ] ): - """Transform the query, extract and transform the data from the Yahoo Finance endpoints.""" + """Yahoo Finance Income Statement Fetcher.""" @staticmethod - def transform_query(params: Dict[str, Any]) -> YFinanceIncomeStatementQueryParams: + def transform_query(params: dict[str, Any]) -> YFinanceIncomeStatementQueryParams: """Transform the query parameters.""" return YFinanceIncomeStatementQueryParams(**params) @staticmethod def extract_data( - # pylint: disable=unused-argument query: YFinanceIncomeStatementQueryParams, - credentials: Optional[Dict[str, str]], + credentials: Optional[dict[str, str]], **kwargs: Any, - ) -> List[YFinanceIncomeStatementData]: + ) -> list[YFinanceIncomeStatementData]: """Extract the data from the Yahoo Finance endpoints.""" - from yfinance import Ticker # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import json # noqa + from numpy import nan + from openbb_core.provider.utils.errors import EmptyDataError + from openbb_core.provider.utils.helpers import to_snake_case + from yfinance import Ticker period = "yearly" if query.period == "annual" else "quarterly" data = Ticker(query.symbol).get_income_stmt( as_dict=False, pretty=False, freq=period ) + if data is None: raise EmptyDataError() + + if query.limit: + data = data.iloc[:, : query.limit] + data.index = [to_snake_case(i) for i in data.index] data = data.reset_index().sort_index(ascending=False).set_index("index") - data = data.fillna("N/A").replace("N/A", None).to_dict() + data = data.replace({nan: None}).to_dict() data = [{"period_ending": str(key), **value} for key, value in data.items()] data = json.loads(json.dumps(data)) @@ -96,8 +110,8 @@ def extract_data( @staticmethod def transform_data( query: YFinanceIncomeStatementQueryParams, - data: List[Dict], + data: list[dict], **kwargs: Any, - ) -> List[YFinanceIncomeStatementData]: + ) -> list[YFinanceIncomeStatementData]: """Transform the data.""" return [YFinanceIncomeStatementData.model_validate(d) for d in data]