diff --git a/.changeset/add_literal_enums_config_setting.md b/.changeset/add_literal_enums_config_setting.md new file mode 100644 index 000000000..82b7a9468 --- /dev/null +++ b/.changeset/add_literal_enums_config_setting.md @@ -0,0 +1,14 @@ +--- +default: minor +--- + +# Add `literal_enums` config setting + +Instead of the default `Enum` classes for enums, you can now generate `Literal` sets wherever `enum` appears in the OpenAPI spec by setting `literal_enums: true` in your config file. + +```yaml +literal_enums: true +``` + +Thanks to @emosenkis for PR #1114 closes #587, #725, #1076, and probably many more. +Thanks also to @eli-bl, @expobrain, @theorm, @chrisguillory, and anyone else who helped getting to this design! diff --git a/README.md b/README.md index efd3dad6e..871f3a296 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,17 @@ class_overrides: The easiest way to find what needs to be overridden is probably to generate your client and go look at everything in the `models` folder. +### literal_enums + +By default, `openapi-python-client` generates classes inheriting for `Enum` for enums. It can instead use `Literal` +values for enums by setting this to `true`: + +```yaml +literal_enums: true +``` + +This is especially useful if enum values, when transformed to their Python names, end up conflicting due to case sensitivity or special symbols. + ### project_name_override and package_name_override Used to change the name of generated client library project/package. If the project name is changed but an override for the package name diff --git a/end_to_end_tests/literal-enums-golden-record/.gitignore b/end_to_end_tests/literal-enums-golden-record/.gitignore new file mode 100644 index 000000000..79a2c3d73 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/.gitignore @@ -0,0 +1,23 @@ +__pycache__/ +build/ +dist/ +*.egg-info/ +.pytest_cache/ + +# pyenv +.python-version + +# Environments +.env +.venv + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# JetBrains +.idea/ + +/coverage.xml +/.coverage diff --git a/end_to_end_tests/literal-enums-golden-record/README.md b/end_to_end_tests/literal-enums-golden-record/README.md new file mode 100644 index 000000000..2c6268349 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/README.md @@ -0,0 +1,124 @@ +# my-enum-api-client +A client library for accessing My Enum API + +## Usage +First, create a client: + +```python +from my_enum_api_client import Client + +client = Client(base_url="https://api.example.com") +``` + +If the endpoints you're going to hit require authentication, use `AuthenticatedClient` instead: + +```python +from my_enum_api_client import AuthenticatedClient + +client = AuthenticatedClient(base_url="https://api.example.com", token="SuperSecretToken") +``` + +Now call your endpoint and use your models: + +```python +from my_enum_api_client.models import MyDataModel +from my_enum_api_client.api.my_tag import get_my_data_model +from my_enum_api_client.types import Response + +with client as client: + my_data: MyDataModel = get_my_data_model.sync(client=client) + # or if you need more info (e.g. status_code) + response: Response[MyDataModel] = get_my_data_model.sync_detailed(client=client) +``` + +Or do the same thing with an async version: + +```python +from my_enum_api_client.models import MyDataModel +from my_enum_api_client.api.my_tag import get_my_data_model +from my_enum_api_client.types import Response + +async with client as client: + my_data: MyDataModel = await get_my_data_model.asyncio(client=client) + response: Response[MyDataModel] = await get_my_data_model.asyncio_detailed(client=client) +``` + +By default, when you're calling an HTTPS API it will attempt to verify that SSL is working correctly. Using certificate verification is highly recommended most of the time, but sometimes you may need to authenticate to a server (especially an internal server) using a custom certificate bundle. + +```python +client = AuthenticatedClient( + base_url="https://internal_api.example.com", + token="SuperSecretToken", + verify_ssl="/path/to/certificate_bundle.pem", +) +``` + +You can also disable certificate validation altogether, but beware that **this is a security risk**. + +```python +client = AuthenticatedClient( + base_url="https://internal_api.example.com", + token="SuperSecretToken", + verify_ssl=False +) +``` + +Things to know: +1. Every path/method combo becomes a Python module with four functions: + 1. `sync`: Blocking request that returns parsed data (if successful) or `None` + 1. `sync_detailed`: Blocking request that always returns a `Request`, optionally with `parsed` set if the request was successful. + 1. `asyncio`: Like `sync` but async instead of blocking + 1. `asyncio_detailed`: Like `sync_detailed` but async instead of blocking + +1. All path/query params, and bodies become method arguments. +1. If your endpoint had any tags on it, the first tag will be used as a module name for the function (my_tag above) +1. Any endpoint which did not have a tag will be in `my_enum_api_client.api.default` + +## Advanced customizations + +There are more settings on the generated `Client` class which let you control more runtime behavior, check out the docstring on that class for more info. You can also customize the underlying `httpx.Client` or `httpx.AsyncClient` (depending on your use-case): + +```python +from my_enum_api_client import Client + +def log_request(request): + print(f"Request event hook: {request.method} {request.url} - Waiting for response") + +def log_response(response): + request = response.request + print(f"Response event hook: {request.method} {request.url} - Status {response.status_code}") + +client = Client( + base_url="https://api.example.com", + httpx_args={"event_hooks": {"request": [log_request], "response": [log_response]}}, +) + +# Or get the underlying httpx client to modify directly with client.get_httpx_client() or client.get_async_httpx_client() +``` + +You can even set the httpx client directly, but beware that this will override any existing settings (e.g., base_url): + +```python +import httpx +from my_enum_api_client import Client + +client = Client( + base_url="https://api.example.com", +) +# Note that base_url needs to be re-set, as would any shared cookies, headers, etc. +client.set_httpx_client(httpx.Client(base_url="https://api.example.com", proxies="http://localhost:8030")) +``` + +## Building / publishing this package +This project uses [Poetry](https://python-poetry.org/) to manage dependencies and packaging. Here are the basics: +1. Update the metadata in pyproject.toml (e.g. authors, version) +1. If you're using a private repository, configure it with Poetry + 1. `poetry config repositories. ` + 1. `poetry config http-basic. ` +1. Publish the client with `poetry publish --build -r ` or, if for public PyPI, just `poetry publish --build` + +If you want to install this client into another project without publishing it (e.g. for development) then: +1. If that project **is using Poetry**, you can simply do `poetry add ` from that project +1. If that project is not using Poetry: + 1. Build a wheel with `poetry build -f wheel` + 1. Install that wheel from the other project `pip install ` diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/__init__.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/__init__.py new file mode 100644 index 000000000..5d1901164 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/__init__.py @@ -0,0 +1,8 @@ +"""A client library for accessing My Enum API""" + +from .client import AuthenticatedClient, Client + +__all__ = ( + "AuthenticatedClient", + "Client", +) diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/__init__.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/__init__.py new file mode 100644 index 000000000..81f9fa241 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/__init__.py @@ -0,0 +1 @@ +"""Contains methods for accessing the API""" diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/__init__.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/bool_enum_tests_bool_enum_post.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/bool_enum_tests_bool_enum_post.py new file mode 100644 index 000000000..92e95162c --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/bool_enum_tests_bool_enum_post.py @@ -0,0 +1,101 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...types import UNSET, Response + + +def _get_kwargs( + *, + bool_enum: bool, +) -> Dict[str, Any]: + params: Dict[str, Any] = {} + + params["bool_enum"] = bool_enum + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: Dict[str, Any] = { + "method": "post", + "url": "/enum/bool", + "params": params, + } + + return _kwargs + + +def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[Any]: + if response.status_code == HTTPStatus.OK: + return None + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[Any]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + bool_enum: bool, +) -> Response[Any]: + """Bool Enum + + Args: + bool_enum (bool): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Any] + """ + + kwargs = _get_kwargs( + bool_enum=bool_enum, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + bool_enum: bool, +) -> Response[Any]: + """Bool Enum + + Args: + bool_enum (bool): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Any] + """ + + kwargs = _get_kwargs( + bool_enum=bool_enum, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/int_enum_tests_int_enum_post.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/int_enum_tests_int_enum_post.py new file mode 100644 index 000000000..77e362b44 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/enums/int_enum_tests_int_enum_post.py @@ -0,0 +1,103 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.an_int_enum import AnIntEnum +from ...types import UNSET, Response + + +def _get_kwargs( + *, + int_enum: AnIntEnum, +) -> Dict[str, Any]: + params: Dict[str, Any] = {} + + json_int_enum: int = int_enum + params["int_enum"] = json_int_enum + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: Dict[str, Any] = { + "method": "post", + "url": "/enum/int", + "params": params, + } + + return _kwargs + + +def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[Any]: + if response.status_code == HTTPStatus.OK: + return None + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[Any]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + int_enum: AnIntEnum, +) -> Response[Any]: + """Int Enum + + Args: + int_enum (AnIntEnum): An enumeration. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Any] + """ + + kwargs = _get_kwargs( + int_enum=int_enum, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + int_enum: AnIntEnum, +) -> Response[Any]: + """Int Enum + + Args: + int_enum (AnIntEnum): An enumeration. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Any] + """ + + kwargs = _get_kwargs( + int_enum=int_enum, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/__init__.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/get_user_list.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/get_user_list.py new file mode 100644 index 000000000..b97c078db --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/get_user_list.py @@ -0,0 +1,257 @@ +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.a_model import AModel +from ...models.an_enum import AnEnum +from ...models.an_enum_with_null import AnEnumWithNull +from ...models.get_user_list_int_enum_header import GetUserListIntEnumHeader +from ...models.get_user_list_string_enum_header import ( + GetUserListStringEnumHeader, +) +from ...types import UNSET, Response, Unset + + +def _get_kwargs( + *, + an_enum_value: List[AnEnum], + an_enum_value_with_null: List[Union[AnEnumWithNull, None]], + an_enum_value_with_only_null: List[None], + int_enum_header: Union[Unset, GetUserListIntEnumHeader] = UNSET, + string_enum_header: Union[Unset, GetUserListStringEnumHeader] = UNSET, +) -> Dict[str, Any]: + headers: Dict[str, Any] = {} + if not isinstance(int_enum_header, Unset): + headers["Int-Enum-Header"] = str(int_enum_header) + + if not isinstance(string_enum_header, Unset): + headers["String-Enum-Header"] = str(string_enum_header) + + params: Dict[str, Any] = {} + + json_an_enum_value = [] + for an_enum_value_item_data in an_enum_value: + an_enum_value_item: str = an_enum_value_item_data + json_an_enum_value.append(an_enum_value_item) + + params["an_enum_value"] = json_an_enum_value + + json_an_enum_value_with_null = [] + for an_enum_value_with_null_item_data in an_enum_value_with_null: + an_enum_value_with_null_item: Union[None, str] + if isinstance(an_enum_value_with_null_item_data, str): + an_enum_value_with_null_item = an_enum_value_with_null_item_data + else: + an_enum_value_with_null_item = an_enum_value_with_null_item_data + json_an_enum_value_with_null.append(an_enum_value_with_null_item) + + params["an_enum_value_with_null"] = json_an_enum_value_with_null + + json_an_enum_value_with_only_null = an_enum_value_with_only_null + + params["an_enum_value_with_only_null"] = json_an_enum_value_with_only_null + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: Dict[str, Any] = { + "method": "get", + "url": "/tests/", + "params": params, + } + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[List["AModel"]]: + if response.status_code == HTTPStatus.OK: + response_200 = [] + _response_200 = response.json() + for response_200_item_data in _response_200: + response_200_item = AModel.from_dict(response_200_item_data) + + response_200.append(response_200_item) + + return response_200 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[List["AModel"]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + an_enum_value: List[AnEnum], + an_enum_value_with_null: List[Union[AnEnumWithNull, None]], + an_enum_value_with_only_null: List[None], + int_enum_header: Union[Unset, GetUserListIntEnumHeader] = UNSET, + string_enum_header: Union[Unset, GetUserListStringEnumHeader] = UNSET, +) -> Response[List["AModel"]]: + """Get List + + Get a list of things + + Args: + an_enum_value (List[AnEnum]): + an_enum_value_with_null (List[Union[AnEnumWithNull, None]]): + an_enum_value_with_only_null (List[None]): + int_enum_header (Union[Unset, GetUserListIntEnumHeader]): + string_enum_header (Union[Unset, GetUserListStringEnumHeader]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[List['AModel']] + """ + + kwargs = _get_kwargs( + an_enum_value=an_enum_value, + an_enum_value_with_null=an_enum_value_with_null, + an_enum_value_with_only_null=an_enum_value_with_only_null, + int_enum_header=int_enum_header, + string_enum_header=string_enum_header, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + an_enum_value: List[AnEnum], + an_enum_value_with_null: List[Union[AnEnumWithNull, None]], + an_enum_value_with_only_null: List[None], + int_enum_header: Union[Unset, GetUserListIntEnumHeader] = UNSET, + string_enum_header: Union[Unset, GetUserListStringEnumHeader] = UNSET, +) -> Optional[List["AModel"]]: + """Get List + + Get a list of things + + Args: + an_enum_value (List[AnEnum]): + an_enum_value_with_null (List[Union[AnEnumWithNull, None]]): + an_enum_value_with_only_null (List[None]): + int_enum_header (Union[Unset, GetUserListIntEnumHeader]): + string_enum_header (Union[Unset, GetUserListStringEnumHeader]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + List['AModel'] + """ + + return sync_detailed( + client=client, + an_enum_value=an_enum_value, + an_enum_value_with_null=an_enum_value_with_null, + an_enum_value_with_only_null=an_enum_value_with_only_null, + int_enum_header=int_enum_header, + string_enum_header=string_enum_header, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + an_enum_value: List[AnEnum], + an_enum_value_with_null: List[Union[AnEnumWithNull, None]], + an_enum_value_with_only_null: List[None], + int_enum_header: Union[Unset, GetUserListIntEnumHeader] = UNSET, + string_enum_header: Union[Unset, GetUserListStringEnumHeader] = UNSET, +) -> Response[List["AModel"]]: + """Get List + + Get a list of things + + Args: + an_enum_value (List[AnEnum]): + an_enum_value_with_null (List[Union[AnEnumWithNull, None]]): + an_enum_value_with_only_null (List[None]): + int_enum_header (Union[Unset, GetUserListIntEnumHeader]): + string_enum_header (Union[Unset, GetUserListStringEnumHeader]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[List['AModel']] + """ + + kwargs = _get_kwargs( + an_enum_value=an_enum_value, + an_enum_value_with_null=an_enum_value_with_null, + an_enum_value_with_only_null=an_enum_value_with_only_null, + int_enum_header=int_enum_header, + string_enum_header=string_enum_header, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + an_enum_value: List[AnEnum], + an_enum_value_with_null: List[Union[AnEnumWithNull, None]], + an_enum_value_with_only_null: List[None], + int_enum_header: Union[Unset, GetUserListIntEnumHeader] = UNSET, + string_enum_header: Union[Unset, GetUserListStringEnumHeader] = UNSET, +) -> Optional[List["AModel"]]: + """Get List + + Get a list of things + + Args: + an_enum_value (List[AnEnum]): + an_enum_value_with_null (List[Union[AnEnumWithNull, None]]): + an_enum_value_with_only_null (List[None]): + int_enum_header (Union[Unset, GetUserListIntEnumHeader]): + string_enum_header (Union[Unset, GetUserListStringEnumHeader]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + List['AModel'] + """ + + return ( + await asyncio_detailed( + client=client, + an_enum_value=an_enum_value, + an_enum_value_with_null=an_enum_value_with_null, + an_enum_value_with_only_null=an_enum_value_with_only_null, + int_enum_header=int_enum_header, + string_enum_header=string_enum_header, + ) + ).parsed diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/post_user_list.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/post_user_list.py new file mode 100644 index 000000000..e76e4be8b --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/api/tests/post_user_list.py @@ -0,0 +1,172 @@ +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.a_model import AModel +from ...models.post_user_list_body import PostUserListBody +from ...types import Response + + +def _get_kwargs( + *, + body: PostUserListBody, +) -> Dict[str, Any]: + headers: Dict[str, Any] = {} + + _kwargs: Dict[str, Any] = { + "method": "post", + "url": "/tests/", + } + + _body = body.to_multipart() + + _kwargs["files"] = _body + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[List["AModel"]]: + if response.status_code == HTTPStatus.OK: + response_200 = [] + _response_200 = response.json() + for response_200_item_data in _response_200: + response_200_item = AModel.from_dict(response_200_item_data) + + response_200.append(response_200_item) + + return response_200 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[List["AModel"]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: PostUserListBody, +) -> Response[List["AModel"]]: + """Post List + + Post a list of things + + Args: + body (PostUserListBody): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[List['AModel']] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + body: PostUserListBody, +) -> Optional[List["AModel"]]: + """Post List + + Post a list of things + + Args: + body (PostUserListBody): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + List['AModel'] + """ + + return sync_detailed( + client=client, + body=body, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: PostUserListBody, +) -> Response[List["AModel"]]: + """Post List + + Post a list of things + + Args: + body (PostUserListBody): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[List['AModel']] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + body: PostUserListBody, +) -> Optional[List["AModel"]]: + """Post List + + Post a list of things + + Args: + body (PostUserListBody): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + List['AModel'] + """ + + return ( + await asyncio_detailed( + client=client, + body=body, + ) + ).parsed diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/client.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/client.py new file mode 100644 index 000000000..0f6d15e84 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/client.py @@ -0,0 +1,268 @@ +import ssl +from typing import Any, Dict, Optional, Union + +import httpx +from attrs import define, evolve, field + + +@define +class Client: + """A class for keeping track of data related to the API + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + """ + + raise_on_unexpected_status: bool = field(default=False, kw_only=True) + _base_url: str = field(alias="base_url") + _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies") + _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers") + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout") + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl") + _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects") + _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args") + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) + + def with_headers(self, headers: Dict[str, str]) -> "Client": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: Dict[str, str]) -> "Client": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "Client": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) + + def set_httpx_client(self, client: httpx.Client) -> "Client": + """Manually set the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "Client": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "Client": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "Client": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) + + +@define +class AuthenticatedClient: + """A Client which has been authenticated for use on secured endpoints + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + token: The token to use for authentication + prefix: The prefix to use for the Authorization header + auth_header_name: The name of the Authorization header + """ + + raise_on_unexpected_status: bool = field(default=False, kw_only=True) + _base_url: str = field(alias="base_url") + _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies") + _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers") + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout") + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl") + _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects") + _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args") + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) + + token: str + prefix: str = "Bearer" + auth_header_name: str = "Authorization" + + def with_headers(self, headers: Dict[str, str]) -> "AuthenticatedClient": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: Dict[str, str]) -> "AuthenticatedClient": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) + + def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient": + """Manually set the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "AuthenticatedClient": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "AuthenticatedClient": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "AuthenticatedClient": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/errors.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/errors.py new file mode 100644 index 000000000..5f92e76ac --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/errors.py @@ -0,0 +1,16 @@ +"""Contains shared errors types that can be raised from API functions""" + + +class UnexpectedStatus(Exception): + """Raised by api functions when the response status an undocumented status and Client.raise_on_unexpected_status is True""" + + def __init__(self, status_code: int, content: bytes): + self.status_code = status_code + self.content = content + + super().__init__( + f"Unexpected status code: {status_code}\n\nResponse content:\n{content.decode(errors='ignore')}" + ) + + +__all__ = ["UnexpectedStatus"] diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/__init__.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/__init__.py new file mode 100644 index 000000000..2bdeafad7 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/__init__.py @@ -0,0 +1,23 @@ +"""Contains all the data models used in inputs/outputs""" + +from .a_model import AModel +from .an_all_of_enum import AnAllOfEnum +from .an_enum import AnEnum +from .an_enum_with_null import AnEnumWithNull +from .an_int_enum import AnIntEnum +from .different_enum import DifferentEnum +from .get_user_list_int_enum_header import GetUserListIntEnumHeader +from .get_user_list_string_enum_header import GetUserListStringEnumHeader +from .post_user_list_body import PostUserListBody + +__all__ = ( + "AModel", + "AnAllOfEnum", + "AnEnum", + "AnEnumWithNull", + "AnIntEnum", + "DifferentEnum", + "GetUserListIntEnumHeader", + "GetUserListStringEnumHeader", + "PostUserListBody", +) diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/a_model.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/a_model.py new file mode 100644 index 000000000..e05fdaa6d --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/a_model.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, List, Type, TypeVar, Union + +from attrs import define as _attrs_define + +from ..models.an_all_of_enum import AnAllOfEnum, check_an_all_of_enum +from ..models.an_enum import AnEnum, check_an_enum +from ..models.different_enum import DifferentEnum, check_different_enum +from ..types import UNSET, Unset + +T = TypeVar("T", bound="AModel") + + +@_attrs_define +class AModel: + """A Model for testing all the ways enums can be used + + Attributes: + an_enum_value (AnEnum): For testing Enums in all the ways they can be used + an_allof_enum_with_overridden_default (AnAllOfEnum): Default: 'overridden_default'. + any_value (Union[Unset, Any]): + an_optional_allof_enum (Union[Unset, AnAllOfEnum]): + nested_list_of_enums (Union[Unset, List[List[DifferentEnum]]]): + """ + + an_enum_value: AnEnum + an_allof_enum_with_overridden_default: AnAllOfEnum = "overridden_default" + any_value: Union[Unset, Any] = UNSET + an_optional_allof_enum: Union[Unset, AnAllOfEnum] = UNSET + nested_list_of_enums: Union[Unset, List[List[DifferentEnum]]] = UNSET + + def to_dict(self) -> Dict[str, Any]: + an_enum_value: str = self.an_enum_value + + an_allof_enum_with_overridden_default: str = self.an_allof_enum_with_overridden_default + + any_value = self.any_value + + an_optional_allof_enum: Union[Unset, str] = UNSET + if not isinstance(self.an_optional_allof_enum, Unset): + an_optional_allof_enum = self.an_optional_allof_enum + + nested_list_of_enums: Union[Unset, List[List[str]]] = UNSET + if not isinstance(self.nested_list_of_enums, Unset): + nested_list_of_enums = [] + for nested_list_of_enums_item_data in self.nested_list_of_enums: + nested_list_of_enums_item = [] + for nested_list_of_enums_item_item_data in nested_list_of_enums_item_data: + nested_list_of_enums_item_item: str = nested_list_of_enums_item_item_data + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + nested_list_of_enums.append(nested_list_of_enums_item) + + field_dict: Dict[str, Any] = {} + field_dict.update( + { + "an_enum_value": an_enum_value, + "an_allof_enum_with_overridden_default": an_allof_enum_with_overridden_default, + } + ) + if any_value is not UNSET: + field_dict["any_value"] = any_value + if an_optional_allof_enum is not UNSET: + field_dict["an_optional_allof_enum"] = an_optional_allof_enum + if nested_list_of_enums is not UNSET: + field_dict["nested_list_of_enums"] = nested_list_of_enums + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + an_enum_value = check_an_enum(d.pop("an_enum_value")) + + an_allof_enum_with_overridden_default = check_an_all_of_enum(d.pop("an_allof_enum_with_overridden_default")) + + any_value = d.pop("any_value", UNSET) + + _an_optional_allof_enum = d.pop("an_optional_allof_enum", UNSET) + an_optional_allof_enum: Union[Unset, AnAllOfEnum] + if isinstance(_an_optional_allof_enum, Unset): + an_optional_allof_enum = UNSET + else: + an_optional_allof_enum = check_an_all_of_enum(_an_optional_allof_enum) + + nested_list_of_enums = [] + _nested_list_of_enums = d.pop("nested_list_of_enums", UNSET) + for nested_list_of_enums_item_data in _nested_list_of_enums or []: + nested_list_of_enums_item = [] + _nested_list_of_enums_item = nested_list_of_enums_item_data + for nested_list_of_enums_item_item_data in _nested_list_of_enums_item: + nested_list_of_enums_item_item = check_different_enum(nested_list_of_enums_item_item_data) + + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + nested_list_of_enums.append(nested_list_of_enums_item) + + a_model = cls( + an_enum_value=an_enum_value, + an_allof_enum_with_overridden_default=an_allof_enum_with_overridden_default, + any_value=any_value, + an_optional_allof_enum=an_optional_allof_enum, + nested_list_of_enums=nested_list_of_enums, + ) + + return a_model diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_all_of_enum.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_all_of_enum.py new file mode 100644 index 000000000..e238b15a9 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_all_of_enum.py @@ -0,0 +1,16 @@ +from typing import Literal, Set, cast + +AnAllOfEnum = Literal["a_default", "bar", "foo", "overridden_default"] + +AN_ALL_OF_ENUM_VALUES: Set[AnAllOfEnum] = { + "a_default", + "bar", + "foo", + "overridden_default", +} + + +def check_an_all_of_enum(value: str) -> AnAllOfEnum: + if value in AN_ALL_OF_ENUM_VALUES: + return cast(AnAllOfEnum, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {AN_ALL_OF_ENUM_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_enum.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_enum.py new file mode 100644 index 000000000..608b22fc4 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_enum.py @@ -0,0 +1,14 @@ +from typing import Literal, Set, cast + +AnEnum = Literal["FIRST_VALUE", "SECOND_VALUE"] + +AN_ENUM_VALUES: Set[AnEnum] = { + "FIRST_VALUE", + "SECOND_VALUE", +} + + +def check_an_enum(value: str) -> AnEnum: + if value in AN_ENUM_VALUES: + return cast(AnEnum, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {AN_ENUM_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_enum_with_null.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_enum_with_null.py new file mode 100644 index 000000000..1519ec27c --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_enum_with_null.py @@ -0,0 +1,14 @@ +from typing import Literal, Set, cast + +AnEnumWithNull = Literal["FIRST_VALUE", "SECOND_VALUE"] + +AN_ENUM_WITH_NULL_VALUES: Set[AnEnumWithNull] = { + "FIRST_VALUE", + "SECOND_VALUE", +} + + +def check_an_enum_with_null(value: str) -> AnEnumWithNull: + if value in AN_ENUM_WITH_NULL_VALUES: + return cast(AnEnumWithNull, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {AN_ENUM_WITH_NULL_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_int_enum.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_int_enum.py new file mode 100644 index 000000000..a3c1108ea --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/an_int_enum.py @@ -0,0 +1,15 @@ +from typing import Literal, Set, cast + +AnIntEnum = Literal[-1, 1, 2] + +AN_INT_ENUM_VALUES: Set[AnIntEnum] = { + -1, + 1, + 2, +} + + +def check_an_int_enum(value: int) -> AnIntEnum: + if value in AN_INT_ENUM_VALUES: + return cast(AnIntEnum, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {AN_INT_ENUM_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/different_enum.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/different_enum.py new file mode 100644 index 000000000..d40045c50 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/different_enum.py @@ -0,0 +1,14 @@ +from typing import Literal, Set, cast + +DifferentEnum = Literal["DIFFERENT", "OTHER"] + +DIFFERENT_ENUM_VALUES: Set[DifferentEnum] = { + "DIFFERENT", + "OTHER", +} + + +def check_different_enum(value: str) -> DifferentEnum: + if value in DIFFERENT_ENUM_VALUES: + return cast(DifferentEnum, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {DIFFERENT_ENUM_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/get_user_list_int_enum_header.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/get_user_list_int_enum_header.py new file mode 100644 index 000000000..50e8114ae --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/get_user_list_int_enum_header.py @@ -0,0 +1,15 @@ +from typing import Literal, Set, cast + +GetUserListIntEnumHeader = Literal[1, 2, 3] + +GET_USER_LIST_INT_ENUM_HEADER_VALUES: Set[GetUserListIntEnumHeader] = { + 1, + 2, + 3, +} + + +def check_get_user_list_int_enum_header(value: int) -> GetUserListIntEnumHeader: + if value in GET_USER_LIST_INT_ENUM_HEADER_VALUES: + return cast(GetUserListIntEnumHeader, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {GET_USER_LIST_INT_ENUM_HEADER_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/get_user_list_string_enum_header.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/get_user_list_string_enum_header.py new file mode 100644 index 000000000..d73cea6a6 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/get_user_list_string_enum_header.py @@ -0,0 +1,15 @@ +from typing import Literal, Set, cast + +GetUserListStringEnumHeader = Literal["one", "three", "two"] + +GET_USER_LIST_STRING_ENUM_HEADER_VALUES: Set[GetUserListStringEnumHeader] = { + "one", + "three", + "two", +} + + +def check_get_user_list_string_enum_header(value: str) -> GetUserListStringEnumHeader: + if value in GET_USER_LIST_STRING_ENUM_HEADER_VALUES: + return cast(GetUserListStringEnumHeader, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {GET_USER_LIST_STRING_ENUM_HEADER_VALUES!r}") diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/post_user_list_body.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/post_user_list_body.py new file mode 100644 index 000000000..e61cb4183 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/models/post_user_list_body.py @@ -0,0 +1,255 @@ +import json +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.an_all_of_enum import AnAllOfEnum, check_an_all_of_enum +from ..models.an_enum import AnEnum, check_an_enum +from ..models.an_enum_with_null import AnEnumWithNull, check_an_enum_with_null +from ..models.different_enum import DifferentEnum, check_different_enum +from ..types import UNSET, Unset + +T = TypeVar("T", bound="PostUserListBody") + + +@_attrs_define +class PostUserListBody: + """ + Attributes: + an_enum_value (Union[Unset, List[AnEnum]]): + an_enum_value_with_null (Union[Unset, List[Union[AnEnumWithNull, None]]]): + an_enum_value_with_only_null (Union[Unset, List[None]]): + an_allof_enum_with_overridden_default (Union[Unset, AnAllOfEnum]): Default: 'overridden_default'. + an_optional_allof_enum (Union[Unset, AnAllOfEnum]): + nested_list_of_enums (Union[Unset, List[List[DifferentEnum]]]): + """ + + an_enum_value: Union[Unset, List[AnEnum]] = UNSET + an_enum_value_with_null: Union[Unset, List[Union[AnEnumWithNull, None]]] = UNSET + an_enum_value_with_only_null: Union[Unset, List[None]] = UNSET + an_allof_enum_with_overridden_default: Union[Unset, AnAllOfEnum] = "overridden_default" + an_optional_allof_enum: Union[Unset, AnAllOfEnum] = UNSET + nested_list_of_enums: Union[Unset, List[List[DifferentEnum]]] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + an_enum_value: Union[Unset, List[str]] = UNSET + if not isinstance(self.an_enum_value, Unset): + an_enum_value = [] + for an_enum_value_item_data in self.an_enum_value: + an_enum_value_item: str = an_enum_value_item_data + an_enum_value.append(an_enum_value_item) + + an_enum_value_with_null: Union[Unset, List[Union[None, str]]] = UNSET + if not isinstance(self.an_enum_value_with_null, Unset): + an_enum_value_with_null = [] + for an_enum_value_with_null_item_data in self.an_enum_value_with_null: + an_enum_value_with_null_item: Union[None, str] + if isinstance(an_enum_value_with_null_item_data, str): + an_enum_value_with_null_item = an_enum_value_with_null_item_data + else: + an_enum_value_with_null_item = an_enum_value_with_null_item_data + an_enum_value_with_null.append(an_enum_value_with_null_item) + + an_enum_value_with_only_null: Union[Unset, List[None]] = UNSET + if not isinstance(self.an_enum_value_with_only_null, Unset): + an_enum_value_with_only_null = self.an_enum_value_with_only_null + + an_allof_enum_with_overridden_default: Union[Unset, str] = UNSET + if not isinstance(self.an_allof_enum_with_overridden_default, Unset): + an_allof_enum_with_overridden_default = self.an_allof_enum_with_overridden_default + + an_optional_allof_enum: Union[Unset, str] = UNSET + if not isinstance(self.an_optional_allof_enum, Unset): + an_optional_allof_enum = self.an_optional_allof_enum + + nested_list_of_enums: Union[Unset, List[List[str]]] = UNSET + if not isinstance(self.nested_list_of_enums, Unset): + nested_list_of_enums = [] + for nested_list_of_enums_item_data in self.nested_list_of_enums: + nested_list_of_enums_item = [] + for nested_list_of_enums_item_item_data in nested_list_of_enums_item_data: + nested_list_of_enums_item_item: str = nested_list_of_enums_item_item_data + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + nested_list_of_enums.append(nested_list_of_enums_item) + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update({}) + if an_enum_value is not UNSET: + field_dict["an_enum_value"] = an_enum_value + if an_enum_value_with_null is not UNSET: + field_dict["an_enum_value_with_null"] = an_enum_value_with_null + if an_enum_value_with_only_null is not UNSET: + field_dict["an_enum_value_with_only_null"] = an_enum_value_with_only_null + if an_allof_enum_with_overridden_default is not UNSET: + field_dict["an_allof_enum_with_overridden_default"] = an_allof_enum_with_overridden_default + if an_optional_allof_enum is not UNSET: + field_dict["an_optional_allof_enum"] = an_optional_allof_enum + if nested_list_of_enums is not UNSET: + field_dict["nested_list_of_enums"] = nested_list_of_enums + + return field_dict + + def to_multipart(self) -> Dict[str, Any]: + an_enum_value: Union[Unset, Tuple[None, bytes, str]] = UNSET + if not isinstance(self.an_enum_value, Unset): + _temp_an_enum_value = [] + for an_enum_value_item_data in self.an_enum_value: + an_enum_value_item: str = an_enum_value_item_data + _temp_an_enum_value.append(an_enum_value_item) + an_enum_value = (None, json.dumps(_temp_an_enum_value).encode(), "application/json") + + an_enum_value_with_null: Union[Unset, Tuple[None, bytes, str]] = UNSET + if not isinstance(self.an_enum_value_with_null, Unset): + _temp_an_enum_value_with_null = [] + for an_enum_value_with_null_item_data in self.an_enum_value_with_null: + an_enum_value_with_null_item: Union[None, str] + if isinstance(an_enum_value_with_null_item_data, str): + an_enum_value_with_null_item = an_enum_value_with_null_item_data + else: + an_enum_value_with_null_item = an_enum_value_with_null_item_data + _temp_an_enum_value_with_null.append(an_enum_value_with_null_item) + an_enum_value_with_null = (None, json.dumps(_temp_an_enum_value_with_null).encode(), "application/json") + + an_enum_value_with_only_null: Union[Unset, Tuple[None, bytes, str]] = UNSET + if not isinstance(self.an_enum_value_with_only_null, Unset): + _temp_an_enum_value_with_only_null = self.an_enum_value_with_only_null + an_enum_value_with_only_null = ( + None, + json.dumps(_temp_an_enum_value_with_only_null).encode(), + "application/json", + ) + + an_allof_enum_with_overridden_default: Union[Unset, Tuple[None, bytes, str]] = UNSET + if not isinstance(self.an_allof_enum_with_overridden_default, Unset): + an_allof_enum_with_overridden_default = ( + None, + str(self.an_allof_enum_with_overridden_default).encode(), + "text/plain", + ) + + an_optional_allof_enum: Union[Unset, Tuple[None, bytes, str]] = UNSET + if not isinstance(self.an_optional_allof_enum, Unset): + an_optional_allof_enum = (None, str(self.an_optional_allof_enum).encode(), "text/plain") + + nested_list_of_enums: Union[Unset, Tuple[None, bytes, str]] = UNSET + if not isinstance(self.nested_list_of_enums, Unset): + _temp_nested_list_of_enums = [] + for nested_list_of_enums_item_data in self.nested_list_of_enums: + nested_list_of_enums_item = [] + for nested_list_of_enums_item_item_data in nested_list_of_enums_item_data: + nested_list_of_enums_item_item: str = nested_list_of_enums_item_item_data + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + _temp_nested_list_of_enums.append(nested_list_of_enums_item) + nested_list_of_enums = (None, json.dumps(_temp_nested_list_of_enums).encode(), "application/json") + + field_dict: Dict[str, Any] = {} + for prop_name, prop in self.additional_properties.items(): + field_dict[prop_name] = (None, str(prop).encode(), "text/plain") + + field_dict.update({}) + if an_enum_value is not UNSET: + field_dict["an_enum_value"] = an_enum_value + if an_enum_value_with_null is not UNSET: + field_dict["an_enum_value_with_null"] = an_enum_value_with_null + if an_enum_value_with_only_null is not UNSET: + field_dict["an_enum_value_with_only_null"] = an_enum_value_with_only_null + if an_allof_enum_with_overridden_default is not UNSET: + field_dict["an_allof_enum_with_overridden_default"] = an_allof_enum_with_overridden_default + if an_optional_allof_enum is not UNSET: + field_dict["an_optional_allof_enum"] = an_optional_allof_enum + if nested_list_of_enums is not UNSET: + field_dict["nested_list_of_enums"] = nested_list_of_enums + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + an_enum_value = [] + _an_enum_value = d.pop("an_enum_value", UNSET) + for an_enum_value_item_data in _an_enum_value or []: + an_enum_value_item = check_an_enum(an_enum_value_item_data) + + an_enum_value.append(an_enum_value_item) + + an_enum_value_with_null = [] + _an_enum_value_with_null = d.pop("an_enum_value_with_null", UNSET) + for an_enum_value_with_null_item_data in _an_enum_value_with_null or []: + + def _parse_an_enum_value_with_null_item(data: object) -> Union[AnEnumWithNull, None]: + if data is None: + return data + try: + if not isinstance(data, str): + raise TypeError() + componentsschemas_an_enum_with_null_type_1 = check_an_enum_with_null(data) + + return componentsschemas_an_enum_with_null_type_1 + except: # noqa: E722 + pass + return cast(Union[AnEnumWithNull, None], data) + + an_enum_value_with_null_item = _parse_an_enum_value_with_null_item(an_enum_value_with_null_item_data) + + an_enum_value_with_null.append(an_enum_value_with_null_item) + + an_enum_value_with_only_null = cast(List[None], d.pop("an_enum_value_with_only_null", UNSET)) + + _an_allof_enum_with_overridden_default = d.pop("an_allof_enum_with_overridden_default", UNSET) + an_allof_enum_with_overridden_default: Union[Unset, AnAllOfEnum] + if isinstance(_an_allof_enum_with_overridden_default, Unset): + an_allof_enum_with_overridden_default = UNSET + else: + an_allof_enum_with_overridden_default = check_an_all_of_enum(_an_allof_enum_with_overridden_default) + + _an_optional_allof_enum = d.pop("an_optional_allof_enum", UNSET) + an_optional_allof_enum: Union[Unset, AnAllOfEnum] + if isinstance(_an_optional_allof_enum, Unset): + an_optional_allof_enum = UNSET + else: + an_optional_allof_enum = check_an_all_of_enum(_an_optional_allof_enum) + + nested_list_of_enums = [] + _nested_list_of_enums = d.pop("nested_list_of_enums", UNSET) + for nested_list_of_enums_item_data in _nested_list_of_enums or []: + nested_list_of_enums_item = [] + _nested_list_of_enums_item = nested_list_of_enums_item_data + for nested_list_of_enums_item_item_data in _nested_list_of_enums_item: + nested_list_of_enums_item_item = check_different_enum(nested_list_of_enums_item_item_data) + + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + nested_list_of_enums.append(nested_list_of_enums_item) + + post_user_list_body = cls( + an_enum_value=an_enum_value, + an_enum_value_with_null=an_enum_value_with_null, + an_enum_value_with_only_null=an_enum_value_with_only_null, + an_allof_enum_with_overridden_default=an_allof_enum_with_overridden_default, + an_optional_allof_enum=an_optional_allof_enum, + nested_list_of_enums=nested_list_of_enums, + ) + + post_user_list_body.additional_properties = d + return post_user_list_body + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/py.typed b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/py.typed new file mode 100644 index 000000000..1aad32711 --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 \ No newline at end of file diff --git a/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/types.py b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/types.py new file mode 100644 index 000000000..21fac106f --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/my_enum_api_client/types.py @@ -0,0 +1,45 @@ +"""Contains some shared types for properties""" + +from http import HTTPStatus +from typing import BinaryIO, Generic, Literal, MutableMapping, Optional, Tuple, TypeVar + +from attrs import define + + +class Unset: + def __bool__(self) -> Literal[False]: + return False + + +UNSET: Unset = Unset() + +FileJsonType = Tuple[Optional[str], BinaryIO, Optional[str]] + + +@define +class File: + """Contains information for file uploads""" + + payload: BinaryIO + file_name: Optional[str] = None + mime_type: Optional[str] = None + + def to_tuple(self) -> FileJsonType: + """Return a tuple representation that httpx will accept for multipart/form-data""" + return self.file_name, self.payload, self.mime_type + + +T = TypeVar("T") + + +@define +class Response(Generic[T]): + """A response from an endpoint""" + + status_code: HTTPStatus + content: bytes + headers: MutableMapping[str, str] + parsed: Optional[T] + + +__all__ = ["File", "Response", "FileJsonType", "Unset", "UNSET"] diff --git a/end_to_end_tests/literal-enums-golden-record/pyproject.toml b/end_to_end_tests/literal-enums-golden-record/pyproject.toml new file mode 100644 index 000000000..d32c2d72c --- /dev/null +++ b/end_to_end_tests/literal-enums-golden-record/pyproject.toml @@ -0,0 +1,27 @@ +[tool.poetry] +name = "my-enum-api-client" +version = "0.1.0" +description = "A client library for accessing My Enum API" +authors = [] +readme = "README.md" +packages = [ + {include = "my_enum_api_client"}, +] +include = ["CHANGELOG.md", "my_enum_api_client/py.typed"] + + +[tool.poetry.dependencies] +python = "^3.8" +httpx = ">=0.20.0,<0.28.0" +attrs = ">=21.3.0" +python-dateutil = "^2.8.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["F", "I", "UP"] diff --git a/end_to_end_tests/literal_enums.config.yml b/end_to_end_tests/literal_enums.config.yml new file mode 100644 index 000000000..120eae0a7 --- /dev/null +++ b/end_to_end_tests/literal_enums.config.yml @@ -0,0 +1 @@ +literal_enums: true diff --git a/end_to_end_tests/openapi_3.1_enums.yaml b/end_to_end_tests/openapi_3.1_enums.yaml new file mode 100644 index 000000000..b77d4ff74 --- /dev/null +++ b/end_to_end_tests/openapi_3.1_enums.yaml @@ -0,0 +1,226 @@ +openapi: 3.1.0 +info: + title: My Enum API + description: An API for testing enum handling in openapi-python-client + version: 0.1.0 +paths: + /tests/: + get: + tags: + - tests + summary: Get List + description: 'Get a list of things ' + operationId: getUserList + parameters: + - required: true + schema: + title: An Enum Value + type: array + items: + $ref: '#/components/schemas/AnEnum' + name: an_enum_value + in: query + - required: true + schema: + title: An Enum Value With Null And String Values + type: array + items: + $ref: '#/components/schemas/AnEnumWithNull' + name: an_enum_value_with_null + in: query + - required: true + schema: + title: An Enum Value With Only Null Values + type: array + items: + $ref: '#/components/schemas/AnEnumWithOnlyNull' + name: an_enum_value_with_only_null + in: query + - in: header + name: Int-Enum-Header + required: false + schema: + type: integer + enum: + - 1 + - 2 + - 3 + - in: header + name: String-Enum-Header + required: false + schema: + type: string + enum: + - one + - two + - three + responses: + '200': + description: Successful Response + content: + application/json: + schema: + title: Response Get List Tests Get + type: array + items: + $ref: '#/components/schemas/AModel' + post: + tags: + - tests + summary: Post List + description: 'Post a list of things ' + operationId: postUserList + requestBody: + content: + multipart/form-data: + schema: + type: object + properties: + an_enum_value: + title: An Enum Value + type: array + items: + $ref: '#/components/schemas/AnEnum' + an_enum_value_with_null: + title: An Enum Value With Null And String Values + type: array + items: + $ref: '#/components/schemas/AnEnumWithNull' + an_enum_value_with_only_null: + title: An Enum Value With Only Null Values + type: array + items: + $ref: '#/components/schemas/AnEnumWithOnlyNull' + an_allof_enum_with_overridden_default: + title: An AllOf Enum With Overridden Default + allOf: + - $ref: '#/components/schemas/AnAllOfEnum' + default: overridden_default + an_optional_allof_enum: + title: An Optional AllOf Enum + $ref: '#/components/schemas/AnAllOfEnum' + nested_list_of_enums: + title: Nested List Of Enums + type: array + items: + type: array + items: + $ref: '#/components/schemas/DifferentEnum' + default: [] + responses: + '200': + description: Successful Response + content: + application/json: + schema: + title: Response Get List Tests Get + type: array + items: + $ref: '#/components/schemas/AModel' + /enum/int: + post: + tags: + - enums + summary: Int Enum + operationId: int_enum_tests_int_enum_post + parameters: + - required: true + schema: + $ref: '#/components/schemas/AnIntEnum' + name: int_enum + in: query + responses: + '200': + description: Successful Response + content: + application/json: + schema: {} + /enum/bool: + post: + tags: + - enums + summary: Bool Enum + operationId: bool_enum_tests_bool_enum_post + parameters: + - required: true + schema: + type: boolean + enum: + - true + - false + name: bool_enum + in: query + responses: + '200': + description: Successful Response + content: + application/json: + schema: {} +components: + schemas: + AModel: + title: AModel + required: + - an_enum_value + - an_allof_enum_with_overridden_default + type: object + properties: + any_value: {} + an_enum_value: + $ref: '#/components/schemas/AnEnum' + an_allof_enum_with_overridden_default: + allOf: + - $ref: '#/components/schemas/AnAllOfEnum' + default: overridden_default + an_optional_allof_enum: + $ref: '#/components/schemas/AnAllOfEnum' + nested_list_of_enums: + title: Nested List Of Enums + type: array + items: + type: array + items: + $ref: '#/components/schemas/DifferentEnum' + default: [] + description: 'A Model for testing all the ways enums can be used ' + additionalProperties: false + AnEnum: + title: AnEnum + enum: + - FIRST_VALUE + - SECOND_VALUE + description: 'For testing Enums in all the ways they can be used ' + AnEnumWithNull: + title: AnEnumWithNull + enum: + - FIRST_VALUE + - SECOND_VALUE + - null + description: 'For testing Enums with mixed string / null values ' + AnEnumWithOnlyNull: + title: AnEnumWithOnlyNull + enum: + - null + description: 'For testing Enums with only null values ' + AnAllOfEnum: + title: AnAllOfEnum + enum: + - foo + - bar + - a_default + - overridden_default + default: a_default + AnIntEnum: + title: AnIntEnum + enum: + - -1 + - 1 + - 2 + type: integer + description: An enumeration. + DifferentEnum: + title: DifferentEnum + enum: + - DIFFERENT + - OTHER + description: An enumeration. diff --git a/end_to_end_tests/regen_golden_record.py b/end_to_end_tests/regen_golden_record.py index 0bffe132a..2471e1340 100644 --- a/end_to_end_tests/regen_golden_record.py +++ b/end_to_end_tests/regen_golden_record.py @@ -51,6 +51,26 @@ def regen_golden_record_3_1_features(): output_path.rename(gr_path) +def regen_literal_enums_golden_record(): + runner = CliRunner() + openapi_path = Path(__file__).parent / "openapi_3.1_enums.yaml" + + gr_path = Path(__file__).parent / "literal-enums-golden-record" + output_path = Path.cwd() / "my-enum-api-client" + config_path = Path(__file__).parent / "literal_enums.config.yml" + + shutil.rmtree(gr_path, ignore_errors=True) + shutil.rmtree(output_path, ignore_errors=True) + + result = runner.invoke(app, ["generate", f"--path={openapi_path}", f"--config={config_path}"]) + + if result.stdout: + print(result.stdout) + if result.exception: + raise result.exception + output_path.rename(gr_path) + + def regen_metadata_snapshots(): runner = CliRunner() openapi_path = Path(__file__).parent / "3.1_specific.openapi.yaml" @@ -124,3 +144,4 @@ def regen_custom_template_golden_record(): regen_golden_record_3_1_features() regen_metadata_snapshots() regen_custom_template_golden_record() + regen_literal_enums_golden_record() diff --git a/end_to_end_tests/test_end_to_end.py b/end_to_end_tests/test_end_to_end.py index 621d8ecc4..a448a0698 100644 --- a/end_to_end_tests/test_end_to_end.py +++ b/end_to_end_tests/test_end_to_end.py @@ -148,6 +148,17 @@ def test_3_1_specific_features(): ) +def test_literal_enums_end_to_end(): + config_path = Path(__file__).parent / "literal_enums.config.yml" + run_e2e_test( + "openapi_3.1_enums.yaml", + [f"--config={config_path}"], + {}, + "literal-enums-golden-record", + "my-enum-api-client" + ) + + @pytest.mark.parametrize( "meta,generated_file,expected_file", ( diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index 90bea54ee..f2cfb40ec 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -20,6 +20,7 @@ from .config import Config, MetaType from .parser import GeneratorData, import_string_from_class from .parser.errors import ErrorLevel, GeneratorError +from .parser.properties import LiteralEnumProperty __version__ = version(__package__) @@ -227,9 +228,12 @@ def _build_models(self) -> None: # Generate enums str_enum_template = self.env.get_template("str_enum.py.jinja") int_enum_template = self.env.get_template("int_enum.py.jinja") + literal_enum_template = self.env.get_template("literal_enum.py.jinja") for enum in self.openapi.enums: module_path = models_dir / f"{enum.class_info.module_name}.py" - if enum.value_type is int: + if isinstance(enum, LiteralEnumProperty): + module_path.write_text(literal_enum_template.render(enum=enum), encoding=self.config.file_encoding) + elif enum.value_type is int: module_path.write_text(int_enum_template.render(enum=enum), encoding=self.config.file_encoding) else: module_path.write_text(str_enum_template.render(enum=enum), encoding=self.config.file_encoding) diff --git a/openapi_python_client/config.py b/openapi_python_client/config.py index 740e06309..6625bda1f 100644 --- a/openapi_python_client/config.py +++ b/openapi_python_client/config.py @@ -43,6 +43,7 @@ class ConfigFile(BaseModel): post_hooks: Optional[List[str]] = None field_prefix: str = "field_" http_timeout: int = 5 + literal_enums: bool = False @staticmethod def load_from_path(path: Path) -> "ConfigFile": @@ -70,6 +71,7 @@ class Config: post_hooks: List[str] field_prefix: str http_timeout: int + literal_enums: bool document_source: Union[Path, str] file_encoding: str content_type_overrides: Dict[str, str] @@ -109,6 +111,7 @@ def from_sources( post_hooks=post_hooks, field_prefix=config_file.field_prefix, http_timeout=config_file.http_timeout, + literal_enums=config_file.literal_enums, document_source=document_source, file_encoding=file_encoding, overwrite=overwrite, diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index 9d62e1df5..acc8998cd 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -15,6 +15,7 @@ from .properties import ( Class, EnumProperty, + LiteralEnumProperty, ModelProperty, Parameters, Property, @@ -488,7 +489,7 @@ class GeneratorData: models: Iterator[ModelProperty] errors: List[ParseError] endpoint_collections_by_tag: Dict[utils.PythonIdentifier, EndpointCollection] - enums: Iterator[EnumProperty] + enums: Iterator[Union[EnumProperty, LiteralEnumProperty]] @staticmethod def from_dict(data: Dict[str, Any], *, config: Config) -> Union["GeneratorData", GeneratorError]: @@ -517,7 +518,9 @@ def from_dict(data: Dict[str, Any], *, config: Config) -> Union["GeneratorData", data=openapi.paths, schemas=schemas, parameters=parameters, request_bodies=request_bodies, config=config ) - enums = (prop for prop in schemas.classes_by_name.values() if isinstance(prop, EnumProperty)) + enums = ( + prop for prop in schemas.classes_by_name.values() if isinstance(prop, (EnumProperty, LiteralEnumProperty)) + ) models = (prop for prop in schemas.classes_by_name.values() if isinstance(prop, ModelProperty)) return GeneratorData( diff --git a/openapi_python_client/parser/properties/__init__.py b/openapi_python_client/parser/properties/__init__.py index c1e94c3c8..94c6e3d08 100644 --- a/openapi_python_client/parser/properties/__init__.py +++ b/openapi_python_client/parser/properties/__init__.py @@ -4,6 +4,7 @@ "AnyProperty", "Class", "EnumProperty", + "LiteralEnumProperty", "ModelProperty", "Parameters", "Property", @@ -30,6 +31,7 @@ from .float import FloatProperty from .int import IntProperty from .list_property import ListProperty +from .literal_enum_property import LiteralEnumProperty from .model_property import ModelProperty, process_model from .none import NoneProperty from .property import Property @@ -194,6 +196,15 @@ def property_from_data( # noqa: PLR0911, PLR0912 schemas, ) if data.enum: + if config.literal_enums: + return LiteralEnumProperty.build( + data=data, + name=name, + required=required, + schemas=schemas, + parent_name=parent_name, + config=config, + ) return EnumProperty.build( data=data, name=name, diff --git a/openapi_python_client/parser/properties/enum_property.py b/openapi_python_client/parser/properties/enum_property.py index b6a27254f..29609864f 100644 --- a/openapi_python_client/parser/properties/enum_property.py +++ b/openapi_python_client/parser/properties/enum_property.py @@ -121,7 +121,7 @@ def build( # noqa: PLR0911 if parent_name: class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}" class_info = Class.from_string(string=class_name, config=config) - values = EnumProperty.values_from_list(value_list) + values = EnumProperty.values_from_list(value_list, class_info) if class_info.name in schemas.classes_by_name: existing = schemas.classes_by_name[class_info.name] @@ -183,7 +183,7 @@ def get_imports(self, *, prefix: str) -> set[str]: return imports @staticmethod - def values_from_list(values: list[str] | list[int]) -> dict[str, ValueType]: + def values_from_list(values: list[str] | list[int], class_info: Class) -> dict[str, ValueType]: """Convert a list of values into dict of {name: value}, where value can sometimes be None""" output: dict[str, ValueType] = {} @@ -200,7 +200,10 @@ def values_from_list(values: list[str] | list[int]) -> dict[str, ValueType]: else: key = f"VALUE_{i}" if key in output: - raise ValueError(f"Duplicate key {key} in Enum") + raise ValueError( + f"Duplicate key {key} in enum {class_info.module_name}.{class_info.name}; " + f"consider setting literal_enums in your config" + ) sanitized_key = utils.snake_case(key).upper() output[sanitized_key] = utils.remove_string_escapes(value) return output diff --git a/openapi_python_client/parser/properties/literal_enum_property.py b/openapi_python_client/parser/properties/literal_enum_property.py new file mode 100644 index 000000000..c305a9a41 --- /dev/null +++ b/openapi_python_client/parser/properties/literal_enum_property.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +__all__ = ["LiteralEnumProperty"] + +from typing import Any, ClassVar, List, Union, cast + +from attr import evolve +from attrs import define + +from ... import Config, utils +from ... import schema as oai +from ...schema import DataType +from ..errors import PropertyError +from .none import NoneProperty +from .protocol import PropertyProtocol, Value +from .schemas import Class, Schemas +from .union import UnionProperty + +ValueType = Union[str, int] + + +@define +class LiteralEnumProperty(PropertyProtocol): + """A property that should use a literal enum""" + + name: str + required: bool + default: Value | None + python_name: utils.PythonIdentifier + description: str | None + example: str | None + values: set[ValueType] + class_info: Class + value_type: type[ValueType] + + template: ClassVar[str] = "literal_enum_property.py.jinja" + + _allowed_locations: ClassVar[set[oai.ParameterLocation]] = { + oai.ParameterLocation.QUERY, + oai.ParameterLocation.PATH, + oai.ParameterLocation.COOKIE, + oai.ParameterLocation.HEADER, + } + + @classmethod + def build( # noqa: PLR0911 + cls, + *, + data: oai.Schema, + name: str, + required: bool, + schemas: Schemas, + parent_name: str, + config: Config, + ) -> tuple[LiteralEnumProperty | NoneProperty | UnionProperty | PropertyError, Schemas]: + """ + Create a LiteralEnumProperty from schema data. + + Args: + data: The OpenAPI Schema which defines this enum. + name: The name to use for variables which receive this Enum's value (e.g. model property name) + required: Whether or not this Property is required in the calling context + schemas: The Schemas which have been defined so far (used to prevent naming collisions) + parent_name: The context in which this LiteralEnumProperty is defined, used to create more specific class names. + config: The global config for this run of the generator + + Returns: + A tuple containing either the created property or a PropertyError AND update schemas. + """ + + enum = data.enum or [] # The outer function checks for this, but mypy doesn't know that + + # OpenAPI allows for null as an enum value, but it doesn't make sense with how enums are constructed in Python. + # So instead, if null is a possible value, make the property nullable. + # Mypy is not smart enough to know that the type is right though + unchecked_value_list = [value for value in enum if value is not None] # type: ignore + + # It's legal to have an enum that only contains null as a value, we don't bother constructing an enum for that + if len(unchecked_value_list) == 0: + return ( + NoneProperty.build( + name=name, + required=required, + default="None", + python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix), + description=None, + example=None, + ), + schemas, + ) + + value_types = {type(value) for value in unchecked_value_list} + if len(value_types) > 1: + return PropertyError( + header="Enum values must all be the same type", detail=f"Got {value_types}", data=data + ), schemas + value_type = next(iter(value_types)) + if value_type not in (str, int): + return PropertyError(header=f"Unsupported enum type {value_type}", data=data), schemas + value_list = cast( + Union[List[int], List[str]], unchecked_value_list + ) # We checked this with all the value_types stuff + + if len(value_list) < len(enum): # Only one of the values was None, that becomes a union + data.oneOf = [ + oai.Schema(type=DataType.NULL), + data.model_copy(update={"enum": value_list, "default": data.default}), + ] + data.enum = None + return UnionProperty.build( + data=data, + name=name, + required=required, + schemas=schemas, + parent_name=parent_name, + config=config, + ) + + class_name = data.title or name + if parent_name: + class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}" + class_info = Class.from_string(string=class_name, config=config) + values: set[str | int] = set(value_list) + + if class_info.name in schemas.classes_by_name: + existing = schemas.classes_by_name[class_info.name] + if not isinstance(existing, LiteralEnumProperty) or values != existing.values: + return ( + PropertyError( + detail=f"Found conflicting enums named {class_info.name} with incompatible values.", data=data + ), + schemas, + ) + + prop = LiteralEnumProperty( + name=name, + required=required, + class_info=class_info, + values=values, + value_type=value_type, + default=None, + python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix), + description=data.description, + example=data.example, + ) + checked_default = prop.convert_value(data.default) + if isinstance(checked_default, PropertyError): + checked_default.data = data + return checked_default, schemas + prop = evolve(prop, default=checked_default) + + schemas = evolve(schemas, classes_by_name={**schemas.classes_by_name, class_info.name: prop}) + return prop, schemas + + def convert_value(self, value: Any) -> Value | PropertyError | None: + if value is None or isinstance(value, Value): + return value + if isinstance(value, self.value_type): + if value in self.values: + return Value(python_code=repr(value), raw_value=value) + else: + return PropertyError(detail=f"Value {value} is not valid for enum {self.name}") + return PropertyError(detail=f"Cannot convert {value} to enum {self.name} of type {self.value_type}") + + def get_base_type_string(self, *, quoted: bool = False) -> str: + return self.class_info.name + + def get_base_json_type_string(self, *, quoted: bool = False) -> str: + return self.value_type.__name__ + + def get_instance_type_string(self) -> str: + return self.value_type.__name__ + + def get_imports(self, *, prefix: str) -> set[str]: + """ + Get a set of import strings that should be included when this property is used somewhere + + Args: + prefix: A prefix to put before any relative (local) module names. This should be the number of . to get + back to the root of the generated client. + """ + imports = super().get_imports(prefix=prefix) + imports.add("from typing import cast") + imports.add(f"from {prefix}models.{self.class_info.module_name} import {self.class_info.name}") + imports.add( + f"from {prefix}models.{self.class_info.module_name} import check_{self.get_class_name_snake_case()}" + ) + return imports + + def get_class_name_snake_case(self) -> str: + return utils.snake_case(self.class_info.name) diff --git a/openapi_python_client/parser/properties/merge_properties.py b/openapi_python_client/parser/properties/merge_properties.py index dc7b3e5eb..db6424a7c 100644 --- a/openapi_python_client/parser/properties/merge_properties.py +++ b/openapi_python_client/parser/properties/merge_properties.py @@ -3,6 +3,7 @@ from openapi_python_client.parser.properties.date import DateProperty from openapi_python_client.parser.properties.datetime import DateTimeProperty from openapi_python_client.parser.properties.file import FileProperty +from openapi_python_client.parser.properties.literal_enum_property import LiteralEnumProperty __all__ = ["merge_properties"] @@ -53,6 +54,9 @@ def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyErr if isinstance(prop1, EnumProperty) or isinstance(prop2, EnumProperty): return _merge_with_enum(prop1, prop2) + if isinstance(prop1, LiteralEnumProperty) or isinstance(prop2, LiteralEnumProperty): + return _merge_with_literal_enum(prop1, prop2) + if (merged := _merge_same_type(prop1, prop2)) is not None: return merged @@ -136,6 +140,32 @@ def _merge_with_enum(prop1: PropertyProtocol, prop2: PropertyProtocol) -> EnumPr ) +def _merge_with_literal_enum(prop1: PropertyProtocol, prop2: PropertyProtocol) -> LiteralEnumProperty | PropertyError: + if isinstance(prop1, LiteralEnumProperty) and isinstance(prop2, LiteralEnumProperty): + # We want the narrowest validation rules that fit both, so use whichever values list is a + # subset of the other. + if prop1.values <= prop2.values: + values = prop1.values + class_info = prop1.class_info + elif prop2.values <= prop1.values: + values = prop2.values + class_info = prop2.class_info + else: + return PropertyError(detail="can't redefine a literal enum property with incompatible lists of values") + return _merge_common_attributes(evolve(prop1, values=values, class_info=class_info), prop2) + + # If enum values were specified for just one of the properties, use those. + enum_prop = prop1 if isinstance(prop1, LiteralEnumProperty) else cast(LiteralEnumProperty, prop2) + non_enum_prop = prop2 if isinstance(prop1, LiteralEnumProperty) else prop1 + if (isinstance(non_enum_prop, IntProperty) and enum_prop.value_type is int) or ( + isinstance(non_enum_prop, StringProperty) and enum_prop.value_type is str + ): + return _merge_common_attributes(enum_prop, prop1, prop2) + return PropertyError( + detail=f"can't combine literal enum of type {enum_prop.value_type} with {non_enum_prop.get_type_string(no_optional=True)}" + ) + + def _merge_common_attributes(base: PropertyT, *extend_with: PropertyProtocol) -> PropertyT | PropertyError: """Create a new instance based on base, overriding basic attributes with values from extend_with, in order. diff --git a/openapi_python_client/parser/properties/property.py b/openapi_python_client/parser/properties/property.py index aeac32a7f..6e73a01ae 100644 --- a/openapi_python_client/parser/properties/property.py +++ b/openapi_python_client/parser/properties/property.py @@ -14,6 +14,7 @@ from .float import FloatProperty from .int import IntProperty from .list_property import ListProperty +from .literal_enum_property import LiteralEnumProperty from .model_property import ModelProperty from .none import NoneProperty from .string import StringProperty @@ -27,6 +28,7 @@ DateProperty, DateTimeProperty, EnumProperty, + LiteralEnumProperty, FileProperty, FloatProperty, IntProperty, diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index 4db1c3546..c6d79b9a7 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -7,7 +7,7 @@ from ...client import AuthenticatedClient, Client from ...types import Response, UNSET from ... import errors -{% for relative in endpoint.relative_imports %} +{% for relative in endpoint.relative_imports | sort %} {{ relative }} {% endfor %} diff --git a/openapi_python_client/templates/literal_enum.py.jinja b/openapi_python_client/templates/literal_enum.py.jinja new file mode 100644 index 000000000..df993adb7 --- /dev/null +++ b/openapi_python_client/templates/literal_enum.py.jinja @@ -0,0 +1,10 @@ +from typing import Literal, Set, cast + +{{ enum.class_info.name }} = Literal{{ "%r" | format(enum.values|list|sort) }} + +{{ enum.get_class_name_snake_case() | upper }}_VALUES: Set[{{ enum.class_info.name }}] = { {% for v in enum.values|list|sort %}{{"%r"|format(v)}}, {% endfor %} } + +def check_{{ enum.get_class_name_snake_case() }}(value: {{ enum.get_instance_type_string() }}) -> {{ enum.class_info.name}}: + if value in {{ enum.get_class_name_snake_case() | upper }}_VALUES: + return cast({{enum.class_info.name}}, value) + raise TypeError(f"Unexpected value {value!r}. Expected one of {{"{"}}{{ enum.get_class_name_snake_case() | upper }}_VALUES!r}") diff --git a/openapi_python_client/templates/model.py.jinja b/openapi_python_client/templates/model.py.jinja index 2d22efe05..012201426 100644 --- a/openapi_python_client/templates/model.py.jinja +++ b/openapi_python_client/templates/model.py.jinja @@ -13,7 +13,7 @@ import json from ..types import UNSET, Unset -{% for relative in model.relative_imports %} +{% for relative in model.relative_imports | sort %} {{ relative }} {% endfor %} diff --git a/openapi_python_client/templates/property_templates/literal_enum_property.py.jinja b/openapi_python_client/templates/property_templates/literal_enum_property.py.jinja new file mode 100644 index 000000000..680ebfabe --- /dev/null +++ b/openapi_python_client/templates/property_templates/literal_enum_property.py.jinja @@ -0,0 +1,38 @@ +{% macro construct_function(property, source) %} +check_{{ property.get_class_name_snake_case() }}({{ source }}) +{% endmacro %} + +{% from "property_templates/property_macros.py.jinja" import construct_template %} + +{% macro construct(property, source) %} +{{ construct_template(construct_function, property, source) }} +{% endmacro %} + +{% macro check_type_for_construct(property, source) %}isinstance({{ source }}, {{ property.get_instance_type_string() }}){% endmacro %} + +{% macro transform(property, source, destination, declare_type=True, multipart=False) %} +{% set type_string = property.get_type_string(json=True) %} +{% if property.required %} +{{ destination }}{% if declare_type %}: {{ type_string }}{% endif %} = {{ source }} +{%- else %} +{{ destination }}{% if declare_type %}: {{ type_string }}{% endif %} = UNSET +if not isinstance({{ source }}, Unset): + {{ destination }} = {{ source }} +{% endif %} +{% endmacro %} + +{% macro transform_multipart(property, source, destination) %} +{% set transformed = "(None, str(" + source + ").encode(), \"text/plain\")" %} +{% set type_string = "Union[Unset, Tuple[None, bytes, str]]" %} +{% if property.required %} +{{ destination }} = {{ transformed }} +{%- else %} +{{ destination }}: {{ type_string }} = UNSET +if not isinstance({{ source }}, Unset): + {{ destination }} = {{ transformed }} +{% endif %} +{% endmacro %} + +{% macro transform_header(source) %} +str({{ source }}) +{% endmacro %} diff --git a/tests/conftest.py b/tests/conftest.py index c01b4ce87..969e57cbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ FileProperty, IntProperty, ListProperty, + LiteralEnumProperty, ModelProperty, NoneProperty, StringProperty, @@ -106,7 +107,7 @@ def __call__( ) -> PropertyType: ... -class EnumFactory(Protocol): +class EnumFactory(Protocol[PropertyType]): def __call__( self, *, @@ -119,11 +120,11 @@ def __call__( python_name: PythonIdentifier | None = None, description: str | None = None, example: str | None = None, - ) -> EnumProperty: ... + ) -> PropertyType: ... @pytest.fixture -def enum_property_factory() -> EnumFactory: +def enum_property_factory() -> EnumFactory[EnumProperty]: """ This fixture surfaces in the test as a function which manufactures EnumProperties with defaults. @@ -141,6 +142,25 @@ def enum_property_factory() -> EnumFactory: ) +@pytest.fixture +def literal_enum_property_factory() -> EnumFactory[LiteralEnumProperty]: + """ + This fixture surfaces in the test as a function which manufactures LiteralEnumProperties with defaults. + + You can pass the same params into this as the LiteralEnumProerty constructor to override defaults. + """ + from openapi_python_client.parser.properties import Class + + return _simple_factory( + LiteralEnumProperty, + lambda kwargs: { + "class_info": Class(name=kwargs["name"], module_name=kwargs["name"]), + "values": set(), + "value_type": str, + }, + ) + + @pytest.fixture def any_property_factory() -> SimpleFactory[AnyProperty]: """ diff --git a/tests/test_parser/test_properties/test_enum_property.py b/tests/test_parser/test_properties/test_enum_property.py index dce27ce10..21b183f8e 100644 --- a/tests/test_parser/test_properties/test_enum_property.py +++ b/tests/test_parser/test_properties/test_enum_property.py @@ -1,16 +1,28 @@ +from typing import Type, Union + +import pytest + import openapi_python_client.schema as oai from openapi_python_client import Config from openapi_python_client.parser.errors import PropertyError -from openapi_python_client.parser.properties import EnumProperty, Schemas +from openapi_python_client.parser.properties import LiteralEnumProperty, Schemas +from openapi_python_client.parser.properties.enum_property import EnumProperty + +PropertyClass = Union[Type[EnumProperty], Type[LiteralEnumProperty]] + + +@pytest.fixture(params=[EnumProperty, LiteralEnumProperty]) +def property_class(request) -> PropertyClass: + return request.param -def test_conflict(config: Config) -> None: +def test_conflict(config: Config, property_class: PropertyClass) -> None: schemas = Schemas() - _, schemas = EnumProperty.build( + _, schemas = property_class.build( data=oai.Schema(enum=["a"]), name="Existing", required=True, schemas=schemas, parent_name="", config=config ) - err, new_schemas = EnumProperty.build( + err, new_schemas = property_class.build( data=oai.Schema(enum=["a", "b"]), name="Existing", required=True, @@ -23,11 +35,11 @@ def test_conflict(config: Config) -> None: assert err.detail == "Found conflicting enums named Existing with incompatible values." -def test_bad_default_value(config: Config) -> None: +def test_bad_default_value(config: Config, property_class: PropertyClass) -> None: data = oai.Schema(default="B", enum=["A"]) schemas = Schemas() - err, new_schemas = EnumProperty.build( + err, new_schemas = property_class.build( data=data, name="Existing", required=True, schemas=schemas, parent_name="parent", config=config ) @@ -35,11 +47,11 @@ def test_bad_default_value(config: Config) -> None: assert err == PropertyError(detail="Value B is not valid for enum Existing", data=data) -def test_bad_default_type(config: Config) -> None: +def test_bad_default_type(config: Config, property_class: PropertyClass) -> None: data = oai.Schema(default=123, enum=["A"]) schemas = Schemas() - err, new_schemas = EnumProperty.build( + err, new_schemas = property_class.build( data=data, name="Existing", required=True, schemas=schemas, parent_name="parent", config=config ) @@ -47,22 +59,22 @@ def test_bad_default_type(config: Config) -> None: assert isinstance(err, PropertyError) -def test_mixed_types(config: Config) -> None: +def test_mixed_types(config: Config, property_class: PropertyClass) -> None: data = oai.Schema(enum=["A", 1]) schemas = Schemas() - err, _ = EnumProperty.build( + err, _ = property_class.build( data=data, name="Enum", required=True, schemas=schemas, parent_name="parent", config=config ) assert isinstance(err, PropertyError) -def test_unsupported_type(config: Config) -> None: +def test_unsupported_type(config: Config, property_class: PropertyClass) -> None: data = oai.Schema(enum=[1.4, 1.5]) schemas = Schemas() - err, _ = EnumProperty.build( + err, _ = property_class.build( data=data, name="Enum", required=True, schemas=schemas, parent_name="parent", config=config ) diff --git a/tests/test_parser/test_properties/test_init.py b/tests/test_parser/test_properties/test_init.py index 5bfd8bc41..f56bf065d 100644 --- a/tests/test_parser/test_properties/test_init.py +++ b/tests/test_parser/test_properties/test_init.py @@ -13,6 +13,7 @@ UnionProperty, ) from openapi_python_client.parser.properties.protocol import ModelProperty, Value +from openapi_python_client.parser.properties.schemas import Class from openapi_python_client.schema import DataType from openapi_python_client.utils import ClassName, PythonIdentifier @@ -352,7 +353,7 @@ def test_values_from_list(self): data = ["abc", "123", "a23", "1bc", 4, -3, "a Thing WIth spaces", ""] - result = EnumProperty.values_from_list(data) + result = EnumProperty.values_from_list(data, Class("ClassName", "module_name")) assert result == { "ABC": "abc", @@ -371,7 +372,44 @@ def test_values_from_list_duplicate(self): data = ["abc", "123", "a23", "abc"] with pytest.raises(ValueError): - EnumProperty.values_from_list(data) + EnumProperty.values_from_list(data, Class("ClassName", "module_name")) + + +class TestLiteralEnumProperty: + def test_is_base_type(self, literal_enum_property_factory): + assert literal_enum_property_factory().is_base_type is True + + @pytest.mark.parametrize( + "required, expected", + ( + (False, "Union[Unset, {}]"), + (True, "{}"), + ), + ) + def test_get_type_string(self, mocker, literal_enum_property_factory, required, expected): + fake_class = mocker.MagicMock() + fake_class.name = "MyTestEnum" + + p = literal_enum_property_factory(class_info=fake_class, required=required) + + assert p.get_type_string() == expected.format(fake_class.name) + assert p.get_type_string(no_optional=True) == fake_class.name + assert p.get_type_string(json=True) == expected.format("str") + + def test_get_imports(self, mocker, literal_enum_property_factory): + fake_class = mocker.MagicMock(module_name="my_test_enum") + fake_class.name = "MyTestEnum" + prefix = "..." + + literal_enum_property = literal_enum_property_factory(class_info=fake_class, required=False) + + assert literal_enum_property.get_imports(prefix=prefix) == { + "from typing import cast", + f"from {prefix}models.{fake_class.module_name} import {fake_class.name}", + f"from {prefix}models.{fake_class.module_name} import check_my_test_enum", + "from typing import Union", # Makes sure unset is handled via base class + "from ...types import UNSET, Unset", + } class TestPropertyFromData: diff --git a/tests/test_parser/test_properties/test_merge_properties.py b/tests/test_parser/test_properties/test_merge_properties.py index 12ddb79fa..819f9ec26 100644 --- a/tests/test_parser/test_properties/test_merge_properties.py +++ b/tests/test_parser/test_properties/test_merge_properties.py @@ -1,5 +1,6 @@ from itertools import permutations +import pytest from attr import evolve from openapi_python_client.parser.errors import PropertyError @@ -104,17 +105,31 @@ def test_merge_with_any( assert merge_properties(prop, any_prop) == prop -def test_merge_enums(enum_property_factory, config): - enum_with_fewer_values = enum_property_factory( - description="desc1", - values={"A": "A", "B": "B"}, - value_type=str, - ) - enum_with_more_values = enum_property_factory( - example="example2", - values={"A": "A", "B": "B", "C": "C"}, - value_type=str, - ) +@pytest.mark.parametrize("literal_enums", (False, True)) +def test_merge_enums(literal_enums, enum_property_factory, literal_enum_property_factory, config): + if literal_enums: + enum_with_fewer_values = literal_enum_property_factory( + description="desc1", + values={"A", "B"}, + value_type=str, + ) + enum_with_more_values = literal_enum_property_factory( + example="example2", + values={"A", "B", "C"}, + value_type=str, + ) + else: + enum_with_fewer_values = enum_property_factory( + description="desc1", + values={"A": "A", "B": "B"}, + value_type=str, + ) + enum_with_more_values = enum_property_factory( + example="example2", + values={"A": "A", "B": "B", "C": "C"}, + value_type=str, + ) + # Setting class_info separately because it doesn't get initialized by the constructor - we want # to make sure the right enum class name gets used in the merged property enum_with_fewer_values.class_info = Class.from_string(string="FewerValuesEnum", config=config) @@ -132,36 +147,60 @@ def test_merge_enums(enum_property_factory, config): ) -def test_merge_string_with_string_enum(string_property_factory, enum_property_factory): - values = {"A": "A", "B": "B"} +@pytest.mark.parametrize("literal_enums", (False, True)) +def test_merge_string_with_string_enum( + literal_enums, string_property_factory, enum_property_factory, literal_enum_property_factory +): string_prop = string_property_factory(default=Value("A", "A"), description="desc1", example="example1") - enum_prop = enum_property_factory( - default=Value("test.B", "B"), - description="desc2", - example="example2", - values=values, - value_type=str, + enum_prop = ( + literal_enum_property_factory( + default=Value("'B'", "B"), + description="desc2", + example="example2", + values={"A", "B"}, + value_type=str, + ) + if literal_enums + else enum_property_factory( + default=Value("test.B", "B"), + description="desc2", + example="example2", + values={"A": "A", "B": "B"}, + value_type=str, + ) ) assert merge_properties(string_prop, enum_prop) == evolve(enum_prop, required=True) assert merge_properties(enum_prop, string_prop) == evolve( enum_prop, required=True, - default=Value("test.A", "A"), + default=Value("'A'" if literal_enums else "test.A", "A"), description=string_prop.description, example=string_prop.example, ) -def test_merge_int_with_int_enum(int_property_factory, enum_property_factory): - values = {"VALUE_1": 1, "VALUE_2": 2} +@pytest.mark.parametrize("literal_enums", (False, True)) +def test_merge_int_with_int_enum( + literal_enums, int_property_factory, enum_property_factory, literal_enum_property_factory +): int_prop = int_property_factory(default=Value("1", 1), description="desc1", example="example1") - enum_prop = enum_property_factory( - default=Value("test.VALUE_1", 1), - description="desc2", - example="example2", - values=values, - value_type=int, + enum_prop = ( + literal_enum_property_factory( + default=Value("1", 1), + description="desc2", + example="example2", + values={1, 2}, + value_type=int, + ) + if literal_enums + else enum_property_factory( + default=Value("test.VALUE_1", 1), + description="desc2", + example="example2", + values={"VALUE_1": 1, "VALUE_2": 2}, + value_type=int, + ) ) assert merge_properties(int_prop, enum_prop) == evolve(enum_prop, required=True) @@ -170,12 +209,15 @@ def test_merge_int_with_int_enum(int_property_factory, enum_property_factory): ) +@pytest.mark.parametrize("literal_enums", (False, True)) def test_merge_with_incompatible_enum( + literal_enums, boolean_property_factory, int_property_factory, float_property_factory, string_property_factory, enum_property_factory, + literal_enum_property_factory, model_property_factory, ): props = [ @@ -184,9 +226,19 @@ def test_merge_with_incompatible_enum( float_property_factory(), string_property_factory(), model_property_factory(), + enum_property_factory(values={"INCOMPATIBLE": "INCOMPATIBLE"}), + literal_enum_property_factory(values={"INCOMPATIBLE"}), ] - string_enum_prop = enum_property_factory(value_type=str) - int_enum_prop = enum_property_factory(value_type=int) + string_enum_prop = ( + literal_enum_property_factory(value_type=str, values={"A"}) + if literal_enums + else enum_property_factory(value_type=str, values={"A": "A"}) + ) + int_enum_prop = ( + literal_enum_property_factory(value_type=int, values={1}) + if literal_enums + else enum_property_factory(value_type=int, values={"VALUE_1": 1}) + ) for prop in props: if not isinstance(prop, StringProperty): assert isinstance(merge_properties(prop, string_enum_prop), PropertyError)