Skip to content

Commit

Permalink
Don't share API objects between event loops (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson authored Apr 19, 2024
1 parent 557c70b commit fed84dd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
12 changes: 9 additions & 3 deletions kr8s/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD 3-Clause License
from __future__ import annotations

import asyncio
import contextlib
import json
import ssl
Expand Down Expand Up @@ -56,9 +57,14 @@ def __init__(self, **kwargs) -> None:
context=kwargs.get("context"),
)
thread_id = threading.get_ident()
if thread_id not in Api._instances:
Api._instances[thread_id] = weakref.WeakValueDictionary()
Api._instances[thread_id][frozenset(kwargs.items())] = self
try:
loop_id = id(asyncio.get_running_loop())
except RuntimeError:
loop_id = 0
thread_loop_id = f"{thread_id}.{loop_id}"
if thread_loop_id not in Api._instances:
Api._instances[thread_loop_id] = weakref.WeakValueDictionary()
Api._instances[thread_loop_id][frozenset(kwargs.items())] = self

def __await__(self):
async def f():
Expand Down
18 changes: 12 additions & 6 deletions kr8s/asyncio/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024, Kr8s Developers (See LICENSE for list)
# SPDX-License-Identifier: BSD 3-Clause License
import asyncio
import threading

from kr8s._api import Api as _AsyncApi
Expand Down Expand Up @@ -53,18 +54,23 @@ async def api(
async def _f(**kwargs):
key = frozenset(kwargs.items())
thread_id = threading.get_ident()
try:
loop_id = id(asyncio.get_running_loop())
except RuntimeError:
loop_id = 0
thread_loop_id = f"{thread_id}.{loop_id}"
if (
_cls._instances
and thread_id in _cls._instances
and key in _cls._instances[thread_id]
and thread_loop_id in _cls._instances
and key in _cls._instances[thread_loop_id]
):
return await _cls._instances[thread_id][key]
return await _cls._instances[thread_loop_id][key]
if (
all(k is None for k in kwargs.values())
and thread_id in _cls._instances
and list(_cls._instances[thread_id].values())
and thread_loop_id in _cls._instances
and list(_cls._instances[thread_loop_id].values())
):
return await list(_cls._instances[thread_id].values())[0]
return await list(_cls._instances[thread_loop_id].values())[0]
return await _cls(**kwargs, bypass_factory=True)

return await _f(
Expand Down
11 changes: 11 additions & 0 deletions kr8s/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ async def create_api(q):
assert type(k1) is type(k2)


def test_api_factory_multi_event_loop():
assert len(kr8s.Api._instances) == 0

async def create_api():
return await kr8s.asyncio.api()

k1 = anyio.run(create_api)
k2 = anyio.run(create_api)
assert k1 is not k2


async def test_api_factory_with_kubeconfig(k8s_cluster, serviceaccount):
k1 = await kr8s.asyncio.api(kubeconfig=k8s_cluster.kubeconfig_path)
k2 = await kr8s.asyncio.api(serviceaccount=serviceaccount)
Expand Down

0 comments on commit fed84dd

Please sign in to comment.