Skip to content

Commit

Permalink
[Feature] Custom Provider choices available on the reference.json (#…
Browse files Browse the repository at this point in the history
…6409)

* change package builder and argparse translator to account for custom chocies defined in providers

* default reference
  • Loading branch information
hjoaquim authored May 14, 2024
1 parent 0eee602 commit 88cdd75
Show file tree
Hide file tree
Showing 3 changed files with 6,375 additions and 3,054 deletions.
14 changes: 8 additions & 6 deletions cli/openbb_cli/argparse_translator/argparse_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CustomArgument(BaseModel):
action: Literal["store_true", "store"]
help: str
nargs: Optional[Literal["+"]]
choices: Optional[Any]
choices: Optional[Tuple]

@model_validator(mode="after") # type: ignore
@classmethod
Expand Down Expand Up @@ -117,7 +117,7 @@ def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
return "+"
return None

def _get_choices(self, type_: str) -> Tuple:
def _get_choices(self, type_: str, custom_choices: Any) -> Tuple:
"""Get the choices for the given type."""
type_ = self._make_type_parsable(type_) # type: ignore
type_origin = get_origin(type_)
Expand All @@ -126,14 +126,12 @@ def _get_choices(self, type_: str) -> Tuple:

if type_origin is Literal:
choices = get_args(type_)
# param_type = type(choices[0])

if type_origin is list:
type_ = get_args(type_)[0]

if get_origin(type_) is Literal:
choices = get_args(type_)
# param_type = type(choices[0])

if type_origin is Union and type(None) in get_args(type_):
# remove NoneType from the args
Expand All @@ -145,7 +143,9 @@ def _get_choices(self, type_: str) -> Tuple:

if get_origin(type_) is Literal:
choices = get_args(type_)
# param_type = type(choices[0])

if custom_choices:
return tuple(custom_choices)

return choices

Expand Down Expand Up @@ -174,7 +174,9 @@ def build_custom_groups(self):
action="store" if type_ != bool else "store_true",
help=arg["description"],
nargs=self._get_nargs(type_), # type: ignore
choices=self._get_choices(arg["type"]),
choices=self._get_choices(
arg["type"], custom_choices=arg["choices"]
),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,10 @@ def _get_provider_field_params(
.strip().replace("\n", " ").replace(" ", " ").replace('"', "'")
) # fmt: skip

extra = field_info.json_schema_extra or {}

# Add information for the providers supporting multiple symbols
if params_type == "QueryParams" and (extra := field_info.json_schema_extra):
if params_type == "QueryParams" and extra:

providers = []
for p, v in extra.items(): # type: ignore[union-attr]
Expand Down Expand Up @@ -1512,6 +1514,7 @@ def _get_provider_field_params(
"description": cleaned_description,
"default": default_value,
"optional": not is_required,
"choices": extra.get("choices"),
}
)

Expand Down
Loading

0 comments on commit 88cdd75

Please sign in to comment.