diff --git a/src/udp.rs b/src/udp.rs index 240658c..266788a 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -11,22 +11,25 @@ pub struct UdpSocket { } impl UdpSocket { - pub fn new( + pub fn new(socket: Socket) -> Self { + UdpSocket { socket } + } + + pub fn open( + &mut self, bus: &mut SpiBus, - socket: Socket, local_port: u16, - ) -> Result { - socket.command(bus, socketn::Command::Close)?; - socket.reset_interrupt(bus, socketn::Interrupt::All)?; - socket.set_source_port(bus, local_port)?; - socket.set_mode(bus, socketn::Protocol::Udp)?; - socket.set_interrupt_mask( + ) -> Result<(), SpiBus::Error> { + self.socket.command(bus, socketn::Command::Close)?; + self.socket.reset_interrupt(bus, socketn::Interrupt::All)?; + self.socket.set_source_port(bus, local_port)?; + self.socket.set_mode(bus, socketn::Protocol::Udp)?; + self.socket.set_interrupt_mask( bus, socketn::Interrupt::SendOk as u8 & socketn::Interrupt::Timeout as u8, )?; - socket.command(bus, socketn::Command::Open)?; - - Ok(UdpSocket { socket }) + self.socket.command(bus, socketn::Command::Open)?; + Ok(()) } fn set_destination( @@ -182,17 +185,23 @@ where { type UdpSocket = UdpSocket; type Error = UdpSocketError; - fn connect(&self, remote: SocketAddr) -> Result { + + fn socket(&self) -> Result { + let mut device = self.device.borrow_mut(); + if let Some(socket) = device.take_socket() { + Ok(UdpSocket::new(socket)) + } else { + Err(Self::Error::NoMoreSockets) + } + } + + fn connect(&self, socket: &mut Self::UdpSocket, remote: SocketAddr) -> Result<(), Self::Error> { let mut device = self.device.borrow_mut(); if let SocketAddr::V4(remote) = remote { - if let Some(socket) = device.take_socket() { - // TODO find a random port - let mut udp_socket = UdpSocket::new(&mut device.bus, socket, 4000)?; - udp_socket.set_destination(&mut device.bus, remote)?; - Ok(udp_socket) - } else { - Err(Self::Error::NoMoreSockets) - } + // TODO find a random port + socket.open(&mut device.bus, 4000)?; + socket.set_destination(&mut device.bus, remote)?; + Ok(()) } else { Err(Self::Error::UnsupportedAddress) } @@ -221,13 +230,10 @@ where SpiBus: ActiveBus, HostImpl: Host, { - fn bind(&self, local_port: u16) -> Result { + fn bind(&self, socket: &mut Self::UdpSocket, local_port: u16) -> Result<(), Self::Error> { let mut device = self.device.borrow_mut(); - if let Some(socket) = device.take_socket() { - Ok(UdpSocket::new(&mut device.bus, socket, local_port)?) - } else { - Err(Self::Error::NoMoreSockets) - } + socket.open(&mut device.bus, local_port)?; + Ok(()) } fn send_to( &self,