Skip to content

Commit

Permalink
Ensure portfoward's local_port keyword follows kubectl behavior (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson authored Oct 14, 2024
1 parent b73f4c7 commit 25d74f4
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 20 deletions.
25 changes: 23 additions & 2 deletions kr8s/_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from kr8s._exec import Exec
from kr8s._types import SpecType, SupportsKeysAndGetItem
from kr8s.asyncio.portforward import PortForward as AsyncPortForward
from kr8s.portforward import LocalPortType
from kr8s.portforward import PortForward as SyncPortForward

JSONPATH_CONDITION_EXPRESSION = r"jsonpath='{(?P<expression>.*?)}'=(?P<condition>.*)"
Expand Down Expand Up @@ -971,13 +972,23 @@ async def logs(
def portforward(
self,
remote_port: int,
local_port: int | None = None,
local_port: LocalPortType = "match",
address: list[str] | str = "127.0.0.1",
) -> SyncPortForward | AsyncPortForward:
"""Port forward a pod.
Returns an instance of :class:`kr8s.portforward.PortForward` for this Pod.
Args:
remote_port:
The port on the Pod to forward to.
local_port:
The local port to listen on. Defaults to ``"match"``, which will match the ``remote_port``.
Set to ``"auto"`` or ``None`` to find an available high port.
Set to an ``int`` to specify a specific port.
address:
List of addresses or address to listen on. Defaults to ["127.0.0.1"], will listen only on 127.0.0.1.
Example:
This can be used as a an async context manager or with explicit start/stop methods.
Expand Down Expand Up @@ -1360,13 +1371,23 @@ async def ready(self) -> bool:
def portforward(
self,
remote_port: int,
local_port: int | None = None,
local_port: LocalPortType = "match",
address: str | list[str] = "127.0.0.1",
) -> SyncPortForward | AsyncPortForward:
"""Port forward a service.
Returns an instance of :class:`kr8s.portforward.PortForward` for this Service.
Args:
remote_port:
The port on the Pod to forward to.
local_port:
The local port to listen on. Defaults to ``"match"``, which will match the ``remote_port``.
Set to ``"auto"`` or ``None`` to find an available high port.
Set to an ``int`` to specify a specific port.
address:
List of addresses or address to listen on. Defaults to ["127.0.0.1"], will listen only on 127.0.0.1.
Example:
This can be used as a an async context manager or with explicit start/stop methods.
Expand Down
33 changes: 22 additions & 11 deletions kr8s/_portforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import socket
import sys
from contextlib import asynccontextmanager, suppress
from typing import TYPE_CHECKING, AsyncGenerator
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Union

import anyio
import httpx_ws
Expand All @@ -17,6 +17,8 @@
from ._exceptions import ConnectionClosedError
from ._types import APIObjectWithPods

LocalPortType = Union[Literal["match", "auto"], int, None]

if TYPE_CHECKING:
from ._objects import APIObject

Expand All @@ -39,14 +41,16 @@ class PortForward:
Currently Port Forwards only work when using ``asyncio`` and not ``trio``.
Args:
``resource`` (Pod or Resource): The Pod or Resource to forward to.
``remote_port`` (int): The port on the Pod to forward to.
``local_port`` (int, optional): The local port to listen on. Defaults to 0, which will choose a random port.
``address``(List[str] | str, optional): List of addresses or address to listen on. Defaults to ["127.0.0.1"],
will listen only on 127.0.0.1
resource:
The Pod or Resource to forward to.
remote_port:
The port on the Pod to forward to.
local_port:
The local port to listen on. Defaults to ``"match"``, which will match the ``remote_port``.
Set to ``"auto"`` or ``None`` to find an available high port.
Set to an ``int`` to specify a specific port.
address:
List of addresses or address to listen on. Defaults to ["127.0.0.1"], will listen only on 127.0.0.1.
Example:
This class can be used as a an async context manager or with explicit start/stop methods.
Expand Down Expand Up @@ -78,7 +82,7 @@ def __init__(
self,
resource: APIObject,
remote_port: int,
local_port: int | None = None,
local_port: LocalPortType = "match",
address: list[str] | str = "127.0.0.1",
) -> None:
with suppress(sniffio.AsyncLibraryNotFoundError):
Expand All @@ -90,7 +94,14 @@ def __init__(
self.server = None
self.servers: list[asyncio.Server] = []
self.remote_port = remote_port
self.local_port = local_port if local_port is not None else 0
if local_port == "match":
self.local_port = remote_port
elif local_port == "auto" or local_port is None:
self.local_port = 0
elif isinstance(local_port, int):
self.local_port = local_port
else:
raise TypeError("local_port must be 'match', 'auto', an int or None")
if isinstance(address, str):
address = [address]
self.address = address
Expand Down
5 changes: 5 additions & 0 deletions kr8s/portforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
This module provides a class for managing a port forward connection to a Kubernetes Pod or Service.
"""
from __future__ import annotations

import threading
import time

from ._async_utils import sync
from ._portforward import LocalPortType
from ._portforward import PortForward as _PortForward

__all__ = ["PortForward", "LocalPortType"]


@sync
class PortForward(_PortForward):
Expand Down
14 changes: 7 additions & 7 deletions kr8s/tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ async def test_pod_logs(example_pod_spec):

async def test_pod_port_forward_context_manager(nginx_service):
[nginx_pod, *_] = await nginx_service.ready_pods()
async with nginx_pod.portforward(80) as port:
async with nginx_pod.portforward(80, local_port=None) as port:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as session:
resp = await session.get(f"http://localhost:{port}/")
assert resp.status_code == 200
Expand All @@ -683,7 +683,7 @@ def test_pod_port_forward_context_manager_sync(nginx_service):
nginx_service = SyncService.get(
nginx_service.name, namespace=nginx_service.namespace
)
with nginx_service.portforward(80) as port:
with nginx_service.portforward(80, local_port=None) as port:
with httpx.Client(timeout=DEFAULT_TIMEOUT) as session:
resp = session.get(f"http://localhost:{port}/")
assert resp.status_code == 200
Expand All @@ -708,7 +708,7 @@ async def test_pod_port_forward_context_manager_manual(nginx_service):
async def test_pod_port_forward_start_stop(nginx_service):
[nginx_pod, *_] = await nginx_service.ready_pods()
for _ in range(5):
pf = nginx_pod.portforward(80)
pf = nginx_pod.portforward(80, local_port=None)
assert pf._bg_task is None
port = await pf.start()
assert pf._bg_task is not None
Expand All @@ -725,7 +725,7 @@ async def test_pod_port_forward_start_stop(nginx_service):


async def test_service_port_forward_context_manager(nginx_service):
async with nginx_service.portforward(80) as port:
async with nginx_service.portforward(80, local_port=None) as port:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as session:
resp = await session.get(f"http://localhost:{port}/")
assert resp.status_code == 200
Expand All @@ -734,7 +734,7 @@ async def test_service_port_forward_context_manager(nginx_service):


async def test_service_port_forward_start_stop(nginx_service):
pf = nginx_service.portforward(80)
pf = nginx_service.portforward(80, local_port=None)
assert pf._bg_task is None
port = await pf.start()
assert pf._bg_task is not None
Expand All @@ -752,9 +752,9 @@ async def test_service_port_forward_start_stop(nginx_service):
async def test_unsupported_port_forward():
pv = await PersistentVolume({"metadata": {"name": "foo"}})
with pytest.raises(AttributeError):
await pv.portforward(80)
await pv.portforward(80, local_port=None)
with pytest.raises(ValueError):
await PortForward(pv, 80).start()
await PortForward(pv, 80, local_port=None).start()


@pytest.mark.skipif(
Expand Down

0 comments on commit 25d74f4

Please sign in to comment.