Skip to content

Commit

Permalink
Accept abstract namespace paths for unix domain sockets
Browse files Browse the repository at this point in the history
Accept paths starting with a null byte in create_unix_listener and
connect_unix_socket to allow creating abstract unix sockets. Fixes agronholm#781
  • Loading branch information
tapetersen committed Sep 5, 2024
1 parent ee8165b commit 16683b6
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 36 deletions.
30 changes: 22 additions & 8 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,30 @@ async def setup_unix_local_socket(
path_str: str | bytes | None
if path is not None:
path_str = os.fspath(path)
is_abstract = (
path_str.startswith(b"\0")
if isinstance(path_str, bytes)
else path_str.startswith("\0")
)

# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
raise
if is_abstract:
# Unix abstract namespace socket. No file backing so skip stat call
pass
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (
errno.ENOENT,
errno.ENOTDIR,
errno.EBADF,
errno.ELOOP,
):
raise
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
else:
path_str = None

Expand Down
115 changes: 87 additions & 28 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
has_ipv6 = True

skip_ipv6_mark = pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available")
skip_unix_abstract_mark = pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Abstract namespace sockets is a Linux only feature",
)


@pytest.fixture
Expand Down Expand Up @@ -735,12 +739,20 @@ async def test_bind_link_local(self) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXStream:
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
@pytest.fixture(
params=[
"path",
pytest.param("abstract", marks=[skip_unix_abstract_mark]),
]
)
def socket_path(self, request: SubRequest) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
with tempfile.TemporaryDirectory() as path:
yield Path(path) / "socket"
if request.param == "path":
yield Path(path) / "socket"
else:
yield Path(f"\0{path}") / "socket"

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand All @@ -764,7 +776,15 @@ async def test_extra_attributes(
assert (
stream.extra(SocketAttribute.local_address) == raw_socket.getsockname()
)
assert stream.extra(SocketAttribute.remote_address) == str(socket_path)
remote_addr = stream.extra(SocketAttribute.remote_address)
if isinstance(remote_addr, str):
assert stream.extra(SocketAttribute.remote_address) == str(socket_path)
else:
assert isinstance(remote_addr, bytes)
assert stream.extra(SocketAttribute.remote_address) == bytes(
socket_path
)

pytest.raises(
TypedAttributeLookupError, stream.extra, SocketAttribute.local_port
)
Expand Down Expand Up @@ -1031,8 +1051,12 @@ async def test_send_after_close(
await stream.send(b"foo")

async def test_cannot_connect(self, socket_path: Path) -> None:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)
if str(socket_path).startswith("\0"):
with pytest.raises(ConnectionRefusedError):
await connect_unix(socket_path)
else:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)

async def test_connecting_using_bytes(
self, server_sock: socket.socket, socket_path: Path
Expand All @@ -1057,12 +1081,20 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXListener:
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
@pytest.fixture(
params=[
"path",
pytest.param("abstract", marks=[skip_unix_abstract_mark]),
]
)
def socket_path(self, request: SubRequest) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
with tempfile.TemporaryDirectory() as path:
yield Path(path) / "socket"
if request.param == "path":
yield Path(path) / "socket"
else:
yield Path(f"\0{path}") / "socket"

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand Down Expand Up @@ -1461,12 +1493,20 @@ async def test_send_after_close(self, family: AnyIPAddressFamily) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXDatagramSocket:
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
@pytest.fixture(
params=[
"path",
pytest.param("abstract", marks=[skip_unix_abstract_mark]),
]
)
def socket_path(self, request: SubRequest) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
with tempfile.TemporaryDirectory() as path:
yield Path(path) / "socket"
if request.param == "path":
yield Path(path) / "socket"
else:
yield Path(f"\0{path}") / "socket"

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand Down Expand Up @@ -1506,12 +1546,18 @@ async def test_send_receive(self, socket_path_or_str: Path | str) -> None:
await sock.sendto(b"blah", path)
request, addr = await sock.receive()
assert request == b"blah"
assert addr == path
if isinstance(addr, bytes):
assert addr == path.encode()
else:
assert addr == path

await sock.sendto(b"halb", path)
response, addr = await sock.receive()
assert response == b"halb"
assert addr == path
if isinstance(addr, bytes):
assert addr == path.encode()
else:
assert addr == path

async def test_iterate(self, peer_socket_path: Path, socket_path: Path) -> None:
async def serve() -> None:
Expand Down Expand Up @@ -1589,18 +1635,31 @@ async def test_local_path_invalid_ascii(self, socket_path: Path) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestConnectedUNIXDatagramSocket:
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
@pytest.fixture(
params=[
"path",
pytest.param("abstract", marks=[skip_unix_abstract_mark]),
]
)
def socket_path(self, request: SubRequest) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
with tempfile.TemporaryDirectory() as path:
yield Path(path) / "socket"
if request.param == "path":
yield Path(path) / "socket"
else:
yield Path(f"\0{path}") / "socket"

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
return socket_path if request.param else str(socket_path)

@pytest.fixture
@pytest.fixture(
params=[
pytest.param("path", id="path-peer"),
pytest.param("abstract", marks=[skip_unix_abstract_mark], id="abstract-peer"),
]
)
def peer_socket_path(self) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
Expand Down Expand Up @@ -1634,10 +1693,8 @@ async def test_extra_attributes(
raw_socket = unix_dg.extra(SocketAttribute.raw_socket)
assert raw_socket is not None
assert unix_dg.extra(SocketAttribute.family) == AddressFamily.AF_UNIX
assert unix_dg.extra(SocketAttribute.local_address) == str(socket_path)
assert unix_dg.extra(SocketAttribute.remote_address) == str(
peer_socket_path
)
assert os.fsencode(unix_dg.extra(SocketAttribute.local_address)) == os.fsencode(socket_path)
assert os.fsencode(unix_dg.extra(SocketAttribute.remote_address)) == os.fsencode(peer_socket_path)
pytest.raises(
TypedAttributeLookupError, unix_dg.extra, SocketAttribute.local_port
)
Expand All @@ -1657,11 +1714,11 @@ async def test_send_receive(
peer_socket_path_or_str,
local_path=socket_path_or_str,
) as unix_dg2:
socket_path = str(socket_path_or_str)
socket_path = os.fsdecode(socket_path_or_str)

await unix_dg2.send(b"blah")
request = await unix_dg1.receive()
assert request == (b"blah", socket_path)
data, remote_addr = await unix_dg1.receive()
assert (data, os.fsdecode(remote_addr)) == (b"blah", socket_path)

await unix_dg1.sendto(b"halb", socket_path)
response = await unix_dg2.receive()
Expand All @@ -1682,13 +1739,15 @@ async def serve() -> None:
async with await create_connected_unix_datagram_socket(
peer_socket_path, local_path=socket_path
) as unix_dg2:
path = str(socket_path)
path = os.fsdecode(socket_path)
async with create_task_group() as tg:
tg.start_soon(serve)
await unix_dg1.sendto(b"FOOBAR", path)
assert await unix_dg1.receive() == (b"RABOOF", path)
data, addr = await unix_dg1.receive()
assert (data, os.fsdecode(addr)) == (b"RABOOF", path)
await unix_dg1.sendto(b"123456", path)
assert await unix_dg1.receive() == (b"654321", path)
data, addr = await unix_dg1.receive()
assert (data, os.fsdecode(addr)) == (b"654321", path)
tg.cancel_scope.cancel()

async def test_concurrent_receive(
Expand Down

0 comments on commit 16683b6

Please sign in to comment.