Skip to content

Commit

Permalink
fix: avoid using default mutable parameters
Browse files Browse the repository at this point in the history
Using mutable default arguments is a common Python problem, see e.g.
https://docs.python-guide.org/writing/gotchas/@mutable-default-arguments

In this specific case the default argument even tries to setup some
infrastructure settings at import time, which can potentially fail.
  • Loading branch information
marlamb committed Jul 9, 2024
1 parent a19e3b1 commit 9fa773a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
14 changes: 10 additions & 4 deletions src/msgraph_core/base_graph_request_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SerializationWriterFactoryRegistry,
)
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from typing import Optional

from .graph_client_factory import GraphClientFactory

Expand All @@ -16,11 +17,16 @@ class BaseGraphRequestAdapter(HttpxRequestAdapter):
def __init__(
self,
authentication_provider: AuthenticationProvider,
parse_node_factory: ParseNodeFactory = ParseNodeFactoryRegistry(),
serialization_writer_factory:
SerializationWriterFactory = SerializationWriterFactoryRegistry(),
http_client: httpx.AsyncClient = GraphClientFactory.create_with_default_middleware()
parse_node_factory: Optional[ParseNodeFactory] = None,
serialization_writer_factory: Optional[SerializationWriterFactory] = None,
http_client: Optional[httpx.AsyncClient] = None
) -> None:
if parse_node_factory is None:
parse_node_factory = ParseNodeFactoryRegistry()
if serialization_writer_factory is None:
serialization_writer_factory = SerializationWriterFactoryRegistry()
if http_client is None:
http_client = GraphClientFactory.create_with_default_middleware()
super().__init__(
authentication_provider=authentication_provider,
parse_node_factory=parse_node_factory,
Expand Down
8 changes: 6 additions & 2 deletions src/msgraph_core/graph_client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class GraphClientFactory(KiotaClientFactory):
@staticmethod
def create_with_default_middleware(
api_version: APIVersion = APIVersion.v1,
client: httpx.AsyncClient = KiotaClientFactory.get_default_client(),
client: Optional[httpx.AsyncClient] = None,
host: NationalClouds = NationalClouds.Global,
options: Optional[Dict[str, RequestOption]] = None
) -> httpx.AsyncClient:
Expand All @@ -44,6 +44,8 @@ def create_with_default_middleware(
Returns:
httpx.AsyncClient: An instance of the AsyncClient object
"""
if client is None:
client = KiotaClientFactory.get_default_client()
client.base_url = GraphClientFactory._get_base_url(host, api_version) # type: ignore
middleware = KiotaClientFactory.get_default_middleware(options)
telemetry_handler = GraphClientFactory._get_telemetry_handler(options)
Expand All @@ -54,7 +56,7 @@ def create_with_default_middleware(
def create_with_custom_middleware(
middleware: Optional[List[BaseMiddleware]],
api_version: APIVersion = APIVersion.v1,
client: httpx.AsyncClient = KiotaClientFactory.get_default_client(),
client: Optional[httpx.AsyncClient] = None,
host: NationalClouds = NationalClouds.Global,
) -> httpx.AsyncClient:
"""Applies a custom middleware chain to the HTTP Client
Expand All @@ -70,6 +72,8 @@ def create_with_custom_middleware(
host (NationalClouds): The national clound endpoint to be used.
Defaults to NationalClouds.Global.
"""
if client is None:
client = KiotaClientFactory.get_default_client()
client.base_url = GraphClientFactory._get_base_url(host, api_version) # type: ignore
return GraphClientFactory._load_middleware_to_client(client, middleware)

Expand Down

0 comments on commit 9fa773a

Please sign in to comment.