Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] - multiple items allowed in provider parameters #6256

Merged
merged 5 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openbb_platform/core/openbb_core/api/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ async def lifespan(_: FastAPI):
@app.exception_handler(Exception)
async def api_exception_handler(_: Request, exc: Exception):
"""Exception handler for all other exceptions."""
if Env().DEBUG_MODE:
raise exc
logger.error(exc)
return JSONResponse(
status_code=404,
Expand All @@ -101,6 +103,8 @@ async def api_exception_handler(_: Request, exc: Exception):
@app.exception_handler(OpenBBError)
async def openbb_exception_handler(_: Request, exc: OpenBBError):
"""Exception handler for OpenBB errors."""
if Env().DEBUG_MODE:
raise exc
logger.error(exc.original)
openbb_error = exc.original
status_code = 400 if "No results" in str(openbb_error) else 500
Expand Down
4 changes: 2 additions & 2 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def _create_field(
multiple := extra.get("multiple_items_allowed") # type: ignore
):
if provider_name:
additional_description += " Multiple items allowed."
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
" Multiple items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
" Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
)

provider_field = (
Expand Down
18 changes: 11 additions & 7 deletions openbb_platform/core/openbb_core/app/static/package_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,13 +788,18 @@ def build_command_method_body(path: str, func: Callable):
code += " simplefilter('always', DeprecationWarning)\n"
code += f""" warn("{deprecation_message}", category=DeprecationWarning, stacklevel=2)\n\n"""

extra_info = {}
info = {}

code += " return self._run(\n"
code += f""" "{path}",\n"""
code += " **filter_inputs(\n"
for name, param in parameter_map.items():
if name == "extra_params":
fields = param.annotation.__args__[0].__dataclass_fields__
values = {k: k for k in fields}
for k in values:
if extra := MethodDefinition.get_extra(fields[k]):
info[k] = extra
code += f" {name}=kwargs,\n"
elif name == "provider_choices":
field = param.annotation.__args__[0].__dataclass_fields__["provider"]
Expand All @@ -808,19 +813,18 @@ def build_command_method_body(path: str, func: Callable):
code += " },\n"
elif MethodDefinition.is_annotated_dc(param.annotation):
fields = param.annotation.__args__[0].__dataclass_fields__
value = {k: k for k in fields}
values = {k: k for k in fields}
code += f" {name}={{\n"
for k, v in value.items():
for k, v in values.items():
code += f' "{k}": {v},\n'
# TODO: Extend this to extra_params
if extra := MethodDefinition.get_extra(fields[k]):
extra_info[k] = extra
info[k] = extra
code += " },\n"
else:
code += f" {name}={name},\n"

if extra_info:
code += f" extra_info={extra_info},\n"
if info:
code += f" info={info},\n"

if MethodDefinition.is_data_processing_function(path):
code += " data_processing=True,\n"
Expand Down
6 changes: 3 additions & 3 deletions openbb_platform/core/openbb_core/app/static/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@

def filter_inputs(
data_processing: bool = False,
extra_info: Optional[Dict[str, Dict[str, List[str]]]] = None,
info: Optional[Dict[str, Dict[str, List[str]]]] = None,
**kwargs,
) -> dict:
"""Filter command inputs."""
for key, value in kwargs.items():
if data_processing and key == "data":
kwargs[key] = convert_to_basemodel(value)

if extra_info:
if info:
PROPERTY = "multiple_items_allowed"

# Here we check if list items are passed and multiple items allowed for
# the given provider/input combination. In that case we transform the list
# into a comma-separated string
for field, props in extra_info.items():
for field, props in info.items():
if PROPERTY in props and (
provider := kwargs.get("provider_choices", {}).get("provider")
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class AnalystSearchQueryParams(QueryParams):

analyst_name: Optional[str] = Field(
default=None,
description="A comma separated list of analyst names to bring back."
+ " Omitting will bring back all available analysts.",
description="Analyst names to return."
+ " Omitting will return all available analysts.",
)
firm_name: Optional[str] = Field(
default=None,
description="A comma separated list of firm names to bring back."
+ " Omitting will bring back all available firms.",
description="Firm names to return."
+ " Omitting will return all available firms.",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class CompanyNewsQueryParams(QueryParams):

symbol: Optional[str] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("symbol", "")
+ " This endpoint will accept multiple symbols separated by commas.",
description=QUERY_DESCRIPTIONS.get("symbol", ""),
)
start_date: Optional[dateType] = Field(
default=None, description=QUERY_DESCRIPTIONS.get("start_date", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
class EquityQuoteQueryParams(QueryParams):
"""Equity Quote Query."""

symbol: str = Field(
description=QUERY_DESCRIPTIONS.get("symbol", "")
+ " This endpoint will accept multiple symbols separated by commas."
)
symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))

@field_validator("symbol", mode="before", check_fields=False)
@classmethod
Expand Down
14 changes: 14 additions & 0 deletions openbb_platform/extensions/equity/integration/test_equity_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,20 @@ def test_equity_estimates_price_target(params, headers):
"firm_ids": None,
"firm_name": "Barclays",
"analyst_name": None,
"page": 0,
}
),
(
{
"limit": 3,
"provider": "benzinga",
# optional provider params
"fields": None,
"analyst_ids": None,
"firm_ids": None,
"firm_name": "Barclays,Credit Suisse",
"analyst_name": None,
"page": 1,
}
),
],
Expand Down
4 changes: 2 additions & 2 deletions openbb_platform/extensions/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom pytest configuration for the extensions."""

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List

import pytest
from openbb_core.app.router import CommandMap
Expand All @@ -13,7 +13,7 @@
# ruff: noqa: SIM114


def parametrize(argnames: str, argvalues: List[Tuple[Any, ...]], **kwargs):
def parametrize(argnames: str, argvalues: List[Dict[str, Any]], **kwargs):
"""Custom parametrize decorator that filters test cases based on the environment."""

routers, providers, obbject_ext = list_openbb_extensions()
Expand Down
Loading
Loading