From 075d4dea5e993fdb0485364d7eab986df333b358 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 22 Sep 2024 23:31:04 +0200 Subject: [PATCH] support using JWT in Authorization header with API, Metrics, PProf (#3630) --- README.md | 2 +- internal/api/api.go | 12 +- internal/api/api_test.go | 35 ----- internal/auth/manager.go | 68 +++++++--- internal/auth/manager_test.go | 171 +++++++++++++++--------- internal/defs/path.go | 19 ++- internal/metrics/metrics.go | 12 +- internal/playback/on_get_test.go | 17 +-- internal/playback/on_list_test.go | 17 +-- internal/playback/server.go | 21 +-- internal/pprof/pprof.go | 12 +- internal/servers/hls/http_server.go | 35 +---- internal/servers/hls/server_test.go | 9 +- internal/servers/rtmp/conn.go | 4 +- internal/servers/rtmp/server_test.go | 4 +- internal/servers/rtsp/conn.go | 2 +- internal/servers/rtsp/session.go | 4 +- internal/servers/srt/conn.go | 4 +- internal/servers/srt/server_test.go | 4 +- internal/servers/webrtc/http_server.go | 83 ++++-------- internal/servers/webrtc/server.go | 14 +- internal/servers/webrtc/server_test.go | 176 +------------------------ internal/servers/webrtc/session.go | 45 ++----- internal/test/auth_manager.go | 6 +- 24 files changed, 268 insertions(+), 508 deletions(-) diff --git a/README.md b/README.md index b0453517e6c..bbeb4d1cbba 100644 --- a/README.md +++ b/README.md @@ -1188,7 +1188,7 @@ The JWT is expected to contain a claim, with a list of permissions in the same f } ``` -Clients are expected to pass the JWT in the Authorization header (in case of HLS and WebRTC) or in query parameters (in case of all other protocols), for instance: +Clients are expected to pass the JWT in the Authorization header (in case of HLS, WebRTC and all web-based features) or in query parameters (in case of all other protocols), for instance: ``` ffmpeg -re -stream_loop -1 -i file.ts -c copy -f rtsp rtsp://localhost:8554/mystream?jwt=MY_JWT diff --git a/internal/api/api.go b/internal/api/api.go index b3c77d6ab8d..15045224238 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -284,17 +284,13 @@ func (a *API) middlewareOrigin(ctx *gin.Context) { } func (a *API) middlewareAuth(ctx *gin.Context) { - user, pass, hasCredentials := ctx.Request.BasicAuth() - err := a.AuthManager.Authenticate(&auth.Request{ - User: user, - Pass: pass, - Query: ctx.Request.URL.RawQuery, - IP: net.ParseIP(ctx.ClientIP()), - Action: conf.AuthActionAPI, + IP: net.ParseIP(ctx.ClientIP()), + Action: conf.AuthActionAPI, + HTTPRequest: ctx.Request, }) if err != nil { - if !hasCredentials { + if err.(*auth.Error).AskCredentials { //nolint:errorlint ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) ctx.AbortWithStatus(http.StatusUnauthorized) return diff --git a/internal/api/api_test.go b/internal/api/api_test.go index d7d026e0bb5..85ec632c965 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/bluenviron/mediamtx/internal/auth" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/test" @@ -111,40 +110,6 @@ func TestPreflightRequest(t *testing.T) { require.Equal(t, byts, []byte{}) } -func TestConfigAuth(t *testing.T) { - cnf := tempConf(t, "api: yes\n") - - api := API{ - Address: "localhost:9997", - ReadTimeout: conf.StringDuration(10 * time.Second), - Conf: cnf, - AuthManager: &test.AuthManager{ - Func: func(req *auth.Request) error { - require.Equal(t, &auth.Request{ - User: "myuser", - Pass: "mypass", - IP: req.IP, - Action: "api", - Query: "key=val", - }, req) - return nil - }, - }, - Parent: &testParent{}, - } - err := api.Initialize() - require.NoError(t, err) - defer api.Close() - - tr := &http.Transport{} - defer tr.CloseIdleConnections() - hc := &http.Client{Transport: tr} - - var out map[string]interface{} - httpRequest(t, hc, http.MethodGet, "http://myuser:mypass@localhost:9997/v3/config/global/get?key=val", nil, &out) - require.Equal(t, true, out["api"]) -} - func TestConfigGlobalGet(t *testing.T) { cnf := tempConf(t, "api: yes\n") diff --git a/internal/auth/manager.go b/internal/auth/manager.go index a8da42a4ddf..c51541dc469 100644 --- a/internal/auth/manager.go +++ b/internal/auth/manager.go @@ -31,6 +31,17 @@ const ( jwtRefreshPeriod = 60 * 60 * time.Second ) +func addJWTFromAuthorization(rawQuery string, auth string) string { + jwt := strings.TrimPrefix(auth, "Bearer ") + if rawQuery != "" { + if v, err := url.ParseQuery(rawQuery); err == nil && v.Get("jwt") == "" { + v.Set("jwt", jwt) + return v.Encode() + } + } + return url.Values{"jwt": []string{jwt}}.Encode() +} + // Protocol is a protocol. type Protocol string @@ -51,21 +62,27 @@ type Request struct { Action conf.AuthAction // only for ActionPublish, ActionRead, ActionPlayback - Path string - Protocol Protocol - ID *uuid.UUID - Query string + Path string + Protocol Protocol + ID *uuid.UUID + Query string + + // RTSP only RTSPRequest *base.Request RTSPNonce string + + // HTTP only + HTTPRequest *http.Request } // Error is a authentication error. type Error struct { - Message string + Message string + AskCredentials bool } // Error implements the error interface. -func (e Error) Error() string { +func (e *Error) Error() string { return "authentication failed: " + e.Message } @@ -154,15 +171,6 @@ func (m *Manager) ReloadInternalUsers(u []conf.AuthInternalUser) { // Authenticate authenticates a request. func (m *Manager) Authenticate(req *Request) error { - err := m.authenticateInner(req) - if err != nil { - return Error{Message: err.Error()} - } - return nil -} - -func (m *Manager) authenticateInner(req *Request) error { - // if this is a RTSP request, fill username and password var rtspAuthHeader headers.Authorization if req.RTSPRequest != nil { @@ -175,18 +183,42 @@ func (m *Manager) authenticateInner(req *Request) error { req.User = rtspAuthHeader.Username } } + } else if req.HTTPRequest != nil { + req.User, req.Pass, _ = req.HTTPRequest.BasicAuth() + req.Query = req.HTTPRequest.URL.RawQuery + + if h := req.HTTPRequest.Header.Get("Authorization"); strings.HasPrefix(h, "Bearer ") { + // support passing username and password through Authorization header + if parts := strings.Split(strings.TrimPrefix(h, "Bearer "), ":"); len(parts) == 2 { + req.User = parts[0] + req.Pass = parts[1] + } else { + req.Query = addJWTFromAuthorization(req.Query, h) + } + } } + var err error + switch m.Method { case conf.AuthMethodInternal: - return m.authenticateInternal(req, &rtspAuthHeader) + err = m.authenticateInternal(req, &rtspAuthHeader) case conf.AuthMethodHTTP: - return m.authenticateHTTP(req) + err = m.authenticateHTTP(req) default: - return m.authenticateJWT(req) + err = m.authenticateJWT(req) + } + + if err != nil { + return &Error{ + Message: err.Error(), + AskCredentials: (req.User == "" && req.Pass == ""), + } } + + return nil } func (m *Manager) authenticateInternal(req *Request, rtspAuthHeader *headers.Authorization) error { diff --git a/internal/auth/manager_test.go b/internal/auth/manager_test.go index abb604c7d38..cdc9b718cd5 100644 --- a/internal/auth/manager_test.go +++ b/internal/auth/manager_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "net" "net/http" + "net/url" "testing" "time" @@ -186,6 +187,37 @@ func TestAuthInternalRTSPDigest(t *testing.T) { require.NoError(t, err) } +func TestAuthInternalCredentialsInBearer(t *testing.T) { + m := Manager{ + Method: conf.AuthMethodInternal, + InternalUsers: []conf.AuthInternalUser{ + { + User: "myuser", + Pass: "mypass", + IPs: conf.IPNetworks{mustParseCIDR("127.1.1.1/32")}, + Permissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, + }, + }, + HTTPAddress: "", + RTSPAuthMethods: []auth.ValidateMethod{auth.ValidateMethodDigestMD5}, + } + + err := m.Authenticate(&Request{ + IP: net.ParseIP("127.1.1.1"), + Action: conf.AuthActionPublish, + Path: "mypath", + Protocol: ProtocolRTSP, + HTTPRequest: &http.Request{ + Header: http.Header{"Authorization": []string{"Bearer myuser:mypass"}}, + URL: &url.URL{}, + }, + }) + require.NoError(t, err) +} + func TestAuthHTTP(t *testing.T) { for _, outcome := range []string{"ok", "fail"} { t.Run(outcome, func(t *testing.T) { @@ -292,78 +324,93 @@ func TestAuthJWT(t *testing.T) { // taken from // https://github.com/MicahParks/jwkset/blob/master/examples/http_server/main.go - key, err := rsa.GenerateKey(rand.Reader, 1024) - require.NoError(t, err) + for _, ca := range []string{"query", "auth header"} { + t.Run(ca, func(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) - jwk, err := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{ - Metadata: jwkset.JWKMetadataOptions{ - KID: "test-key-id", - }, - }) - require.NoError(t, err) + jwk, err := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{ + Metadata: jwkset.JWKMetadataOptions{ + KID: "test-key-id", + }, + }) + require.NoError(t, err) - jwkSet := jwkset.NewMemoryStorage() - err = jwkSet.KeyWrite(context.Background(), jwk) - require.NoError(t, err) + jwkSet := jwkset.NewMemoryStorage() + err = jwkSet.KeyWrite(context.Background(), jwk) + require.NoError(t, err) + + httpServ := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response, err2 := jwkSet.JSONPublic(r.Context()) + if err2 != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } - httpServ := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response, err2 := jwkSet.JSONPublic(r.Context()) - if err2 != nil { - w.WriteHeader(http.StatusInternalServerError) - return + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(response) + }), } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(response) - }), - } + ln, err := net.Listen("tcp", "localhost:4567") + require.NoError(t, err) - ln, err := net.Listen("tcp", "localhost:4567") - require.NoError(t, err) + go httpServ.Serve(ln) + defer httpServ.Shutdown(context.Background()) - go httpServ.Serve(ln) - defer httpServ.Shutdown(context.Background()) + type customClaims struct { + jwt.RegisteredClaims + MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"` + } - type customClaims struct { - jwt.RegisteredClaims - MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"` - } + claims := customClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "test", + Subject: "somebody", + ID: "1", + }, + MediaMTXPermissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, + } - claims := customClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Issuer: "test", - Subject: "somebody", - ID: "1", - }, - MediaMTXPermissions: []conf.AuthInternalUserPermission{{ - Action: conf.AuthActionPublish, - Path: "mypath", - }}, - } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header[jwkset.HeaderKID] = "test-key-id" + ss, err := token.SignedString(key) + require.NoError(t, err) - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header[jwkset.HeaderKID] = "test-key-id" - ss, err := token.SignedString(key) - require.NoError(t, err) + m := Manager{ + Method: conf.AuthMethodJWT, + JWTJWKS: "http://localhost:4567/jwks", + JWTClaimKey: "my_permission_key", + } - m := Manager{ - Method: conf.AuthMethodJWT, - JWTJWKS: "http://localhost:4567/jwks", - JWTClaimKey: "my_permission_key", + if ca == "query" { + err = m.Authenticate(&Request{ + IP: net.ParseIP("127.0.0.1"), + Action: conf.AuthActionPublish, + Path: "mypath", + Protocol: ProtocolRTSP, + Query: "param=value&jwt=" + ss, + }) + } else { + err = m.Authenticate(&Request{ + IP: net.ParseIP("127.0.0.1"), + Action: conf.AuthActionPublish, + Path: "mypath", + Protocol: ProtocolWebRTC, + HTTPRequest: &http.Request{ + Header: http.Header{"Authorization": []string{"Bearer " + ss}}, + URL: &url.URL{}, + }, + }) + } + require.NoError(t, err) + }) } - - err = m.Authenticate(&Request{ - User: "", - Pass: "", - IP: net.ParseIP("127.0.0.1"), - Action: conf.AuthActionPublish, - Path: "mypath", - Protocol: ProtocolRTSP, - Query: "param=value&jwt=" + ss, - }) - require.NoError(t, err) } diff --git a/internal/defs/path.go b/internal/defs/path.go index fa56a791836..28568583015 100644 --- a/internal/defs/path.go +++ b/internal/defs/path.go @@ -3,6 +3,7 @@ package defs import ( "fmt" "net" + "net/http" "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/description" @@ -35,7 +36,7 @@ type Path interface { RemoveReader(req PathRemoveReaderReq) } -// PathAccessRequest is an access request. +// PathAccessRequest is a path access request. type PathAccessRequest struct { Name string Query string @@ -43,13 +44,18 @@ type PathAccessRequest struct { SkipAuth bool // only if skipAuth = false - IP net.IP - User string - Pass string - Proto auth.Protocol - ID *uuid.UUID + User string + Pass string + IP net.IP + Proto auth.Protocol + ID *uuid.UUID + + // RTSP only RTSPRequest *base.Request RTSPNonce string + + // HTTP only + HTTPRequest *http.Request } // ToAuthRequest converts a path access request into an authentication request. @@ -70,6 +76,7 @@ func (r *PathAccessRequest) ToAuthRequest() *auth.Request { Query: r.Query, RTSPRequest: r.RTSPRequest, RTSPNonce: r.RTSPNonce, + HTTPRequest: r.HTTPRequest, } } diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 11a5af72ec6..9332e956828 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -120,17 +120,13 @@ func (m *Metrics) onRequest(ctx *gin.Context) { return } - user, pass, hasCredentials := ctx.Request.BasicAuth() - err := m.AuthManager.Authenticate(&auth.Request{ - User: user, - Pass: pass, - Query: ctx.Request.URL.RawQuery, - IP: net.ParseIP(ctx.ClientIP()), - Action: conf.AuthActionMetrics, + IP: net.ParseIP(ctx.ClientIP()), + Action: conf.AuthActionMetrics, + HTTPRequest: ctx.Request, }) if err != nil { - if !hasCredentials { + if err.(*auth.Error).AskCredentials { //nolint:errorlint ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) ctx.AbortWithStatus(http.StatusUnauthorized) return diff --git a/internal/playback/on_get_test.go b/internal/playback/on_get_test.go index 528fc72440a..38a39797bdb 100644 --- a/internal/playback/on_get_test.go +++ b/internal/playback/on_get_test.go @@ -12,7 +12,6 @@ import ( "github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio" "github.com/bluenviron/mediacommon/pkg/formats/fmp4" "github.com/bluenviron/mediacommon/pkg/formats/fmp4/seekablebuffer" - "github.com/bluenviron/mediamtx/internal/auth" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/test" "github.com/stretchr/testify/require" @@ -239,20 +238,8 @@ func TestOnGet(t *testing.T) { RecordPath: filepath.Join(dir, "%path/%Y-%m-%d_%H-%M-%S-%f"), }, }, - AuthManager: &test.AuthManager{ - Func: func(req *auth.Request) error { - require.Equal(t, &auth.Request{ - User: "myuser", - Pass: "mypass", - IP: req.IP, - Action: "playback", - Path: "mypath", - Query: req.Query, - }, req) - return nil - }, - }, - Parent: test.NilLogger, + AuthManager: test.NilAuthManager, + Parent: test.NilLogger, } err = s.Initialize() require.NoError(t, err) diff --git a/internal/playback/on_list_test.go b/internal/playback/on_list_test.go index a05843c6e3b..8d230c5d6f1 100644 --- a/internal/playback/on_list_test.go +++ b/internal/playback/on_list_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/bluenviron/mediamtx/internal/auth" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/test" "github.com/stretchr/testify/require" @@ -36,20 +35,8 @@ func TestOnList(t *testing.T) { RecordPath: filepath.Join(dir, "%path/%Y-%m-%d_%H-%M-%S-%f"), }, }, - AuthManager: &test.AuthManager{ - Func: func(req *auth.Request) error { - require.Equal(t, &auth.Request{ - User: "myuser", - Pass: "mypass", - IP: req.IP, - Action: "playback", - Query: "path=mypath", - Path: "mypath", - }, req) - return nil - }, - }, - Parent: test.NilLogger, + AuthManager: test.NilAuthManager, + Parent: test.NilLogger, } err = s.Initialize() require.NoError(t, err) diff --git a/internal/playback/server.go b/internal/playback/server.go index dca99feb41f..1550116eb97 100644 --- a/internal/playback/server.go +++ b/internal/playback/server.go @@ -2,7 +2,6 @@ package playback import ( - "errors" "net" "net/http" "sync" @@ -119,27 +118,21 @@ func (s *Server) middlewareOrigin(ctx *gin.Context) { } func (s *Server) doAuth(ctx *gin.Context, pathName string) bool { - user, pass, hasCredentials := ctx.Request.BasicAuth() - err := s.AuthManager.Authenticate(&auth.Request{ - User: user, - Pass: pass, - Query: ctx.Request.URL.RawQuery, - IP: net.ParseIP(ctx.ClientIP()), - Action: conf.AuthActionPlayback, - Path: pathName, + IP: net.ParseIP(ctx.ClientIP()), + Action: conf.AuthActionPlayback, + Path: pathName, + HTTPRequest: ctx.Request, }) if err != nil { - if !hasCredentials { + if err.(*auth.Error).AskCredentials { //nolint:errorlint ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) ctx.Writer.WriteHeader(http.StatusUnauthorized) return false } - var terr auth.Error - errors.As(err, &terr) - - s.Log(logger.Info, "connection %v failed to authenticate: %v", httpp.RemoteAddr(ctx), terr.Message) + s.Log(logger.Info, "connection %v failed to authenticate: %v", + httpp.RemoteAddr(ctx), err.(*auth.Error).Message) //nolint:errorlint // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index 2cb0164f85e..a4ac5c821fc 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -92,17 +92,13 @@ func (pp *PPROF) onRequest(ctx *gin.Context) { return } - user, pass, hasCredentials := ctx.Request.BasicAuth() - err := pp.AuthManager.Authenticate(&auth.Request{ - User: user, - Pass: pass, - Query: ctx.Request.URL.RawQuery, - IP: net.ParseIP(ctx.ClientIP()), - Action: conf.AuthActionMetrics, + IP: net.ParseIP(ctx.ClientIP()), + Action: conf.AuthActionMetrics, + HTTPRequest: ctx.Request, }) if err != nil { - if !hasCredentials { + if err.(*auth.Error).AskCredentials { //nolint:errorlint ctx.Writer.Header().Set("WWW-Authenticate", `Basic realm="mediamtx"`) ctx.Writer.WriteHeader(http.StatusUnauthorized) return diff --git a/internal/servers/hls/http_server.go b/internal/servers/hls/http_server.go index 882e3443307..ed07eb822a3 100644 --- a/internal/servers/hls/http_server.go +++ b/internal/servers/hls/http_server.go @@ -5,7 +5,6 @@ import ( "errors" "net" "net/http" - "net/url" gopath "path" "strings" "time" @@ -37,17 +36,6 @@ func mergePathAndQuery(path string, rawQuery string) string { return res } -func addJWTFromAuthorization(rawQuery string, auth string) string { - jwt := strings.TrimPrefix(auth, "Bearer ") - if rawQuery != "" { - if v, err := url.ParseQuery(rawQuery); err == nil && v.Get("jwt") == "" { - v.Set("jwt", jwt) - return v.Encode() - } - } - return url.Values{"jwt": []string{jwt}}.Encode() -} - type httpServer struct { address string encryption bool @@ -157,28 +145,19 @@ func (s *httpServer) onRequest(ctx *gin.Context) { return } - user, pass, hasCredentials := ctx.Request.BasicAuth() - - q := ctx.Request.URL.RawQuery - if h := ctx.Request.Header.Get("Authorization"); strings.HasPrefix(h, "Bearer ") { - q = addJWTFromAuthorization(q, h) - } - pathConf, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ - Name: dir, - Query: q, - Publish: false, - IP: net.ParseIP(ctx.ClientIP()), - User: user, - Pass: pass, - Proto: auth.ProtocolHLS, + Name: dir, + Publish: false, + IP: net.ParseIP(ctx.ClientIP()), + Proto: auth.ProtocolHLS, + HTTPRequest: ctx.Request, }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { - if !hasCredentials { + if terr.AskCredentials { ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) ctx.Writer.WriteHeader(http.StatusUnauthorized) return diff --git a/internal/servers/hls/server_test.go b/internal/servers/hls/server_test.go index ba8cca80a36..e00ce5b0d50 100644 --- a/internal/servers/hls/server_test.go +++ b/internal/servers/hls/server_test.go @@ -106,8 +106,6 @@ func TestServerNotFound(t *testing.T) { pm := &dummyPathManager{ findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "nonexisting", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return &conf.Path{}, nil }, addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { @@ -181,8 +179,6 @@ func TestServerRead(t *testing.T) { pm := &dummyPathManager{ findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "mystream", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return &conf.Path{}, nil }, addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { @@ -277,8 +273,6 @@ func TestServerRead(t *testing.T) { pm := &dummyPathManager{ findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "mystream", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return &conf.Path{}, nil }, addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { @@ -372,8 +366,7 @@ func TestServerReadAuthorizationHeader(t *testing.T) { require.NoError(t, err) pm := &dummyPathManager{ - findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { - require.Equal(t, "jwt=testing", req.AccessRequest.Query) + findPathConf: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { return &conf.Path{}, nil }, addReader: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { diff --git a/internal/servers/rtmp/conn.go b/internal/servers/rtmp/conn.go index b9c9f0d3fc2..310bddcb0d9 100644 --- a/internal/servers/rtmp/conn.go +++ b/internal/servers/rtmp/conn.go @@ -168,7 +168,7 @@ func (c *conn) runRead(conn *rtmp.Conn, u *url.URL) error { }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) @@ -235,7 +235,7 @@ func (c *conn) runPublish(conn *rtmp.Conn, u *url.URL) error { }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) diff --git a/internal/servers/rtmp/server_test.go b/internal/servers/rtmp/server_test.go index 78c0292e659..de7c7edcf72 100644 --- a/internal/servers/rtmp/server_test.go +++ b/internal/servers/rtmp/server_test.go @@ -68,14 +68,14 @@ type dummyPathManager struct { func (pm *dummyPathManager) AddPublisher(req defs.PathAddPublisherReq) (defs.Path, error) { if req.AccessRequest.User != "myuser" || req.AccessRequest.Pass != "mypass" { - return nil, auth.Error{} + return nil, &auth.Error{} } return pm.path, nil } func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { if req.AccessRequest.User != "myuser" || req.AccessRequest.Pass != "mypass" { - return nil, nil, auth.Error{} + return nil, nil, &auth.Error{} } return pm.path, pm.path.stream, nil } diff --git a/internal/servers/rtsp/conn.go b/internal/servers/rtsp/conn.go index 25843eabca4..4c65b4160be 100644 --- a/internal/servers/rtsp/conn.go +++ b/internal/servers/rtsp/conn.go @@ -139,7 +139,7 @@ func (c *conn) onDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx, }) if res.Err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(res.Err, &terr) { res, err := c.handleAuthError(terr) return res, nil, err diff --git a/internal/servers/rtsp/session.go b/internal/servers/rtsp/session.go index a3e7a90238f..0a1b3d76a45 100644 --- a/internal/servers/rtsp/session.go +++ b/internal/servers/rtsp/session.go @@ -125,7 +125,7 @@ func (s *session) onAnnounce(c *conn, ctx *gortsplib.ServerHandlerOnAnnounceCtx) }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { return c.handleAuthError(terr) } @@ -195,7 +195,7 @@ func (s *session) onSetup(c *conn, ctx *gortsplib.ServerHandlerOnSetupCtx, }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { res, err2 := c.handleAuthError(terr) return res, nil, err2 diff --git a/internal/servers/srt/conn.go b/internal/servers/srt/conn.go index 3c062975ee9..bc665e5b06a 100644 --- a/internal/servers/srt/conn.go +++ b/internal/servers/srt/conn.go @@ -151,7 +151,7 @@ func (c *conn) runPublish(streamID *streamID) error { }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) @@ -250,7 +250,7 @@ func (c *conn) runRead(streamID *streamID) error { }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) diff --git a/internal/servers/srt/server_test.go b/internal/servers/srt/server_test.go index 21bfea8291c..f6461375f7b 100644 --- a/internal/servers/srt/server_test.go +++ b/internal/servers/srt/server_test.go @@ -66,14 +66,14 @@ type dummyPathManager struct { func (pm *dummyPathManager) AddPublisher(req defs.PathAddPublisherReq) (defs.Path, error) { if req.AccessRequest.User != "myuser" || req.AccessRequest.Pass != "mypass" { - return nil, auth.Error{} + return nil, &auth.Error{} } return pm.path, nil } func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { if req.AccessRequest.User != "myuser" || req.AccessRequest.Pass != "mypass" { - return nil, nil, auth.Error{} + return nil, nil, &auth.Error{} } return pm.path, pm.path.stream, nil } diff --git a/internal/servers/webrtc/http_server.go b/internal/servers/webrtc/http_server.go index 7872de05f85..5ba0b698fbc 100644 --- a/internal/servers/webrtc/http_server.go +++ b/internal/servers/webrtc/http_server.go @@ -7,7 +7,6 @@ import ( "io" "net" "net/http" - "net/url" "regexp" "strings" "time" @@ -60,17 +59,6 @@ func sessionLocation(publish bool, path string, secret uuid.UUID) string { return ret } -func addJWTFromAuthorization(rawQuery string, auth string) string { - jwt := strings.TrimPrefix(auth, "Bearer ") - if rawQuery != "" { - if v, err := url.ParseQuery(rawQuery); err == nil && v.Get("jwt") == "" { - v.Set("jwt", jwt) - return v.Encode() - } - } - return url.Values{"jwt": []string{jwt}}.Encode() -} - type httpServer struct { address string encryption bool @@ -120,35 +108,19 @@ func (s *httpServer) close() { } func (s *httpServer) checkAuthOutsideSession(ctx *gin.Context, pathName string, publish bool) bool { - user, pass, hasCredentials := ctx.Request.BasicAuth() - q := ctx.Request.URL.RawQuery - - if h := ctx.Request.Header.Get("Authorization"); strings.HasPrefix(h, "Bearer ") { - // JWT in authorization bearer -> JWT in query parameters - q = addJWTFromAuthorization(q, h) - - // credentials in authorization bearer -> credentials in authorization basic - if parts := strings.Split(strings.TrimPrefix(h, "Bearer "), ":"); len(parts) == 2 { - user = parts[0] - pass = parts[1] - } - } - _, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ - Name: pathName, - Query: q, - Publish: publish, - IP: net.ParseIP(ctx.ClientIP()), - User: user, - Pass: pass, - Proto: auth.ProtocolWebRTC, + Name: pathName, + Publish: publish, + IP: net.ParseIP(ctx.ClientIP()), + Proto: auth.ProtocolWebRTC, + HTTPRequest: ctx.Request, }, }) if err != nil { - var terr auth.Error + var terr *auth.Error if errors.As(err, &terr) { - if !hasCredentials { + if terr.AskCredentials { ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) ctx.Writer.WriteHeader(http.StatusUnauthorized) return false @@ -200,30 +172,31 @@ func (s *httpServer) onWHIPPost(ctx *gin.Context, pathName string, publish bool) return } - user, pass, _ := ctx.Request.BasicAuth() - q := ctx.Request.URL.RawQuery + res := s.parent.newSession(webRTCNewSessionReq{ + pathName: pathName, + remoteAddr: httpp.RemoteAddr(ctx), + offer: offer, + publish: publish, + httpRequest: ctx.Request, + }) + if res.err != nil { + var terr *auth.Error + if errors.As(err, &terr) { + if terr.AskCredentials { + ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + + s.Log(logger.Info, "connection %v failed to authenticate: %v", httpp.RemoteAddr(ctx), terr.Message) - if h := ctx.Request.Header.Get("Authorization"); strings.HasPrefix(h, "Bearer ") { - // JWT in authorization bearer -> JWT in query parameters - q = addJWTFromAuthorization(q, h) + // wait some seconds to mitigate brute force attacks + <-time.After(auth.PauseAfterError) - // credentials in authorization bearer -> credentials in authorization basic - if parts := strings.Split(strings.TrimPrefix(h, "Bearer "), ":"); len(parts) == 2 { - user = parts[0] - pass = parts[1] + writeError(ctx, http.StatusUnauthorized, terr) + return } - } - res := s.parent.newSession(webRTCNewSessionReq{ - pathName: pathName, - remoteAddr: httpp.RemoteAddr(ctx), - query: q, - user: user, - pass: pass, - offer: offer, - publish: publish, - }) - if res.err != nil { writeError(ctx, res.errStatusCode, res.err) return } diff --git a/internal/servers/webrtc/server.go b/internal/servers/webrtc/server.go index a04fb3fdd62..64c1fc478f9 100644 --- a/internal/servers/webrtc/server.go +++ b/internal/servers/webrtc/server.go @@ -133,14 +133,12 @@ type webRTCNewSessionRes struct { } type webRTCNewSessionReq struct { - pathName string - remoteAddr string - query string - user string - pass string - offer []byte - publish bool - res chan webRTCNewSessionRes + pathName string + remoteAddr string + offer []byte + publish bool + httpRequest *http.Request + res chan webRTCNewSessionRes } type webRTCAddSessionCandidatesRes struct { diff --git a/internal/servers/webrtc/server_test.go b/internal/servers/webrtc/server_test.go index f421d0c947a..19f37648a34 100644 --- a/internal/servers/webrtc/server_test.go +++ b/internal/servers/webrtc/server_test.go @@ -96,9 +96,7 @@ func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *st func initializeTestServer(t *testing.T) *Server { pm := &dummyPathManager{ - findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) + findPathConf: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { return &conf.Path{}, nil }, } @@ -182,9 +180,7 @@ func TestPreflightRequest(t *testing.T) { func TestServerOptionsICEServer(t *testing.T) { pathManager := &dummyPathManager{ - findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) + findPathConf: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { return &conf.Path{}, nil }, } @@ -249,14 +245,10 @@ func TestServerPublish(t *testing.T) { pathManager := &dummyPathManager{ findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "teststream", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return &conf.Path{}, nil }, addPublisher: func(req defs.PathAddPublisherReq) (defs.Path, error) { require.Equal(t, "teststream", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return path, nil }, } @@ -534,14 +526,10 @@ func TestServerRead(t *testing.T) { pathManager := &dummyPathManager{ findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "teststream", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return &conf.Path{}, nil }, addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { require.Equal(t, "teststream", req.AccessRequest.Name) - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) return path, str, nil }, } @@ -632,167 +620,9 @@ func TestServerRead(t *testing.T) { } } -func TestServerReadAuthorizationBearerJWT(t *testing.T) { - desc := &description.Session{Medias: []*description.Media{test.MediaH264}} - - str, err := stream.New( - 512, - 1460, - desc, - true, - test.NilLogger, - ) - require.NoError(t, err) - - path := &dummyPath{stream: str} - - pm := &dummyPathManager{ - findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { - require.Equal(t, "jwt=testing", req.AccessRequest.Query) - return &conf.Path{}, nil - }, - addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { - require.Equal(t, "jwt=testing", req.AccessRequest.Query) - return path, str, nil - }, - } - - s := &Server{ - Address: "127.0.0.1:8886", - Encryption: false, - ServerKey: "", - ServerCert: "", - AllowOrigin: "", - TrustedProxies: conf.IPNetworks{}, - ReadTimeout: conf.StringDuration(10 * time.Second), - LocalUDPAddress: "127.0.0.1:8887", - LocalTCPAddress: "127.0.0.1:8887", - IPsFromInterfaces: true, - IPsFromInterfacesList: []string{}, - AdditionalHosts: []string{}, - ICEServers: []conf.WebRTCICEServer{}, - HandshakeTimeout: conf.StringDuration(10 * time.Second), - TrackGatherTimeout: conf.StringDuration(2 * time.Second), - ExternalCmdPool: nil, - PathManager: pm, - Parent: test.NilLogger, - } - err = s.Initialize() - require.NoError(t, err) - defer s.Close() - - tr := &http.Transport{} - defer tr.CloseIdleConnections() - hc := &http.Client{Transport: tr} - - pc, err := pwebrtc.NewPeerConnection(pwebrtc.Configuration{}) - require.NoError(t, err) - defer pc.Close() //nolint:errcheck - - _, err = pc.AddTransceiverFromKind(pwebrtc.RTPCodecTypeVideo) - require.NoError(t, err) - - offer, err := pc.CreateOffer(nil) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodPost, - "http://localhost:8886/teststream/whep", bytes.NewReader([]byte(offer.SDP))) - require.NoError(t, err) - - req.Header.Set("Content-Type", "application/sdp") - req.Header.Set("Authorization", "Bearer testing") - - res, err := hc.Do(req) - require.NoError(t, err) - defer res.Body.Close() - - require.Equal(t, http.StatusCreated, res.StatusCode) -} - -func TestServerReadAuthorizationUserPass(t *testing.T) { - desc := &description.Session{Medias: []*description.Media{test.MediaH264}} - - str, err := stream.New( - 512, - 1460, - desc, - true, - test.NilLogger, - ) - require.NoError(t, err) - - path := &dummyPath{stream: str} - - pm := &dummyPathManager{ - findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) - return &conf.Path{}, nil - }, - addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) - return path, str, nil - }, - } - - s := &Server{ - Address: "127.0.0.1:8886", - Encryption: false, - ServerKey: "", - ServerCert: "", - AllowOrigin: "", - TrustedProxies: conf.IPNetworks{}, - ReadTimeout: conf.StringDuration(10 * time.Second), - LocalUDPAddress: "127.0.0.1:8887", - LocalTCPAddress: "127.0.0.1:8887", - IPsFromInterfaces: true, - IPsFromInterfacesList: []string{}, - AdditionalHosts: []string{}, - ICEServers: []conf.WebRTCICEServer{}, - HandshakeTimeout: conf.StringDuration(10 * time.Second), - TrackGatherTimeout: conf.StringDuration(2 * time.Second), - ExternalCmdPool: nil, - PathManager: pm, - Parent: test.NilLogger, - } - err = s.Initialize() - require.NoError(t, err) - defer s.Close() - - tr := &http.Transport{} - defer tr.CloseIdleConnections() - hc := &http.Client{Transport: tr} - - pc, err := pwebrtc.NewPeerConnection(pwebrtc.Configuration{}) - require.NoError(t, err) - defer pc.Close() //nolint:errcheck - - _, err = pc.AddTransceiverFromKind(pwebrtc.RTPCodecTypeVideo) - require.NoError(t, err) - - offer, err := pc.CreateOffer(nil) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodPost, - "http://localhost:8886/teststream/whep", bytes.NewReader([]byte(offer.SDP))) - require.NoError(t, err) - - req.Header.Set("Content-Type", "application/sdp") - req.Header.Set("Authorization", "Bearer myuser:mypass") - - res, err := hc.Do(req) - require.NoError(t, err) - defer res.Body.Close() - - require.Equal(t, http.StatusCreated, res.StatusCode) -} - func TestServerReadNotFound(t *testing.T) { pm := &dummyPathManager{ - findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) { - require.Equal(t, "myuser", req.AccessRequest.User) - require.Equal(t, "mypass", req.AccessRequest.Pass) + findPathConf: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { return &conf.Path{}, nil }, addReader: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { diff --git a/internal/servers/webrtc/session.go b/internal/servers/webrtc/session.go index 6bee1909407..9298282dae4 100644 --- a/internal/servers/webrtc/session.go +++ b/internal/servers/webrtc/session.go @@ -129,25 +129,15 @@ func (s *session) runPublish() (int, error) { path, err := s.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: s, AccessRequest: defs.PathAccessRequest{ - Name: s.req.pathName, - Query: s.req.query, - Publish: true, - IP: net.ParseIP(ip), - User: s.req.user, - Pass: s.req.pass, - Proto: auth.ProtocolWebRTC, - ID: &s.uuid, + Name: s.req.pathName, + Publish: true, + IP: net.ParseIP(ip), + Proto: auth.ProtocolWebRTC, + ID: &s.uuid, + HTTPRequest: s.req.httpRequest, }, }) if err != nil { - var terr auth.Error - if errors.As(err, &terr) { - // wait some seconds to mitigate brute force attacks - <-time.After(auth.PauseAfterError) - - return http.StatusUnauthorized, err - } - return http.StatusBadRequest, err } @@ -250,23 +240,14 @@ func (s *session) runRead() (int, error) { path, stream, err := s.pathManager.AddReader(defs.PathAddReaderReq{ Author: s, AccessRequest: defs.PathAccessRequest{ - Name: s.req.pathName, - Query: s.req.query, - IP: net.ParseIP(ip), - User: s.req.user, - Pass: s.req.pass, - Proto: auth.ProtocolWebRTC, - ID: &s.uuid, + Name: s.req.pathName, + IP: net.ParseIP(ip), + Proto: auth.ProtocolWebRTC, + ID: &s.uuid, + HTTPRequest: s.req.httpRequest, }, }) if err != nil { - var terr1 auth.Error - if errors.As(err, &terr1) { - // wait some seconds to mitigate brute force attacks - <-time.After(auth.PauseAfterError) - return http.StatusUnauthorized, err - } - var terr2 defs.PathNoOnePublishingError if errors.As(err, &terr2) { return http.StatusNotFound, err @@ -338,7 +319,7 @@ func (s *session) runRead() (int, error) { Conf: path.SafeConf(), ExternalCmdEnv: path.ExternalCmdEnv(), Reader: s.APIReaderDescribe(), - Query: s.req.query, + Query: s.req.httpRequest.URL.RawQuery, }) defer onUnreadHook() @@ -451,7 +432,7 @@ func (s *session) apiItem() *defs.APIWebRTCSession { return defs.APIWebRTCSessionStateRead }(), Path: s.req.pathName, - Query: s.req.query, + Query: s.req.httpRequest.URL.RawQuery, BytesReceived: bytesReceived, BytesSent: bytesSent, } diff --git a/internal/test/auth_manager.go b/internal/test/auth_manager.go index be23ede8459..f59666b78a8 100644 --- a/internal/test/auth_manager.go +++ b/internal/test/auth_manager.go @@ -4,17 +4,17 @@ import "github.com/bluenviron/mediamtx/internal/auth" // AuthManager is a test auth manager. type AuthManager struct { - Func func(req *auth.Request) error + fnc func(req *auth.Request) error } // Authenticate replicates auth.Manager.Replicate func (m *AuthManager) Authenticate(req *auth.Request) error { - return m.Func(req) + return m.fnc(req) } // NilAuthManager is an auth manager that accepts everything. var NilAuthManager = &AuthManager{ - Func: func(_ *auth.Request) error { + fnc: func(_ *auth.Request) error { return nil }, }