diff --git a/pkg/headers/transport.go b/pkg/headers/transport.go index 9ad23f36..dc714bf5 100644 --- a/pkg/headers/transport.go +++ b/pkg/headers/transport.go @@ -47,6 +47,14 @@ const ( TransportProtocolTCP ) +// String implements fmt.Stringer. +func (p TransportProtocol) String() string { + if p == TransportProtocolUDP { + return "RTP/AVP" + } + return "RTP/AVP/TCP" +} + // TransportDelivery is a delivery method. type TransportDelivery int @@ -56,6 +64,14 @@ const ( TransportDeliveryMulticast ) +// String implements fmt.Stringer. +func (d TransportDelivery) String() string { + if d == TransportDeliveryUnicast { + return "unicast" + } + return "multicast" +} + // TransportMode is a transport mode. type TransportMode int @@ -67,6 +83,33 @@ const ( TransportModeRecord ) +func (m *TransportMode) unmarshal(v string) error { + str := strings.ToLower(v) + + switch str { + case "play": + *m = TransportModePlay + return nil + + // receive is an old alias for record, used by ffmpeg with the + // -listen flag, and by Darwin Streaming Server + case "record", "receive": + *m = TransportModeRecord + return nil + + default: + return fmt.Errorf("invalid transport mode: '%s'", str) + } +} + +// String implements fmt.Stringer. +func (m TransportMode) String() string { + if m == TransportModePlay { + return "play" + } + return "record" +} + // Transport is a Transport header. type Transport struct { // protocol of the stream @@ -218,24 +261,12 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { } case "mode": - str := strings.ToLower(v) - str = strings.TrimPrefix(str, "\"") - str = strings.TrimSuffix(str, "\"") - - switch str { - case "play": - v := TransportModePlay - h.Mode = &v - - // receive is an old alias for record, used by ffmpeg with the - // -listen flag, and by Darwin Streaming Server - case "record", "receive": - v := TransportModeRecord - h.Mode = &v - - default: - return fmt.Errorf("invalid transport mode: '%s'", str) + var m TransportMode + err = m.unmarshal(v) + if err != nil { + return err } + h.Mode = &m default: // ignore non-standard keys @@ -253,18 +284,10 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { func (h Transport) Marshal() base.HeaderValue { var rets []string - if h.Protocol == TransportProtocolUDP { - rets = append(rets, "RTP/AVP") - } else { - rets = append(rets, "RTP/AVP/TCP") - } + rets = append(rets, h.Protocol.String()) if h.Delivery != nil { - if *h.Delivery == TransportDeliveryUnicast { - rets = append(rets, "unicast") - } else { - rets = append(rets, "multicast") - } + rets = append(rets, h.Delivery.String()) } if h.Source != nil { @@ -309,11 +332,7 @@ func (h Transport) Marshal() base.HeaderValue { } if h.Mode != nil { - if *h.Mode == TransportModePlay { - rets = append(rets, "mode=play") - } else { - rets = append(rets, "mode=record") - } + rets = append(rets, "mode="+h.Mode.String()) } return base.HeaderValue{strings.Join(rets, ";")} diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index bcd345d4..22ab635b 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -84,12 +84,16 @@ func (e ErrServerMediaNotFound) Error() string { // ErrServerTransportHeaderInvalidMode is an error that can be returned by a server. type ErrServerTransportHeaderInvalidMode struct { - Mode headers.TransportMode + Mode *headers.TransportMode } // Error implements the error interface. func (e ErrServerTransportHeaderInvalidMode) Error() string { - return fmt.Sprintf("transport header contains a invalid mode (%v)", e.Mode) + m := "null" + if e.Mode != nil { + m = e.Mode.String() + } + return fmt.Sprintf("transport header contains a invalid mode (%v)", m) } // ErrServerTransportHeaderNoClientPorts is an error that can be returned by a server. diff --git a/server_record_test.go b/server_record_test.go index f82771d5..9cc5acd6 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -110,7 +110,7 @@ func TestServerRecordErrorAnnounce(t *testing.T) { "unsupported Content-Type header '[aa]'", }, { - "invalid medias", + "invalid sdp", base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), @@ -122,6 +122,29 @@ func TestServerRecordErrorAnnounce(t *testing.T) { }, "invalid SDP: invalid line: (\x01\x02\x03\x04)", }, + { + "invalid session", + base.Request{ + Method: base.Announce, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: []byte("v=0\r\n" + + "o=- 0 0 IN IP4 127.0.0.1\r\n" + + "s=-\r\n" + + "c=IN IP4 0.0.0.0\r\n" + + "t=0 0\r\n" + + "m=video 0 RTP/AVP 96\r\n" + + "a=control\r\n" + + "a=rtpmap:97 H264/90000\r\n" + + "a=fmtp:aa packetization-mode=1; profile-level-id=4D002A; " + + "sprop-parameter-sets=Z00AKp2oHgCJ+WbgICAgQA==,aO48gA==\r\n", + ), + }, + "invalid SDP: media 1 is invalid: clock rate not found", + }, { "invalid URL 1", invalidURLAnnounceReq(t, "rtsp:// aaaaa"), @@ -168,6 +191,87 @@ func TestServerRecordErrorAnnounce(t *testing.T) { } } +func TestServerRecordErrorSetup(t *testing.T) { + for _, ca := range []struct { + name string + err string + }{ + { + "invalid transport", + "transport header contains a invalid mode (null)", + }, + } { + t.Run(ca.name, func(t *testing.T) { + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.EqualError(t, ctx.Error, ca.err) + }, + onAnnounce: func(_ *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil, nil + }, + onRecord: func(_ *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: "localhost:8554", + UDPRTPAddress: "127.0.0.1:8000", + UDPRTCPAddress: "127.0.0.1:8001", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + medias := []*description.Media{testH264Media} + + doAnnounce(t, conn, "rtsp://localhost:8554/teststream", medias) + + var inTH *headers.Transport + + switch ca.name { + case "invalid transport": + inTH = &headers.Transport{ + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Mode: nil, + Protocol: headers.TransportProtocolUDP, + ClientPorts: &[2]int{35466, 35467}, + } + } + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/" + medias[0].Control), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + require.NotEqual(t, base.StatusOK, res.StatusCode) + }) + } +} + func TestServerRecordPath(t *testing.T) { for _, ca := range []struct { name string diff --git a/server_session.go b/server_session.go index affc7433..c8760388 100644 --- a/server_session.go +++ b/server_session.go @@ -728,7 +728,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode} + }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} } default: // record @@ -741,7 +741,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode} + }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} } }