Skip to content

Commit

Permalink
[BugFix] - Provider is added to every response item (#6305)
Browse files Browse the repository at this point in the history
* Exclude provider and don't model_dump

* still missing docstrings

* undo package changes

* minor fix

* mypy

* minor fix

* cleaner

* private var

* docstring

* docstrings

* add package builder tests

* ruff

* rename

* rename

* update tests

* minor fix

* fix test

* handle docstring edge cases

* test
  • Loading branch information
montezdesousa authored Apr 12, 2024
1 parent 3172b9e commit b8d1846
Show file tree
Hide file tree
Showing 17 changed files with 316 additions and 223 deletions.
6 changes: 3 additions & 3 deletions openbb_platform/core/openbb_core/api/router/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def build_new_signature(path: str, func: Callable) -> Signature:
)


def validate_output(c_out: OBBject) -> Dict:
def validate_output(c_out: OBBject) -> OBBject:
"""
Validate OBBject object.
Expand Down Expand Up @@ -170,7 +170,7 @@ def exclude_fields_from_api(key: str, value: Any):
for k, v in c_out.model_copy():
exclude_fields_from_api(k, v)

return c_out.model_dump()
return c_out


def build_api_wrapper(
Expand All @@ -188,7 +188,7 @@ def build_api_wrapper(
func.__annotations__ = new_annotations_map

@wraps(wrapped=func)
async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict:
async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
user_settings: UserSettings = UserSettings.model_validate(
kwargs.pop(
"__authenticated_user_settings",
Expand Down
35 changes: 0 additions & 35 deletions openbb_platform/core/openbb_core/app/model/obbject.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""The OBBject."""

from re import sub
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -82,40 +81,6 @@ def __repr__(self) -> str:
]
return f"{self.__class__.__name__}\n\n" + "\n".join(items)

@classmethod
def results_type_repr(cls, params: Optional[Any] = None) -> str:
"""Return the results type representation."""
results_field = cls.model_fields.get("results")
type_repr = "Any"
if results_field:
type_ = params[0] if params else results_field.annotation
type_repr = getattr(type_, "__name__", str(type_))

if json_schema_extra := getattr(results_field, "json_schema_extra", {}):
model = json_schema_extra.get("model", "Any")

if json_schema_extra.get("is_union"):
return f"Union[List[{model}], {model}]"
if json_schema_extra.get("has_list"):
return f"List[{model}]"

return model

if "typing." in str(type_):
unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_))
type_repr = sub(
r"(\w+\.)*(\w+)?(\, NoneType)?",
r"\2",
unpack_optional,
)

return type_repr

@classmethod
def model_parametrized_name(cls, params: Any) -> str:
"""Return the model name with the parameters."""
return f"OBBject[{cls.results_type_repr(params)}]"

def to_df(
self, index: Optional[Union[str, None]] = "date", sort_by: Optional[str] = None
) -> pd.DataFrame:
Expand Down
93 changes: 85 additions & 8 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,33 @@

from dataclasses import dataclass, make_dataclass
from difflib import SequenceMatcher
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)

from fastapi import Query
from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
Field,
SerializeAsAny,
Tag,
create_model,
)
from pydantic.fields import FieldInfo

