Skip to content

Commit

Permalink
[BugFix] Fix Limit Param For yFinance Financials (#6962)
Browse files Browse the repository at this point in the history
* fix limit param for yfinance financials

* missing decorator
  • Loading branch information
deeleeramone authored Nov 25, 2024
1 parent 7aa7e09 commit 1267573
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Yahoo Finance Balance Sheet Model."""

import json
# pylint: disable=unused-argument

from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.balance_sheet import (
BalanceSheetData,
BalanceSheetQueryParams,
)
from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_core.provider.utils.helpers import to_snake_case
from pydantic import Field, field_validator


Expand All @@ -27,10 +26,15 @@ class YFinanceBalanceSheetQueryParams(BalanceSheetQueryParams):
}
}

period: Literal["annual", "quarter", "ttm"] = Field(
period: Literal["annual", "quarter"] = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[int] = Field(
default=5,
description=QUERY_DESCRIPTIONS.get("limit", ""),
le=5,
)


class YFinanceBalanceSheetData(BalanceSheetData):
Expand Down Expand Up @@ -59,35 +63,41 @@ def date_validate(cls, v): # pylint: disable=E0213
class YFinanceBalanceSheetFetcher(
Fetcher[
YFinanceBalanceSheetQueryParams,
List[YFinanceBalanceSheetData],
list[YFinanceBalanceSheetData],
]
):
"""Transform the query, extract and transform the data from the Yahoo Finance endpoints."""
"""Yahoo Finance Balance Sheet Fetcher."""

@staticmethod
def transform_query(params: Dict[str, Any]) -> YFinanceBalanceSheetQueryParams:
def transform_query(params: dict[str, Any]) -> YFinanceBalanceSheetQueryParams:
"""Transform the query parameters."""
return YFinanceBalanceSheetQueryParams(**params)

@staticmethod
def extract_data(
# pylint: disable=unused-argument
query: YFinanceBalanceSheetQueryParams,
credentials: Optional[Dict[str, str]],
credentials: Optional[dict[str, str]],
**kwargs: Any,
) -> List[Dict]:
) -> list[dict]:
"""Extract the data from the Yahoo Finance endpoints."""
from yfinance import Ticker # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import json # noqa
from numpy import nan
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_core.provider.utils.helpers import to_snake_case
from yfinance import Ticker

period = "yearly" if query.period == "annual" else "quarterly" # type: ignore
data = Ticker(query.symbol).get_balance_sheet(
as_dict=False, pretty=False, freq=period
)
if data is None:
raise EmptyDataError()
if query.limit:
data = data.iloc[:, : query.limit]
data.index = [to_snake_case(i) for i in data.index]
data = data.reset_index().sort_index(ascending=False).set_index("index")
data = data.fillna("N/A").replace("N/A", None).to_dict()
data = data.replace({nan: None}).to_dict()
data = [{"period_ending": str(key), **value} for key, value in data.items()]

data = json.loads(json.dumps(data))
Expand All @@ -97,8 +107,8 @@ def extract_data(
@staticmethod
def transform_data(
query: YFinanceBalanceSheetQueryParams,
data: List[Dict],
data: list[dict],
**kwargs: Any,
) -> List[YFinanceBalanceSheetData]:
) -> list[YFinanceBalanceSheetData]:
"""Transform the data."""
return [YFinanceBalanceSheetData.model_validate(d) for d in data]
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Yahoo Finance Cash Flow Statement Model."""

import json
# pylint: disable=unused-argument

from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.cash_flow import (
CashFlowStatementData,
CashFlowStatementQueryParams,
)
from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_core.provider.utils.helpers import to_snake_case
from pydantic import Field, field_validator


Expand All @@ -31,6 +30,11 @@ class YFinanceCashFlowStatementQueryParams(CashFlowStatementQueryParams):
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[int] = Field(
default=5,
description=QUERY_DESCRIPTIONS.get("limit", ""),
le=5,
)


class YFinanceCashFlowStatementData(CashFlowStatementData):
Expand All @@ -46,7 +50,7 @@ class YFinanceCashFlowStatementData(CashFlowStatementData):

@field_validator("period_ending", mode="before", check_fields=False)
@classmethod
def date_validate(cls, v): # pylint: disable=E0213
def date_validate(cls, v):
"""Return datetime object from string."""
if isinstance(v, str):
return datetime.strptime(v, "%Y-%m-%d %H:%M:%S").date()
Expand All @@ -56,25 +60,29 @@ def date_validate(cls, v): # pylint: disable=E0213
class YFinanceCashFlowStatementFetcher(
Fetcher[
YFinanceCashFlowStatementQueryParams,
List[YFinanceCashFlowStatementData],
list[YFinanceCashFlowStatementData],
]
):
"""Transform the query, extract and transform the data from the Yahoo Finance endpoints."""
"""Yahoo Finance Cash Flow Statement Fetcher."""

@staticmethod
def transform_query(params: Dict[str, Any]) -> YFinanceCashFlowStatementQueryParams:
def transform_query(params: dict[str, Any]) -> YFinanceCashFlowStatementQueryParams:
"""Transform the query parameters."""
return YFinanceCashFlowStatementQueryParams(**params)

@staticmethod
def extract_data(
# pylint: disable=unused-argument
query: YFinanceCashFlowStatementQueryParams,
credentials: Optional[Dict[str, str]],
credentials: Optional[dict[str, str]],
**kwargs: Any,
) -> List[YFinanceCashFlowStatementData]:
) -> list[YFinanceCashFlowStatementData]:
"""Extract the data from the Yahoo Finance endpoints."""
from yfinance import Ticker # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import json # noqa
from numpy import nan
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_core.provider.utils.helpers import to_snake_case
from yfinance import Ticker

