Skip to content

Commit

Permalink
[BugFix] - multiple items allowed in provider parameters (#6256)
Browse files Browse the repository at this point in the history
* fix extra multiple items in extra parameters

* fix: integration test + parametrized typing

* rebuild
  • Loading branch information
montezdesousa authored Mar 28, 2024
1 parent 4b5787b commit db9960a
Show file tree
Hide file tree
Showing 24 changed files with 289 additions and 280 deletions.
4 changes: 4 additions & 0 deletions openbb_platform/core/openbb_core/api/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ async def lifespan(_: FastAPI):
@app.exception_handler(Exception)
async def api_exception_handler(_: Request, exc: Exception):
"""Exception handler for all other exceptions."""
if Env().DEBUG_MODE:
raise exc
logger.error(exc)
return JSONResponse(
status_code=404,
Expand All @@ -101,6 +103,8 @@ async def api_exception_handler(_: Request, exc: Exception):
@app.exception_handler(OpenBBError)
async def openbb_exception_handler(_: Request, exc: OpenBBError):
"""Exception handler for OpenBB errors."""
if Env().DEBUG_MODE:
raise exc
logger.error(exc.original)
openbb_error = exc.original
status_code = 400 if "No results" in str(openbb_error) else 500
Expand Down
4 changes: 2 additions & 2 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def _create_field(
multiple := extra.get("multiple_items_allowed") # type: ignore
):
if provider_name:
additional_description += " Multiple items allowed."
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
" Multiple items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
" Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
)

provider_field = (
Expand Down
18 changes: 11 additions & 7 deletions openbb_platform/core/openbb_core/app/static/package_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,13 +788,18 @@ def build_command_method_body(path: str, func: Callable):
code += " simplefilter('always', DeprecationWarning)\n"
code += f""" warn("{deprecation_message}", category=DeprecationWarning, stacklevel=2)\n\n"""

extra_info = {}
info = {}

code += " return self._run(\n"
code += f""" "{path}",\n"""
code += " **filter_inputs(\n"
for name, param in parameter_map.items():
if name == "extra_params":
fields = param.annotation.__args__[0].__dataclass_fields__
values = {k: k for k in fields}
for k in values:
if extra := MethodDefinition.get_extra(fields[k]):
info[k] = extra
code += f" {name}=kwargs,\n"
elif name == "provider_choices":
field = param.annotation.__args__[0].__dataclass_fields__["provider"]
Expand All @@ -808,19 +813,18 @@ def build_command_method_body(path: str, func: Callable):
code += " },\n"
elif MethodDefinition.is_annotated_dc(param.annotation):
fields = param.annotation.__args__[0].__dataclass_fields__
value = {k: k for k in fields}
values = {k: k for k in fields}
code += f" {name}={{\n"
for k, v in value.items():
for k, v in values.items():
code += f' "{k}": {v},\n'
# TODO: Extend this to extra_params
if extra := MethodDefinition.get_extra(fields[k]):
extra_info[k] = extra
info[k] = extra
code += " },\n"
else:
code += f" {name}={name},\n"

if extra_info:
code += f" extra_info={extra_info},\n"
if info:
code += f" info={info},\n"

if MethodDefinition.is_data_processing_function(path):
code += " data_processing=True,\n"
Expand Down
6 changes: 3 additions & 3 deletions openbb_platform/core/openbb_core/app/static/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@

def filter_inputs(
data_processing: bool = False,
extra_info: Optional[Dict[str, Dict[str, List[str]]]] = None,
info: Optional[Dict[str, Dict[str, List[str]]]] = None,
**kwargs,
) -> dict:
"""Filter command inputs."""
for key, value in kwargs.items():
if data_processing and key == "data":
kwargs[key] = convert_to_basemodel(value)

if extra_info:
if info:
PROPERTY = "multiple_items_allowed"

# Here we check if list items are passed and multiple items allowed for
# the given provider/input combination. In that case we transform the list
# into a comma-separated string
for field, props in extra_info.items():
for field, props in info.items():
if PROPERTY in props and (
provider := kwargs.get("provider_choices", {}).get("provider")
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class AnalystSearchQueryParams(QueryParams):

analyst_name: Optional[str] = Field(
default=None,
description="A comma separated list of analyst names to bring back."
+ " Omitting will bring back all available analysts.",
description="Analyst names to return."
+ " Omitting will return all available analysts.",
)
firm_name: Optional[str] = Field(
default=None,
description="A comma separated list of firm names to bring back."
+ " Omitting will bring back all available firms.",
description="Firm names to return."
+ " Omitting will return all available firms.",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class CompanyNewsQueryParams(QueryParams):

symbol: Optional[str] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("symbol", "")
+ " This endpoint will accept multiple symbols separated by commas.",
description=QUERY_DESCRIPTIONS.get("symbol", ""),
)
start_date: Optional[dateType] = Field(
default=None, description=QUERY_DESCRIPTIONS.get("start_date", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
class EquityQuoteQueryParams(QueryParams):
"""Equity Quote Query."""

symbol: str = Field(
description=QUERY_DESCRIPTIONS.get("symbol", "")
+ " This endpoint will accept multiple symbols separated by commas."
)
symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))

@field_validator("symbol", mode="before", check_fields=False)
@classmethod
Expand Down
14 changes: 14 additions & 0 deletions openbb_platform/extensions/equity/integration/test_equity_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,20 @@ def test_equity_estimates_price_target(params, headers):
"firm_ids": None,
"firm_name": "Barclays",
"analyst_name": None,
"page": 0,
}
),
(
{
"limit": 3,
"provider": "benzinga",
# optional provider params
"fields": None,
"analyst_ids": None,
"firm_ids": None,
"firm_name": "Barclays,Credit Suisse",
"analyst_name": None,
"page": 1,
}
),
],
Expand Down
4 changes: 2 additions & 2 deletions openbb_platform/extensions/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom pytest configuration for the extensions."""

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List

import pytest
from openbb_core.app.router import CommandMap
Expand All @@ -13,7 +13,7 @@
# ruff: noqa: SIM114


def parametrize(argnames: str, argvalues: List[Tuple[Any, ...]], **kwargs):
def parametrize(argnames: str, argvalues: List[Dict[str, Any]], **kwargs):
"""Custom parametrize decorator that filters test cases based on the environment."""

routers, providers, obbject_ext = list_openbb_extensions()
Expand Down
Loading

0 comments on commit db9960a

Please sign in to comment.