from openbb_core.app.model.abstract.singleton import SingletonMeta
from openbb_core.app.model.obbject import OBBject
from openbb_core.provider.query_executor import QueryExecutor
from openbb_core.provider.registry_map import MapType, RegistryMap
from openbb_core.provider.utils.helpers import to_snake_case
Expand Down Expand Up @@ -92,12 +107,15 @@ def __init__(
self._registry_map = registry_map or RegistryMap()
self._query_executor = query_executor or QueryExecutor

self._map = self._registry_map.map
self._map = self._registry_map.standard_extra
# TODO: Try these 4 methods in a single iteration
self._model_providers_map = self._generate_model_providers_dc(self._map)
self._params = self._generate_params_dc(self._map)
self._data = self._generate_data_dc(self._map)
self._return_schema = self._generate_return_schema(self._data)
self._return_annotations = self._generate_return_annotations(
self._registry_map.original_models
)

self._available_providers = self._registry_map.available_providers
self._provider_choices = self._get_provider_choices(self._available_providers)
Expand Down Expand Up @@ -148,9 +166,9 @@ def models(self) -> List[str]:
return self._registry_map.models

@property
def return_map(self) -> Dict[str, Dict[str, Any]]:
def return_annotations(self) -> Dict[str, Type[OBBject]]:
"""Return map."""
return self._registry_map.return_map
return self._return_annotations

def create_executor(self) -> QueryExecutor:
"""Get query executor."""
Expand Down Expand Up @@ -242,7 +260,9 @@ def _create_field(
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
" Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
" Multiple comma separated items allowed for provider(s): "
+ ", ".join(multiple) # type: ignore[arg-type]
+ "."
)

provider_field = (
Expand Down Expand Up @@ -396,7 +416,7 @@ def _generate_params_dc(
This creates a dictionary of dataclasses that can be injected as a FastAPI
dependency.
Example:
Example
-------
@dataclass
class CompanyNews(StandardParams):
Expand Down Expand Up @@ -437,7 +457,7 @@ def _generate_model_providers_dc(self, map_: MapType) -> Dict[str, ProviderChoic
This creates a dictionary that maps model names to dataclasses that can be
injected as a FastAPI dependency.
Example:
Example
-------
@dataclass
class CompanyNews(ProviderChoices):
Expand Down Expand Up @@ -471,7 +491,7 @@ def _generate_data_dc(
This creates a dictionary of dataclasses.
Example:
Example
-------
class EquityHistoricalData(StandardData):
date: date
Expand Down Expand Up @@ -546,3 +566,60 @@ def _get_provider_choices(self, available_providers: List[str]) -> type:
fields=[("provider", Literal[tuple(available_providers)])], # type: ignore
bases=(ProviderChoices,),
)

def _generate_return_annotations(
self, original_models: Dict[str, Dict[str, Any]]
) -> Dict[str, Type[OBBject]]:
"""Generate return annotations for FastAPI.
Example
-------
class Data(BaseModel):
...
class EquityData(Data):
price: float
class YFEquityData(EquityData):
yf_field: str
class AVEquityData(EquityData):
av_field: str
class OBBject(BaseModel):
results: List[
SerializeAsAny[
Annotated[
Union[
Annotated[YFEquityData, Tag("yf")],
Annotated[AVEquityData, Tag("av")],
],
Discriminator(get_provider),
]
]
]
"""

def get_provider(v: Type[BaseModel]):
"""Callable to discriminate which BaseModel to use."""
return getattr(v, "_provider", None)

annotations = {}
for name, models in original_models.items():
outer = set()
args = set()
for provider, model in models.items():
data = model["data"]
outer.add(model["results_type"])
args.add(Annotated[data, Tag(provider)])
# We set the provider to use it in discriminator function
setattr(data, "_provider", provider)
meta = Discriminator(get_provider) if len(args) > 1 else None
inner = SerializeAsAny[Annotated[Union[tuple(args)], meta]] # type: ignore[misc,valid-type]
full = Union[tuple((o[inner] if o else inner) for o in outer)] # type: ignore[valid-type]
annotations[name] = create_model(
f"OBBject_{name}",
__base__=OBBject[full], # type: ignore[valid-type]
__doc__=f"OBBject with results of type {name}",
)
return annotations
70 changes: 11 additions & 59 deletions openbb_platform/core/openbb_core/app/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
Mapping,
Optional,
Type,
Union,
get_args,
get_origin,
get_type_hints,
overload,
)

from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field, SerializeAsAny, Tag, create_model
from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias

Expand Down Expand Up @@ -397,10 +396,9 @@ def complete(
callable_=provider_interface.params[model]["extra"],
)

func = cls.inject_return_type(
func = cls.inject_return_annotation(
func=func,
return_map=provider_interface.return_map.get(model, {}),
model=model,
annotation=provider_interface.return_annotations[model],
)

else:
Expand All @@ -417,60 +415,6 @@ def complete(

return func

@staticmethod
def inject_return_type(
func: Callable[P, OBBject],
return_map: Dict[str, dict],
model: str,
) -> Callable[P, OBBject]:
"""Inject full return model into the function. Also updates __name__ and __doc__ for API schemas."""
results: Dict[str, Any] = {"list_type": [], "dict_type": []}

for provider, return_data in return_map.items():
if return_data["is_list"]:
results["list_type"].append(
Annotated[return_data["model"], Tag(provider)]
)
continue

results["dict_type"].append(Annotated[return_data["model"], Tag(provider)])

list_models, union_models = results.values()

return_types = []
for t, v in results.items():
if not v:
continue

inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type]
Annotated[
Union[tuple(v)], # type: ignore
Field(discriminator="provider"),
]
]
return_types.append(List[inner_type] if t == "list_type" else inner_type)

return_type = create_model(
f"OBBject_{model}",
__base__=OBBject,
__doc__=f"OBBject with results of type {model}",
results=(
Optional[Union[tuple(return_types)]], # type: ignore
Field(
None,
description="Serializable results.",
json_schema_extra={
"model": model,
"has_list": bool(len(list_models) > 0),
"is_union": bool(list_models and union_models),
},
),
),
)

func.__annotations__["return"] = return_type
return func

@staticmethod
def polish_return_schema(func: Callable[P, OBBject]) -> Callable[P, OBBject]:
"""Polish API schemas by filling `__doc__` and `__name__`."""
Expand Down Expand Up @@ -517,6 +461,14 @@ def inject_dependency(
func.__annotations__[arg] = Annotated[callable_, Depends()] # type: ignore
return func

@staticmethod
def inject_return_annotation(
func: Callable[P, OBBject], annotation: Type[OBBject]
) -> Callable[P, OBBject]:
"""Annotate function with return annotation."""
func.__annotations__["return"] = annotation
return func

@staticmethod
def get_description(func: Callable) -> str:
"""Get description from docstring."""
Expand Down
Loading

0 comments on commit b8d1846

Please sign in to comment.