period = "yearly" if query.period == "annual" else "quarterly" # type: ignore
data = Ticker(query.symbol).get_cash_flow(
Expand All @@ -84,9 +92,12 @@ def extract_data(
if data is None:
raise EmptyDataError()

if query.limit:
data = data.iloc[:, : query.limit]

data.index = [to_snake_case(i) for i in data.index]
data = data.reset_index().sort_index(ascending=False).set_index("index")
data = data.fillna("N/A").replace("N/A", None).to_dict()
data = data.replace({nan: None}).to_dict()
data = [{"period_ending": str(key), **value} for key, value in data.items()]

data = json.loads(json.dumps(data))
Expand All @@ -95,10 +106,9 @@ def extract_data(

@staticmethod
def transform_data(
# pylint: disable=unused-argument
query: YFinanceCashFlowStatementQueryParams,
data: List[Dict],
data: list[dict],
**kwargs: Any,
) -> List[YFinanceCashFlowStatementData]:
) -> list[YFinanceCashFlowStatementData]:
"""Transform the data."""
return [YFinanceCashFlowStatementData.model_validate(d) for d in data]
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Yahoo Finance Income Statement Model."""

import json
# pylint: disable=unused-argument

from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.income_statement import (
IncomeStatementData,
IncomeStatementQueryParams,
)
from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_core.provider.utils.helpers import to_snake_case
from pydantic import Field, field_validator


Expand All @@ -31,6 +30,11 @@ class YFinanceIncomeStatementQueryParams(IncomeStatementQueryParams):
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[int] = Field(
default=5,
description=QUERY_DESCRIPTIONS.get("limit", ""),
le=5,
)


class YFinanceIncomeStatementData(IncomeStatementData):
Expand All @@ -48,7 +52,8 @@ class YFinanceIncomeStatementData(IncomeStatementData):
}

@field_validator("period_ending", mode="before", check_fields=False)
def date_validate(cls, v): # pylint: disable=E0213
@classmethod
def date_validate(cls, v):
"""Validate the date field."""
if isinstance(v, str):
return datetime.strptime(v, "%Y-%m-%d %H:%M:%S").date()
Expand All @@ -58,35 +63,44 @@ def date_validate(cls, v): # pylint: disable=E0213
class YFinanceIncomeStatementFetcher(
Fetcher[
YFinanceIncomeStatementQueryParams,
List[YFinanceIncomeStatementData],
list[YFinanceIncomeStatementData],
]
):
"""Transform the query, extract and transform the data from the Yahoo Finance endpoints."""
"""Yahoo Finance Income Statement Fetcher."""

@staticmethod
def transform_query(params: Dict[str, Any]) -> YFinanceIncomeStatementQueryParams:
def transform_query(params: dict[str, Any]) -> YFinanceIncomeStatementQueryParams:
"""Transform the query parameters."""
return YFinanceIncomeStatementQueryParams(**params)

@staticmethod
def extract_data(
# pylint: disable=unused-argument
query: YFinanceIncomeStatementQueryParams,
credentials: Optional[Dict[str, str]],
credentials: Optional[dict[str, str]],
**kwargs: Any,
) -> List[YFinanceIncomeStatementData]:
) -> list[YFinanceIncomeStatementData]:
"""Extract the data from the Yahoo Finance endpoints."""
from yfinance import Ticker # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import json # noqa
from numpy import nan
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_core.provider.utils.helpers import to_snake_case
from yfinance import Ticker

period = "yearly" if query.period == "annual" else "quarterly"
data = Ticker(query.symbol).get_income_stmt(
as_dict=False, pretty=False, freq=period
)

if data is None:
raise EmptyDataError()

if query.limit:
data = data.iloc[:, : query.limit]

data.index = [to_snake_case(i) for i in data.index]
data = data.reset_index().sort_index(ascending=False).set_index("index")
data = data.fillna("N/A").replace("N/A", None).to_dict()
data = data.replace({nan: None}).to_dict()
data = [{"period_ending": str(key), **value} for key, value in data.items()]

data = json.loads(json.dumps(data))
Expand All @@ -96,8 +110,8 @@ def extract_data(
@staticmethod
def transform_data(
query: YFinanceIncomeStatementQueryParams,
data: List[Dict],
data: list[dict],
**kwargs: Any,
) -> List[YFinanceIncomeStatementData]:
) -> list[YFinanceIncomeStatementData]:
"""Transform the data."""
return [YFinanceIncomeStatementData.model_validate(d) for d in data]

0 comments on commit 1267573

Please sign in to comment.