Skip to content

Commit

Permalink
Rewrite udpclient to not use channels
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesDunne committed Jun 9, 2021
1 parent 4508104 commit 358d911
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 167 deletions.
8 changes: 7 additions & 1 deletion cmd/sni/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package main
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"net/url"
"sni/protos/sni"
"sni/snes"
Expand Down Expand Up @@ -75,6 +78,8 @@ func (s *memoryUnaryService) ReadMemory(ctx context.Context, request *sni.ReadMe

complete := make(chan error)

peer.FromContext(ctx)

addr := request.GetAddress()
size := int32(request.GetSize())
reads := make([]snes.Read, 0, 8)
Expand Down Expand Up @@ -114,7 +119,7 @@ func (s *memoryUnaryService) ReadMemory(ctx context.Context, request *sni.ReadMe
return nil, ctx.Err()
case err = <-complete:
if err != nil {
// TODO: handle terminal error
err = status.Error(codes.Unavailable, err.Error())
return
}
break
Expand All @@ -138,6 +143,7 @@ func (s *memoryUnaryService) AcquireDevice(uri string) (dev snes.Queue, err erro
dev, ok = s.devices[uri]
s.devicesRw.RUnlock()
if ok {
// TODO: detect if device is closed and destroy/recreate it if so
return
}

Expand Down
42 changes: 19 additions & 23 deletions snes/retroarch/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"log"
"net"
"sni/snes"
"sni/udpclient"
"sni/util"
"sni/util/env"
"strings"
Expand All @@ -17,21 +16,21 @@ var logDetector = false

type Driver struct {
detectors []*RAClient
addresses []*net.UDPAddr

devices []snes.DeviceDescriptor
opened *Queue
}

func NewDriver(addresses []*net.UDPAddr) *Driver {
d := &Driver{
addresses: addresses,
detectors: make([]*RAClient, len(addresses)),
}

for i, addr := range addresses {
c := &RAClient{}
c := NewRAClient(addr, fmt.Sprintf("retroarch[%d]", i))
d.detectors[i] = c
udpclient.MakeUDPClient(fmt.Sprintf("retroarch[%d]", i), &c.UDPClient)
c.addr = addr
}

return d
Expand All @@ -55,19 +54,29 @@ func (d *Driver) Open(desc snes.DeviceDescriptor) (q snes.Queue, err error) {
return nil, fmt.Errorf("retroarch: open: descriptor is not of expected type")
}

// find detector with same id:
// create a new device with its own connection:
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", descriptor.GetId())
if err != nil {
return
}

var c *RAClient
c = NewRAClient(addr, addr.String())
err = c.Connect(addr)
if err != nil {
return
}

// if we already detected the version, copy it in:
for _, detector := range d.detectors {
if descriptor.GetId() == detector.GetId() {
c = detector
c.version = detector.version
c.useRCR = detector.useRCR
break
}
}

if c == nil {
return nil, fmt.Errorf("retroarch: open: could not find socket by device='%s'\n", descriptor.GetId())
}

// fill back in the addr for the descriptor:
descriptor.addr = c.addr

Expand All @@ -78,23 +87,10 @@ func (d *Driver) Open(desc snes.DeviceDescriptor) (q snes.Queue, err error) {

q = qu

// record that this device is opened:
d.opened = qu
go func() {
<-q.Closed()
d.opened = nil
}()

return
}

func (d *Driver) Detect() (devices []snes.DeviceDescriptor, err error) {
// stop auto-detection if connected already:
if d.opened != nil {
devices = d.devices
return
}

devices = make([]snes.DeviceDescriptor, 0, len(d.detectors))
for i, detector := range d.detectors {
detector.MuteLog(true)
Expand Down
17 changes: 7 additions & 10 deletions snes/retroarch/queue.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package retroarch

import (
"errors"
"fmt"
"sni/snes"
"sni/udpclient"
"sync"
)

Expand All @@ -22,12 +20,12 @@ var (
)

func (q *Queue) IsTerminalError(err error) bool {
if errors.Is(err, udpclient.ErrTimeout) {
return true
}
if errors.Is(err, ErrClosed) {
return true
}
//if errors.Is(err, udpclient.ErrTimeout) {
// return true
//}
//if errors.Is(err, ErrClosed) {
// return true
//}
return false
}

Expand All @@ -39,12 +37,11 @@ func (q *Queue) Close() error {
defer q.lock.Unlock()
q.lock.Lock()

// don't close the underlying connection since it is reused for detection.

if q.c == nil {
return nil
}

q.c.Close()
q.c = nil
close(q.closed)

Expand Down
32 changes: 16 additions & 16 deletions snes/retroarch/raclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"time"
)

const readWriteTimeout = time.Second * 5

type RAClient struct {
udpclient.UDPClient

Expand All @@ -21,13 +23,21 @@ type RAClient struct {
useRCR bool
}

func NewRAClient(addr *net.UDPAddr, name string) *RAClient {
c := &RAClient{
addr: addr,
}
udpclient.MakeUDPClient(name, &c.UDPClient)
return c
}

func (c *RAClient) GetId() string {
return c.addr.String()
}

func (c *RAClient) Version() (err error) {
var rsp []byte
rsp, err = c.WriteThenReadTimeout([]byte("VERSION\n"), time.Second*5)
rsp, err = c.WriteThenReadTimeout([]byte("VERSION\n"), readWriteTimeout)
if err != nil {
return
}
Expand Down Expand Up @@ -91,17 +101,7 @@ func (c *RAClient) ReadMemory(busAddr uint32, size uint8) (data []byte, err erro
reqStr := sb.String()
var rsp []byte

defer func() {
c.Unlock()
}()
c.Lock()

err = c.WriteTimeout([]byte(reqStr), time.Second*5)
if err != nil {
return
}

rsp, err = c.ReadTimeout(time.Second * 5)
rsp, err = c.WriteThenReadTimeout([]byte(reqStr), readWriteTimeout)
if err != nil {
return
}
Expand Down Expand Up @@ -143,7 +143,7 @@ func (c *RAClient) ReadMemoryBatch(batch []snes.Read, keepAlive snes.KeepAlive)
c.Lock()

// send all commands up front in one packet:
err = c.WriteTimeout([]byte(reqStr), time.Second*5)
err = c.WriteTimeout([]byte(reqStr), readWriteTimeout)
if err != nil {
return
}
Expand All @@ -159,7 +159,7 @@ func (c *RAClient) ReadMemoryBatch(batch []snes.Read, keepAlive snes.KeepAlive)
continue
}

rsp, err = c.ReadTimeout(time.Second * 5)
rsp, err = c.ReadTimeout(readWriteTimeout)
if err != nil {
return
}
Expand Down Expand Up @@ -248,7 +248,7 @@ func (c *RAClient) WriteMemoryBatch(batch []snes.Write, keepAlive snes.KeepAlive
reqStr := sb.String()

log.Printf("retroarch: > %s", reqStr)
err = c.WriteTimeout([]byte(reqStr), time.Second*5)
err = c.WriteTimeout([]byte(reqStr), readWriteTimeout)
if err != nil {
return
}
Expand All @@ -263,7 +263,7 @@ func (c *RAClient) WriteMemoryBatch(batch []snes.Write, keepAlive snes.KeepAlive

// expect a response from WRITE_CORE_MEMORY
var rsp []byte
rsp, err = c.ReadTimeout(time.Second * 5)
rsp, err = c.ReadTimeout(readWriteTimeout)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit 358d911

Please sign in to comment.