Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

srt: process connection requests in parallel (#3382) #3534

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,5 @@ replace code.cloudfoundry.org/bytefmt => github.com/cloudfoundry/bytefmt v0.0.0-
replace github.com/pion/ice/v2 => github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9

replace github.com/pion/webrtc/v3 => github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06

replace github.com/datarhei/gosrt => github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ github.com/alecthomas/kong v0.9.0 h1:G5diXxc85KvoV2f0ZRVuMsi45IrBgx9zDNGNj165aPA
github.com/alecthomas/kong v0.9.0/go.mod h1:Y47y5gKfHp1hDc7CH7OeXgLIpp+Q2m1Ni0L5s3bI8Os=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7 h1:4WE1Nez3YyD1CgJfWlnyp+uLLPZOKD5ywWPvwbf/Jp4=
github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 h1:Vax9SzYE68ZYLwFaK7lnCV2ZhX9/YqAJX6xxROPRqEM=
github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw=
github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 h1:WtKhXOpd8lgTeXF3RQVOzkNRuy83ygvWEpMYD2aoY3Q=
Expand Down Expand Up @@ -37,8 +39,6 @@ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/datarhei/gosrt v0.6.0 h1:HrrXAw90V78ok4WMIhX6se1aTHPCn82Sg2hj+PhdmGc=
github.com/datarhei/gosrt v0.6.0/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
106 changes: 30 additions & 76 deletions internal/servers/srt/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,13 @@ type conn struct {
pathName string
query string
sconn srt.Conn

chNew chan srtNewConnReq
chSetConn chan srt.Conn
}

func (c *conn) initialize() {
c.ctx, c.ctxCancel = context.WithCancel(c.parentCtx)

c.created = time.Now()
c.uuid = uuid.New()
c.chNew = make(chan srtNewConnReq)
c.chSetConn = make(chan srt.Conn)

c.Log(logger.Info, "opened")

Expand Down Expand Up @@ -130,36 +125,20 @@ func (c *conn) run() { //nolint:dupl
}

func (c *conn) runInner() error {
var req srtNewConnReq
select {
case req = <-c.chNew:
case <-c.ctx.Done():
return errors.New("terminated")
}

answerSent, err := c.runInner2(req)

if !answerSent {
req.res <- nil
}

return err
}

func (c *conn) runInner2(req srtNewConnReq) (bool, error) {
var streamID streamID
err := streamID.unmarshal(req.connReq.StreamId())
err := streamID.unmarshal(c.connReq.StreamId())
if err != nil {
return false, fmt.Errorf("invalid stream ID '%s': %w", req.connReq.StreamId(), err)
c.connReq.Reject(srt.REJ_PEER)
return fmt.Errorf("invalid stream ID '%s': %w", c.connReq.StreamId(), err)
}

if streamID.mode == streamIDModePublish {
return c.runPublish(req, &streamID)
return c.runPublish(&streamID)
}
return c.runRead(req, &streamID)
return c.runRead(&streamID)
}

func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
func (c *conn) runPublish(streamID *streamID) error {
path, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{
Author: c,
AccessRequest: defs.PathAccessRequest{
Expand All @@ -178,21 +157,24 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
if errors.As(err, &terr) {
// wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError)
return false, terr
c.connReq.Reject(srt.REJ_PEER)
return terr
}
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c})

err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTPublishPassphrase)
err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTPublishPassphrase)
if err != nil {
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

sconn, err := c.exchangeRequestWithConn(req)
sconn, err := c.connReq.Accept()
if err != nil {
return true, err
return err
}

c.mutex.Lock()
Expand All @@ -210,12 +192,12 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
select {
case err := <-readerErr:
sconn.Close()
return true, err
return err

case <-c.ctx.Done():
sconn.Close()
<-readerErr
return true, errors.New("terminated")
return errors.New("terminated")
}
}

Expand Down Expand Up @@ -256,7 +238,7 @@ func (c *conn) runPublishReader(sconn srt.Conn, path defs.Path) error {
}
}

func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
func (c *conn) runRead(streamID *streamID) error {
path, stream, err := c.pathManager.AddReader(defs.PathAddReaderReq{
Author: c,
AccessRequest: defs.PathAccessRequest{
Expand All @@ -274,21 +256,24 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
if errors.As(err, &terr) {
// wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError)
return false, err
c.connReq.Reject(srt.REJ_PEER)
return terr
}
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c})

err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTReadPassphrase)
err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTReadPassphrase)
if err != nil {
return false, err
c.connReq.Reject(srt.REJ_PEER)
return err
}

sconn, err := c.exchangeRequestWithConn(req)
sconn, err := c.connReq.Accept()
if err != nil {
return true, err
return err
}
defer sconn.Close()

