Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] - Provider is added to every response item #6305

Merged
merged 21 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions openbb_platform/core/openbb_core/api/router/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def build_new_signature(path: str, func: Callable) -> Signature:
)


def validate_output(c_out: OBBject) -> Dict:
def validate_output(c_out: OBBject) -> OBBject:
"""
Validate OBBject object.

Expand Down Expand Up @@ -170,7 +170,7 @@ def exclude_fields_from_api(key: str, value: Any):
for k, v in c_out.model_copy():
exclude_fields_from_api(k, v)

return c_out.model_dump()
return c_out


def build_api_wrapper(
Expand All @@ -188,7 +188,7 @@ def build_api_wrapper(
func.__annotations__ = new_annotations_map

@wraps(wrapped=func)
async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict:
async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
user_settings: UserSettings = UserSettings.model_validate(
kwargs.pop(
"__authenticated_user_settings",
Expand Down
35 changes: 0 additions & 35 deletions openbb_platform/core/openbb_core/app/model/obbject.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""The OBBject."""

from re import sub
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -82,40 +81,6 @@ def __repr__(self) -> str:
]
return f"{self.__class__.__name__}\n\n" + "\n".join(items)

@classmethod
def results_type_repr(cls, params: Optional[Any] = None) -> str:
"""Return the results type representation."""
results_field = cls.model_fields.get("results")
type_repr = "Any"
if results_field:
type_ = params[0] if params else results_field.annotation
type_repr = getattr(type_, "__name__", str(type_))

if json_schema_extra := getattr(results_field, "json_schema_extra", {}):
model = json_schema_extra.get("model", "Any")

if json_schema_extra.get("is_union"):
return f"Union[List[{model}], {model}]"
if json_schema_extra.get("has_list"):
return f"List[{model}]"

return model

if "typing." in str(type_):
unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_))
type_repr = sub(
r"(\w+\.)*(\w+)?(\, NoneType)?",
r"\2",
unpack_optional,
)

return type_repr

@classmethod
def model_parametrized_name(cls, params: Any) -> str:
"""Return the model name with the parameters."""
return f"OBBject[{cls.results_type_repr(params)}]"

def to_df(
self, index: Optional[Union[str, None]] = "date", sort_by: Optional[str] = None
) -> pd.DataFrame:
Expand Down
93 changes: 85 additions & 8 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,33 @@

from dataclasses import dataclass, make_dataclass
from difflib import SequenceMatcher
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)

from fastapi import Query
from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
Field,
SerializeAsAny,
Tag,
create_model,
)
from pydantic.fields import FieldInfo

