Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] - Test parametrize skips charting tests #6264

Merged
merged 25 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
# Run linters for openbb_platform
if [ -n "${{ env.platform_files }}" ]; then
pylint ${{ env.platform_files }}
mypy ${{ env.platform_files }} --ignore-missing-imports --check-untyped-defs
mypy ${{ env.platform_files }} --ignore-missing-imports --scripts-are-modules --check-untyped-defs
else
echo "No Python files changed in openbb_platform"
fi
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ repos:
entry: mypy
language: python
"types_or": [python, pyi]
args: ["--ignore-missing-imports", "--scripts-are-modules"]
args: ["--ignore-missing-imports", "--scripts-are-modules", "--check-untyped-defs"]
require_serial: true
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
Expand Down
8 changes: 6 additions & 2 deletions openbb_platform/extensions/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom pytest configuration for the extensions."""

from typing import Any, Dict, List
from typing import Dict, List

import pytest
from openbb_core.app.router import CommandMap
Expand All @@ -13,7 +13,7 @@
# ruff: noqa: SIM114


def parametrize(argnames: str, argvalues: List[Dict[str, Any]], **kwargs):
def parametrize(argnames: str, argvalues: List, **kwargs):
"""Custom parametrize decorator that filters test cases based on the environment."""

routers, providers, obbject_ext = list_openbb_extensions()
Expand Down Expand Up @@ -49,6 +49,10 @@ def decorator(function):
elif "provider" not in args and function_name_v3 in commands:
# Handle edge case
filtered_argvalues.append(args)
elif extension_name in obbject_ext:
filtered_argvalues.append(args)

# If filtered_argvalues is empty, pytest will skip the test!
return pytest.mark.parametrize(argnames, filtered_argvalues, **kwargs)(
function
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,6 @@ def test_charting_equity_price_historical(params, headers):
assert list(chart.keys()) == ["content", "format"]


@parametrize(
"params",
[({"symbol": "AAPL", "limit": 100, "chart": True})],
)
@pytest.mark.integration
def test_charting_equity_fundamental_multiples(params, headers):
"""Test chart equity multiples."""
params = {p: v for p, v in params.items() if v}

query_str = get_querystring(params, [])
url = f"http://0.0.0.0:8000/api/v1/equity/fundamental/multiples?{query_str}"
result = requests.get(url, headers=headers, timeout=10)
assert isinstance(result, requests.Response)
assert result.status_code == 200

chart = result.json()["chart"]
fig = chart.pop("fig", {})

assert chart
assert not fig
assert list(chart.keys()) == ["content", "format"]


@parametrize(
"params",
[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Charting router."""

import json
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import pandas as pd
from openbb_core.app.model.charts.chart import ChartFormat
Expand Down Expand Up @@ -188,7 +188,6 @@ def technical_cones(
**kwargs: TechnicalConesChartQueryParams,
) -> Tuple["OpenBBFigure", Dict[str, Any]]:
"""Volatility Cones Chart."""

data = kwargs.get("data")