Expand All @@ -307,7 +292,7 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {

err = mpegts.FromStream(stream, writer, bw, sconn, time.Duration(c.writeTimeout))
if err != nil {
return true, err
return err
}

c.Log(logger.Info, "is reading from path '%s', %s",
Expand All @@ -331,41 +316,10 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {

select {
case <-c.ctx.Done():
return true, fmt.Errorf("terminated")
return fmt.Errorf("terminated")

case err := <-writer.Error():
return true, err
}
}

func (c *conn) exchangeRequestWithConn(req srtNewConnReq) (srt.Conn, error) {
req.res <- c

select {
case sconn := <-c.chSetConn:
return sconn, nil

case <-c.ctx.Done():
return nil, errors.New("terminated")
}
}

// new is called by srtListener through srtServer.
func (c *conn) new(req srtNewConnReq) *conn {
select {
case c.chNew <- req:
return <-req.res

case <-c.ctx.Done():
return nil
}
}

// setConn is called by srtListener .
func (c *conn) setConn(sconn srt.Conn) {
select {
case c.chSetConn <- sconn:
case <-c.ctx.Done():
return err
}
}

Expand Down
17 changes: 2 additions & 15 deletions internal/servers/srt/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,11 @@ func (l *listener) run() {

func (l *listener) runInner() error {
for {
var sconn *conn
conn, _, err := l.ln.Accept(func(req srt.ConnRequest) srt.ConnType {
sconn = l.parent.newConnRequest(req)
if sconn == nil {
return srt.REJECT
}

// currently it's the same to return SUBSCRIBE or PUBLISH
return srt.SUBSCRIBE
})
req, err := l.ln.Accept2()
if err != nil {
return err
}

if conn == nil {
continue
}

sconn.setConn(conn)
l.parent.newConnRequest(req)
}
}
27 changes: 6 additions & 21 deletions internal/servers/srt/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ func srtMaxPayloadSize(u int) int {
return ((u - 16) / 188) * 188 // 16 = SRT header, 188 = MPEG-TS packet
}

type srtNewConnReq struct {
connReq srt.ConnRequest
res chan *conn
}

type serverAPIConnsListRes struct {
data *defs.APISRTConnList
err error
Expand Down Expand Up @@ -90,7 +85,7 @@ type Server struct {
conns map[*conn]struct{}

// in
chNewConnRequest chan srtNewConnReq
chNewConnRequest chan srt.ConnRequest
chAcceptErr chan error
chCloseConn chan *conn
chAPIConnsList chan serverAPIConnsListReq
Expand All @@ -113,7 +108,7 @@ func (s *Server) Initialize() error {
s.ctx, s.ctxCancel = context.WithCancel(context.Background())

s.conns = make(map[*conn]struct{})
s.chNewConnRequest = make(chan srtNewConnReq)
s.chNewConnRequest = make(chan srt.ConnRequest)
s.chAcceptErr = make(chan error)
s.chCloseConn = make(chan *conn)
s.chAPIConnsList = make(chan serverAPIConnsListReq)
Expand Down Expand Up @@ -165,7 +160,7 @@ outer:
writeTimeout: s.WriteTimeout,
writeQueueSize: s.WriteQueueSize,
udpMaxPayloadSize: s.UDPMaxPayloadSize,
connReq: req.connReq,
connReq: req,
runOnConnect: s.RunOnConnect,
runOnConnectRestart: s.RunOnConnectRestart,
runOnDisconnect: s.RunOnDisconnect,
Expand All @@ -176,7 +171,6 @@ outer:
}
c.initialize()
s.conns[c] = struct{}{}
req.res <- c

case c := <-s.chCloseConn:
delete(s.conns, c)
Expand Down Expand Up @@ -236,20 +230,11 @@ func (s *Server) findConnByUUID(uuid uuid.UUID) *conn {
}

// newConnRequest is called by srtListener.
func (s *Server) newConnRequest(connReq srt.ConnRequest) *conn {
req := srtNewConnReq{
connReq: connReq,
res: make(chan *conn),
}

func (s *Server) newConnRequest(connReq srt.ConnRequest) {
select {
case s.chNewConnRequest <- req:
c := <-req.res

return c.new(req)

case s.chNewConnRequest <- connReq:
case <-s.ctx.Done():
return nil
connReq.Reject(srt.REJ_CLOSE)
}
}

Expand Down
17 changes: 8 additions & 9 deletions internal/staticsources/srt/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ func TestSource(t *testing.T) {
defer ln.Close()

go func() {
conn, _, err := ln.Accept(func(req srt.ConnRequest) srt.ConnType {
require.Equal(t, "sidname", req.StreamId())
err := req.SetPassphrase("ttest1234567")
if err != nil {
return srt.REJECT
}
return srt.SUBSCRIBE
})
req, err := ln.Accept2()
require.NoError(t, err)

require.Equal(t, "sidname", req.StreamId())
err = req.SetPassphrase("ttest1234567")
require.NoError(t, err)

conn, err := req.Accept()
require.NoError(t, err)
require.NotNil(t, conn)
defer conn.Close()

track := &mpegts.Track{
Expand Down
Loading