Skip to content

Commit

Permalink
Modify unix socket tests to short paths
Browse files Browse the repository at this point in the history
  • Loading branch information
cjntaylor committed Dec 23, 2023
1 parent befba29 commit 27e38b1
Showing 1 changed file with 38 additions and 29 deletions.
67 changes: 38 additions & 29 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import io
import os
import platform
import random
import socket
import string
import sys
import threading
import time
Expand All @@ -14,7 +16,7 @@
from socket import AddressFamily
from ssl import SSLContext, SSLError
from threading import Thread
from typing import Any, Iterable, Iterator, NoReturn, TypeVar, cast
from typing import Any, Generator, Iterable, Iterator, NoReturn, TypeVar, cast

import psutil
import pytest
Expand Down Expand Up @@ -702,14 +704,41 @@ async def test_bind_link_local(self) -> None:
pass


class TestUNIX:
# MacOS requires unix socket paths under 107 bytes, but the default mktemp
# implementation generates paths that are too long under /private/var/folders
#
# Linux has a similar limitation of 108 bytes, but doesn't normally hit this
# case because the mktemp implementation creates folders in /tmp. However, there
# are cases where mount points could cause /tmp to resolve to a path over the
# limit
#
# Reference: https://github.com/python/cpython/issues/93852
@classmethod
def _short_socket_path(cls) -> Path:
name = "".join(random.choice(string.ascii_letters) for _ in range(10))
socket_path = Path("/tmp") / name
return socket_path

# Harcode an implementation that should work in both cases by locating the
# socket directly under /tmp, and removing it when done
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
socket_path = TestUNIX._short_socket_path()
yield socket_path
socket_path.unlink(missing_ok=True)

@pytest.fixture
def peer_socket_path(self) -> Generator[Path, None, None]:
socket_path = TestUNIX._short_socket_path()
yield socket_path
socket_path.unlink(missing_ok=True)


@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXStream:
@pytest.fixture
def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("socket")

class TestUNIXStream(TestUNIX):
@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
return socket_path if request.param else str(socket_path)
Expand Down Expand Up @@ -1024,11 +1053,7 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None:
@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXListener:
@pytest.fixture
def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("socket")

class TestUNIXListener(TestUNIX):
@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
return socket_path if request.param else str(socket_path)
Expand Down Expand Up @@ -1421,19 +1446,11 @@ async def test_send_after_close(self, family: AnyIPAddressFamily) -> None:
@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXDatagramSocket:
@pytest.fixture
def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("socket")

class TestUNIXDatagramSocket(TestUNIX):
@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
return socket_path if request.param else str(socket_path)

@pytest.fixture
def peer_socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("peer_socket")

async def test_extra_attributes(self, socket_path: Path) -> None:
async with await create_unix_datagram_socket(local_path=socket_path) as unix_dg:
raw_socket = unix_dg.extra(SocketAttribute.raw_socket)
Expand Down Expand Up @@ -1543,19 +1560,11 @@ async def test_local_path_invalid_ascii(self, socket_path: Path) -> None:
@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestConnectedUNIXDatagramSocket:
@pytest.fixture
def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("socket")

class TestConnectedUNIXDatagramSocket(TestUNIX):
@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
return socket_path if request.param else str(socket_path)

@pytest.fixture
def peer_socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("peer_socket")

@pytest.fixture(params=[False, True], ids=["peer_str", "peer_path"])
def peer_socket_path_or_str(
self, request: SubRequest, peer_socket_path: Path
Expand Down

0 comments on commit 27e38b1

Please sign in to comment.