from openbb_core.app.model.abstract.singleton import SingletonMeta
from openbb_core.app.model.obbject import OBBject
from openbb_core.provider.query_executor import QueryExecutor
from openbb_core.provider.registry_map import MapType, RegistryMap
from openbb_core.provider.utils.helpers import to_snake_case
Expand Down Expand Up @@ -92,12 +107,15 @@ def __init__(
self._registry_map = registry_map or RegistryMap()
self._query_executor = query_executor or QueryExecutor

self._map = self._registry_map.map
self._map = self._registry_map.standard_extra
# TODO: Try these 4 methods in a single iteration
self._model_providers_map = self._generate_model_providers_dc(self._map)
self._params = self._generate_params_dc(self._map)
self._data = self._generate_data_dc(self._map)
self._return_schema = self._generate_return_schema(self._data)
self._return_annotations = self._generate_return_annotations(
self._registry_map.original_models
)

self._available_providers = self._registry_map.available_providers
self._provider_choices = self._get_provider_choices(self._available_providers)
Expand Down Expand Up @@ -148,9 +166,9 @@ def models(self) -> List[str]:
return self._registry_map.models

@property
def return_map(self) -> Dict[str, Dict[str, Any]]:
def return_annotations(self) -> Dict[str, Type[OBBject]]:
"""Return map."""
return self._registry_map.return_map
return self._return_annotations

def create_executor(self) -> QueryExecutor:
"""Get query executor."""
Expand Down Expand Up @@ -242,7 +260,9 @@ def _create_field(
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
" Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
" Multiple comma separated items allowed for provider(s): "
+ ", ".join(multiple) # type: ignore[arg-type]
+ "."
)

provider_field = (
Expand Down Expand Up @@ -396,7 +416,7 @@ def _generate_params_dc(
This creates a dictionary of dataclasses that can be injected as a FastAPI
dependency.

Example:
Example
-------
@dataclass
class CompanyNews(StandardParams):
Expand Down Expand Up @@ -437,7 +457,7 @@ def _generate_model_providers_dc(self, map_: MapType) -> Dict[str, ProviderChoic
This creates a dictionary that maps model names to dataclasses that can be
injected as a FastAPI dependency.

Example:
Example
-------
@dataclass
class CompanyNews(ProviderChoices):
Expand Down Expand Up @@ -471,7 +491,7 @@ def _generate_data_dc(

This creates a dictionary of dataclasses.

Example:
Example
-------
class EquityHistoricalData(StandardData):
date: date
Expand Down Expand Up @@ -546,3 +566,60 @@ def _get_provider_choices(self, available_providers: List[str]) -> type:
fields=[("provider", Literal[tuple(available_providers)])], # type: ignore
bases=(ProviderChoices,),
)

def _generate_return_annotations(
self, original_models: Dict[str, Dict[str, Any]]
) -> Dict[str, Type[OBBject]]:
"""Generate return annotations for FastAPI.

Example
-------
class Data(BaseModel):
...

class EquityData(Data):
price: float

class YFEquityData(EquityData):
yf_field: str

class AVEquityData(EquityData):
av_field: str

class OBBject(BaseModel):
results: List[
SerializeAsAny[
Annotated[
Union[
Annotated[YFEquityData, Tag("yf")],
Annotated[AVEquityData, Tag("av")],
],
Discriminator(get_provider),
]
]
]
"""

def get_provider(v: Type[BaseModel]):
"""Callable to discriminate which BaseModel to use."""
return getattr(v, "_provider", None)

annotations = {}
for name, models in original_models.items():
outer = set()
args = set()
for provider, model in models.items():
data = model["data"]
outer.add(model["results_type"])
args.add(Annotated[data, Tag(provider)])
# We set the provider to use it in discriminator function
setattr(data, "_provider", provider)
meta = Discriminator(get_provider) if len(args) > 1 else None
inner = SerializeAsAny[Annotated[Union[tuple(args)], meta]] # type: ignore[misc,valid-type]
full = Union[tuple((o[inner] if o else inner) for o in outer)] # type: ignore[valid-type]
annotations[name] = create_model(
f"OBBject_{name}",
__base__=OBBject[full], # type: ignore[valid-type]
__doc__=f"OBBject with results of type {name}",
)
return annotations
70 changes: 11 additions & 59 deletions openbb_platform/core/openbb_core/app/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
Mapping,
Optional,
Type,
Union,
get_args,
get_origin,
get_type_hints,
overload,
)

from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field, SerializeAsAny, Tag, create_model
from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias

Expand Down Expand Up @@ -397,10 +396,9 @@ def complete(
callable_=provider_interface.params[model]["extra"],
)

func = cls.inject_return_type(
func = cls.inject_return_annotation(
func=func,
return_map=provider_interface.return_map.get(model, {}),
model=model,
annotation=provider_interface.return_annotations[model],
)

else:
Expand All @@ -417,60 +415,6 @@ def complete(

return func

@staticmethod
def inject_return_type(
func: Callable[P, OBBject],
return_map: Dict[str, dict],
model: str,
) -> Callable[P, OBBject]:
"""Inject full return model into the function. Also updates __name__ and __doc__ for API schemas."""
results: Dict[str, Any] = {"list_type": [], "dict_type": []}

for provider, return_data in return_map.items():
if return_data["is_list"]:
results["list_type"].append(
Annotated[return_data["model"], Tag(provider)]
)
continue

results["dict_type"].append(Annotated[return_data["model"], Tag(provider)])

list_models, union_models = results.values()

return_types = []
for t, v in results.items():
if not v:
continue

inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type]
Annotated[
Union[tuple(v)], # type: ignore
Field(discriminator="provider"),
]
]
return_types.append(List[inner_type] if t == "list_type" else inner_type)

return_type = create_model(
f"OBBject_{model}",
__base__=OBBject,
__doc__=f"OBBject with results of type {model}",
results=(
Optional[Union[tuple(return_types)]], # type: ignore
Field(
None,
description="Serializable results.",
json_schema_extra={
"model": model,
"has_list": bool(len(list_models) > 0),
"is_union": bool(list_models and union_models),
},
),
),
)

func.__annotations__["return"] = return_type
return func

@staticmethod
def polish_return_schema(func: Callable[P, OBBject]) -> Callable[P, OBBject]:
"""Polish API schemas by filling `__doc__` and `__name__`."""
Expand Down Expand Up @@ -517,6 +461,14 @@ def inject_dependency(
func.__annotations__[arg] = Annotated[callable_, Depends()] # type: ignore
return func

@staticmethod
def inject_return_annotation(
func: Callable[P, OBBject], annotation: Type[OBBject]
) -> Callable[P, OBBject]:
"""Annotate function with return annotation."""
func.__annotations__["return"] = annotation
return func

@staticmethod
def get_description(func: Callable) -> str:
"""Get description from docstring."""
Expand Down
Loading
Loading