Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions adafruit_esp32spi/adafruit_esp32spi.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,8 @@ def socket_connected(self, socket_num):
return self.socket_status(socket_num) == SOCKET_ESTABLISHED

def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE):
"""Write the bytearray buffer to a socket"""
"""Write the bytearray buffer to a socket.
Returns the number of bytes written"""
if self._debug:
print("Writing:", buffer)
self._socknum_ll[0][0] = socket_num
Expand Down Expand Up @@ -853,7 +854,7 @@ def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE):
resp = self._send_command_get_response(_SEND_UDP_DATA_CMD, self._socknum_ll)
if resp[0][0] != 1:
raise ConnectionError("Failed to send UDP data")
return
return sent

if sent != len(buffer):
self.socket_close(socket_num)
Expand All @@ -863,6 +864,8 @@ def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE):
if resp[0][0] != 1:
raise ConnectionError("Failed to verify data sent")

return sent

def socket_available(self, socket_num):
"""Determine how many bytes are waiting to be read on the socket"""
self._socknum_ll[0][0] = socket_num
Expand Down
100 changes: 85 additions & 15 deletions adafruit_esp32spi/adafruit_esp32spi_socketpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

try:
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple

if TYPE_CHECKING:
from esp32spi.adafruit_esp32spi import ESP_SPIcontrol # noqa: UP007
Expand All @@ -36,11 +36,15 @@
class SocketPool:
"""ESP32SPI SocketPool library"""

SOCK_STREAM = const(0)
SOCK_DGRAM = const(1)
# socketpool constants
SOCK_STREAM = const(1)
SOCK_DGRAM = const(2)
AF_INET = const(2)
NO_SOCKET_AVAIL = const(255)
SOL_SOCKET = const(0xFFF)
SO_REUSEADDR = const(0x0004)

# implementation specific constants
NO_SOCKET_AVAIL = const(255)
MAX_PACKET = const(4000)

def __new__(cls, iface: ESP_SPIcontrol):
Expand Down Expand Up @@ -73,7 +77,13 @@ def socket(

class Socket:
"""A simplified implementation of the Python 'socket' class, for connecting
through an interface to a remote device"""
through an interface to a remote device. Has properties specific to the
implementation.

:param SocketPool socket_pool: The underlying socket pool.
:param Optional[int] socknum: Allows wrapping a Socket instance around a socket
number returned by the nina firmware. Used internally.
"""

def __init__(
self,
Expand All @@ -82,14 +92,16 @@ def __init__(
type: int = SocketPool.SOCK_STREAM,
proto: int = 0,
fileno: Optional[int] = None, # noqa: UP007
socknum: Optional[int] = None, # noqa: UP007
):
if family != SocketPool.AF_INET:
raise ValueError("Only AF_INET family supported")
self._socket_pool = socket_pool
self._interface = self._socket_pool._interface
self._type = type
self._buffer = b""
self._socknum = self._interface.get_socket()
self._socknum = socknum if socknum is not None else self._interface.get_socket()
self._bound = ()
self.settimeout(0)

def __enter__(self):
Expand Down Expand Up @@ -121,13 +133,14 @@ def send(self, data):
conntype = self._interface.UDP_MODE
else:
conntype = self._interface.TCP_MODE
self._interface.socket_write(self._socknum, data, conn_mode=conntype)
sent = self._interface.socket_write(self._socknum, data, conn_mode=conntype)
gc.collect()
return sent

def sendto(self, data, address):
"""Connect and send some data to the socket."""
self.connect(address)
self.send(data)
return self.send(data)

def recv(self, bufsize: int) -> bytes:
"""Reads some bytes from the connected remote address. Will only return
Expand All @@ -150,12 +163,12 @@ def recv_into(self, buffer, nbytes: int = 0):
if not 0 <= nbytes <= len(buffer):
raise ValueError("nbytes must be 0 to len(buffer)")

last_read_time = time.monotonic()
last_read_time = time.monotonic_ns()
num_to_read = len(buffer) if nbytes == 0 else nbytes
num_read = 0
while num_to_read > 0:
# we might have read socket data into the self._buffer with:
# esp32spi_wsgiserver: socket_readline
# adafruit_wsgi.esp32spi_wsgiserver: socket_readline
if len(self._buffer) > 0:
bytes_to_read = min(num_to_read, len(self._buffer))
buffer[num_read : num_read + bytes_to_read] = self._buffer[:bytes_to_read]
Expand All @@ -167,7 +180,7 @@ def recv_into(self, buffer, nbytes: int = 0):

num_avail = self._available()
if num_avail > 0:
last_read_time = time.monotonic()
last_read_time = time.monotonic_ns()
bytes_read = self._interface.socket_read(self._socknum, min(num_to_read, num_avail))
buffer[num_read : num_read + len(bytes_read)] = bytes_read
num_read += len(bytes_read)
Expand All @@ -176,15 +189,27 @@ def recv_into(self, buffer, nbytes: int = 0):
# We got a message, but there are no more bytes to read, so we can stop.
break
# No bytes yet, or more bytes requested.
if self._timeout > 0 and time.monotonic() - last_read_time > self._timeout:

if self._timeout == 0: # if in non-blocking mode, stop now.
break

# Time out if there's a positive timeout set.
delta = (time.monotonic_ns() - last_read_time) // 1_000_000
if self._timeout > 0 and delta > self._timeout:
raise OSError(errno.ETIMEDOUT)
return num_read

def settimeout(self, value):
"""Set the read timeout for sockets.
If value is 0 socket reads will block until a message is available.
"""Set the read timeout for sockets in seconds.
``0`` means non-blocking. ``None`` means block indefinitely.
"""
self._timeout = value
if value is None:
self._timeout = -1
else:
if value < 0:
raise ValueError("Timeout cannot be a negative number")
# internally in milliseconds as an int
self._timeout = int(value * 1000)

def _available(self):
"""Returns how many bytes of data are available to be read (up to the MAX_PACKET length)"""
Expand Down Expand Up @@ -217,3 +242,48 @@ def _connected(self):
def close(self):
"""Close the socket, after reading whatever remains"""
self._interface.socket_close(self._socknum)

####################################################################
# WORK IN PROGRESS
####################################################################

def accept(self):
"""Accept a connection on a listening socket of type SOCK_STREAM,
creating a new socket of type SOCK_STREAM. Returns a tuple of
(new_socket, remote_address)
"""
client_sock_num = self._interface.socket_available(self._socknum)
if client_sock_num != SocketPool.NO_SOCKET_AVAIL:
sock = Socket(self._socket_pool, socknum=client_sock_num)
# get remote information (addr and port)
remote = self._interface.get_remote_data(client_sock_num)
ip_address = "{}.{}.{}.{}".format(*remote["ip_addr"])
port = remote["port"]
client_address = (ip_address, port)
return sock, client_address
raise OSError(errno.ECONNRESET)

def bind(self, address: tuple[str, int]):
"""Bind a socket to an address"""
self._bound = address

def listen(self, backlog: int): # pylint: disable=unused-argument
"""Set socket to listen for incoming connections.
:param int backlog: length of backlog queue for waiting connections (ignored)
"""
if not self._bound:
self._bound = (self._interface.ip_address, 80)
port = self._bound[1]
self._interface.start_server(port, self._socknum)

def setblocking(self, flag: bool):
"""Set the blocking behaviour of this socket.
:param bool flag: False means non-blocking, True means block indefinitely.
"""
if flag:
self.settimeout(None)
else:
self.settimeout(0)

def setsockopt(self, *opts, **kwopts):
"""Dummy call for compatibility."""