if isinstance(data, pd.DataFrame) and not data.empty and "window" in data.columns:
Expand Down Expand Up @@ -286,10 +285,9 @@ def technical_cones(


def economy_fred_series(
**kwargs: FredSeriesChartQueryParams,
**kwargs: Union[Any, FredSeriesChartQueryParams],
) -> Tuple["OpenBBFigure", Dict[str, Any]]:
"""FRED Series Chart."""

ytitle_dict = {
"chg": "Change",
"ch1": "Change From Year Ago",
Expand Down Expand Up @@ -385,12 +383,9 @@ def z_score_standardization(data: pd.Series) -> pd.Series:
+ " Override this error by setting `allow_unsafe = True`."
)

y1_units = y_units[0]

y1_units = y_units[0] if y_units else None
y1title = y1_units

y2title = y_units[1] if len(y_units) > 1 else None

xtitle = ""

# If the request was transformed, the y-axis will be shared under these conditions.
Expand All @@ -401,8 +396,9 @@ def z_score_standardization(data: pd.Series) -> pd.Series:
y2title = None

# Set the title for the chart.
if kwargs.get("title"):
title = kwargs.get("title")
title: str = ""
if isinstance(kwargs, dict) and title in kwargs:
title = kwargs["title"]
else:
if metadata.get(columns[0]):
title = metadata.get(columns[0]).get("title") if len(columns) == 1 else "FRED Series" # type: ignore
Expand All @@ -412,7 +408,7 @@ def z_score_standardization(data: pd.Series) -> pd.Series:
title = f"{title} - {transform_title}" if transform_title else title

# Define this to use as a check.
y3title = ""
y3title: Optional[str] = ""

# Create the figure object with subplots.
fig = OpenBBFigure().create_subplots(
Expand Down Expand Up @@ -456,14 +452,14 @@ def z_score_standardization(data: pd.Series) -> pd.Series:
if kwargs.get("y2title") and y2title is not None:
y2title = kwargs.get("y2title")
# Set the x-axis title, if suppiled.
if kwargs.get("xtitle"):
xtitle = kwargs.get("xtitle")
if isinstance(kwargs, dict) and "xtitle" in kwargs:
xtitle = kwargs["xtitle"]
# If the data was normalized, set the title to reflect this.
if normalize:
y1title = None
y2title = None
y3title = None
title = f"{title} - Normalized"
title = f"{title} - Normalized" if title else "Normalized"

# Now update the layout of the complete figure.
fig.update_layout(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def send_figure(
self.send_outgoing(outgoing)

if export_image and isinstance(export_image, Path):
if self.loop.is_closed():
if self.loop.is_closed(): # type: ignore[has-type]
# Create a new event loop
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Type
"""Base class for charting plugins."""

from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union

import pandas as pd

from .data_classes import ChartIndicators, TAIndicator


def columns_regex(df_ta: pd.DataFrame, name: str) -> List[str]:
"""Return columns that match regex name"""
"""Return columns that match regex name."""
column_name = df_ta.filter(regex=rf"{name}(?=[^\d]|$)").columns.tolist()

return column_name
Expand All @@ -26,6 +28,7 @@ def __init__(
self.attrs = attrs

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call the indicator function."""
return self.func(*args, **kwargs)


Expand All @@ -39,6 +42,7 @@ class PluginMeta(type):
__subplots__: List[str] = []

def __new__(mcs: Type["PluginMeta"], *args: Any, **kwargs: Any) -> "PluginMeta":
"""Create a new instance of the class."""
name, bases, attrs = args
indicators: Dict[str, Indicator] = {}
cls_attrs: Dict[str, list] = {
Expand Down Expand Up @@ -76,6 +80,7 @@ def __new__(mcs: Type["PluginMeta"], *args: Any, **kwargs: Any) -> "PluginMeta":
return new_cls

def __iter__(cls: Type["PluginMeta"]) -> Iterator[Indicator]: # type: ignore
"""Iterate over the indicators."""
return iter(cls.__indicators__)

# pylint: disable=unused-argument
Expand All @@ -88,11 +93,11 @@ class PltTA(metaclass=PluginMeta):

indicators: ChartIndicators
intraday: bool = False
df_stock: pd.DataFrame
df_ta: pd.DataFrame
df_stock: Union[pd.DataFrame, pd.Series]
df_ta: Optional[pd.DataFrame] = None
df_fib: pd.DataFrame
close_column: Optional[str] = "close"
params: Dict[str, TAIndicator] = {}
params: Optional[Dict[str, TAIndicator]] = {}
inchart_colors: List[str] = []
show_volume: bool = True

Expand All @@ -104,6 +109,7 @@ class PltTA(metaclass=PluginMeta):

# pylint: disable=unused-argument
def __new__(cls, *args: Any, **kwargs: Any) -> "PltTA":
"""Create a new instance of the class."""
if cls is PltTA:
raise TypeError("Can't instantiate abstract class Plugin directly")
self = super().__new__(cls)
Expand Down Expand Up @@ -132,14 +138,15 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "PltTA":

@property
def ma_mode(self) -> List[str]:
"""Moving average mode."""
return list(set(self.__ma_mode__))

@ma_mode.setter
def ma_mode(self, value: List[str]):
self.__ma_mode__ = value

def add_plugins(self, plugins: List["PltTA"]) -> None:
"""Add plugins to current instance"""
"""Add plugins to current instance."""
for plugin in plugins:
for item in plugin.__indicators__:
# pylint: disable=unnecessary-dunder-call
Expand All @@ -161,7 +168,7 @@ def add_plugins(self, plugins: List["PltTA"]) -> None:
getattr(self, attr).extend(value)

def remove_plugins(self, plugins: List["PltTA"]) -> None:
"""Remove plugins from current instance"""
"""Remove plugins from current instance."""
for plugin in plugins:
for item in plugin.__indicators__:
delattr(self, item.name)
Expand All @@ -171,10 +178,11 @@ def remove_plugins(self, plugins: List["PltTA"]) -> None:
delattr(self, static_method)

def __iter__(self) -> Iterator[Indicator]:
"""Iterate over the indicators."""
return iter(self.__indicators__)

def get_float_precision(self) -> str:
"""Returns f-string precision format"""
"""Returns f-string precision format."""
price = self.df_stock[self.close_column].tail(1).values[0]
float_precision = (
",.2f" if price > 1.10 else "" if len(str(price)) < 8 else ".6f"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ class TAIndicator:
]
args: List[Arguments]

def __post_init__(self):
"""Post init."""
self.args = [Arguments(**arg) for arg in self.args]

def __iter__(self):
"""Return iterator."""
return iter(self.args)
Expand Down Expand Up @@ -98,14 +94,6 @@ class ChartIndicators:

indicators: Optional[List[TAIndicator]] = None

def __post_init__(self):
"""Post init."""
self.indicators = (
[TAIndicator(**indicator) for indicator in self.indicators]
if self.indicators
else []
)

def get_indicator(self, name: str) -> Union[TAIndicator, None]:
"""Return indicator with given name."""
output = None
Expand Down Expand Up @@ -165,21 +153,43 @@ def get_options_dict(self, name: str) -> Dict[str, Optional[Arguments]]:
@staticmethod
def get_available_indicators() -> Tuple[str, ...]:
"""Return tuple of available indicators."""
return list(
return tuple(
TAIndicator.__annotations__["name"].__args__ # pylint: disable=E1101
)

@classmethod
def from_dict(cls, indicators: Dict[str, Dict[str, Any]]) -> "ChartIndicators":
"""Return ChartIndicators from dictionary."""
data = []
for indicator in indicators:
args = []
for arg in indicators[indicator]:
args.append({"label": arg, "values": indicators[indicator][arg]})
data.append({"name": indicator, "args": args})

return cls(indicators=data) # type: ignore
def from_dict(
cls, indicators: Dict[str, Dict[str, List[Dict[str, Any]]]]
) -> "ChartIndicators":
"""Return ChartIndicators from dictionary.

Example
-------
ChartIndicators.from_dict(
{
"ad": {
"args": [
{
"label": "AD_LABEL",
"values": [1, 2, 3],
}
]
}
}
)
"""
return cls(
indicators=[
TAIndicator(
name=name, # type: ignore[arg-type]
args=[
Arguments(label=label, values=values)
for label, values in args.items()
],
)
for name, args in indicators.items()
]
)

def to_dataframe(
self, df_ta: pd.DataFrame, ma_mode: Optional[List[str]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class PlotlyTA(PltTA):

inchart_colors: List[str] = []
plugins: List[Type[PltTA]] = []
df_ta: pd.DataFrame = None
df_ta: Optional[pd.DataFrame] = None
close_column: Optional[str] = "close"
has_volume: bool = True
show_volume: bool = True
Expand All @@ -112,11 +112,11 @@ def __new__(cls, *args, **kwargs):
# Creates the instance of the class and loads the plugins
# We set the global variable to the instance of the class so that
# the plugins are only loaded once
PLOTLY_TA = super().__new__(cls)
PLOTLY_TA._locate_plugins(
PLOTLY_TA = super().__new__(cls) # type: ignore[attr-defined, assignment]
PLOTLY_TA._locate_plugins( # type: ignore[attr-defined]
getattr(cls.charting_settings, "debug_mode", False)
)
PLOTLY_TA.add_plugins(PLOTLY_TA.plugins)
PLOTLY_TA.add_plugins(PLOTLY_TA.plugins) # type: ignore[attr-defined, assignment]

return PLOTLY_TA

Expand Down Expand Up @@ -180,7 +180,7 @@ def __plot__(
df_stock = df_stock.to_frame()

if not isinstance(indicators, ChartIndicators):
indicators = ChartIndicators.from_dict(indicators or dict(dict()))
indicators = ChartIndicators.from_dict(indicators or {})

# Apply to_datetime to the index in a way that handles daylight savings.
df_stock.loc[:, "date"] = df_stock.index # type: ignore
Expand Down Expand Up @@ -289,7 +289,7 @@ def _locate_plugins(debug: Optional[bool] = False) -> None:
def _clear_data(self):
"""Clear and reset all data to default values."""
self.df_stock = None
self.indicators = {}
self.indicators = ChartIndicators.from_dict({})
self.params = None
self.intraday = False
self.show_volume = True
Expand Down
Loading
Loading