Skip to content

Commit

Permalink
webrtc, hls: support passing JWT through Authorization header (#3248) (
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Jun 11, 2024
1 parent 80a133a commit caa9fa6
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 53 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1180,12 +1180,20 @@ The JWT is expected to contain the `mediamtx_permissions` scope, with a list of
}
```
Clients are expected to pass the JWT in query parameters, for instance:
Clients are expected to pass the JWT in the Authorization header (in case of HLS and WebRTC) or in query parameters (in case of any other protocol), for instance (RTSP):
```
ffmpeg -re -stream_loop -1 -i file.ts -c copy -f rtsp rtsp://localhost:8554/mystream?jwt=MY_JWT
```
For instance (HLS):
```
GET /mypath/index.m3u8 HTTP/1.1
Host: example.com
Authorization: Bearer MY_JWT
```
Here's a tutorial on how to setup the [Keycloak identity server](https://www.keycloak.org/) in order to provide such JWTs:
1. Start Keycloak:
Expand Down
19 changes: 18 additions & 1 deletion internal/servers/hls/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"net"
"net/http"
"net/url"
gopath "path"
"strings"
"time"
Expand Down Expand Up @@ -36,6 +37,17 @@ func mergePathAndQuery(path string, rawQuery string) string {
return res
}

func addJWTFromAuthorization(rawQuery string, auth string) string {
jwt := strings.TrimPrefix(auth, "Bearer ")
if rawQuery != "" {
if v, err := url.ParseQuery(rawQuery); err == nil && v.Get("jwt") == "" {
v.Set("jwt", jwt)
return v.Encode()
}
}
return url.Values{"jwt": []string{jwt}}.Encode()
}

type httpServer struct {
address string
encryption bool
Expand Down Expand Up @@ -145,10 +157,15 @@ func (s *httpServer) onRequest(ctx *gin.Context) {

user, pass, hasCredentials := ctx.Request.BasicAuth()

q := ctx.Request.URL.RawQuery
if h := ctx.Request.Header.Get("Authorization"); strings.HasPrefix(h, "Bearer ") {
q = addJWTFromAuthorization(q, h)
}

pathConf, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{
AccessRequest: defs.PathAccessRequest{
Name: dir,
Query: ctx.Request.URL.RawQuery,
Query: q,
Publish: false,
IP: net.ParseIP(ctx.ClientIP()),
User: user,
Expand Down
173 changes: 151 additions & 22 deletions internal/servers/hls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/bluenviron/gohlslib"
"github.com/bluenviron/gohlslib/pkg/codecs"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/mediamtx/internal/auth"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/defs"
"github.com/bluenviron/mediamtx/internal/externalcmd"
Expand Down Expand Up @@ -49,21 +48,16 @@ func (pa *dummyPath) RemoveReader(_ defs.PathRemoveReaderReq) {
}

type dummyPathManager struct {
stream *stream.Stream
findPathConf func(req defs.PathFindPathConfReq) (*conf.Path, error)
addReader func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error)
}

func (pm *dummyPathManager) FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) {
if req.AccessRequest.User != "myuser" || req.AccessRequest.Pass != "mypass" {
return nil, auth.Error{}
}
return &conf.Path{}, nil
return pm.findPathConf(req)
}

func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) {
if req.AccessRequest.Name == "nonexisting" {
return nil, nil, fmt.Errorf("not found")
}
return &dummyPath{}, pm.stream, nil
return pm.addReader(req)
}

func TestServerNotFound(t *testing.T) {
Expand All @@ -72,6 +66,19 @@ func TestServerNotFound(t *testing.T) {
"always remux on",
} {
t.Run(ca, func(t *testing.T) {
pm := &dummyPathManager{
findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) {
require.Equal(t, "nonexisting", req.AccessRequest.Name)
require.Equal(t, "myuser", req.AccessRequest.User)
require.Equal(t, "mypass", req.AccessRequest.Pass)
return &conf.Path{}, nil
},
addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) {
require.Equal(t, "nonexisting", req.AccessRequest.Name)
return nil, nil, fmt.Errorf("not found")
},
}

s := &Server{
Address: "127.0.0.1:8888",
Encryption: false,
Expand All @@ -88,7 +95,7 @@ func TestServerNotFound(t *testing.T) {
Directory: "",
ReadTimeout: conf.StringDuration(10 * time.Second),
WriteQueueSize: 512,
PathManager: &dummyPathManager{},
PathManager: pm,
Parent: test.NilLogger,
}
err := s.Initialize()
Expand Down Expand Up @@ -126,15 +133,26 @@ func TestServerRead(t *testing.T) {
t.Run("always remux off", func(t *testing.T) {
desc := &description.Session{Medias: []*description.Media{test.MediaH264}}

stream, err := stream.New(
str, err := stream.New(
1460,
desc,
true,
test.NilLogger,
)
require.NoError(t, err)

pathManager := &dummyPathManager{stream: stream}
pm := &dummyPathManager{
findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) {
require.Equal(t, "mystream", req.AccessRequest.Name)
require.Equal(t, "myuser", req.AccessRequest.User)
require.Equal(t, "mypass", req.AccessRequest.Pass)
return &conf.Path{}, nil
},
addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) {
require.Equal(t, "mystream", req.AccessRequest.Name)
return &dummyPath{}, str, nil
},
}

s := &Server{
Address: "127.0.0.1:8888",
Expand All @@ -152,7 +170,7 @@ func TestServerRead(t *testing.T) {
Directory: "",
ReadTimeout: conf.StringDuration(10 * time.Second),
WriteQueueSize: 512,
PathManager: pathManager,
PathManager: pm,
Parent: test.NilLogger,
}
err = s.Initialize()
Expand Down Expand Up @@ -192,7 +210,7 @@ func TestServerRead(t *testing.T) {
go func() {
time.Sleep(100 * time.Millisecond)
for i := 0; i < 4; i++ {
stream.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
Base: unit.Base{
NTP: time.Time{},
PTS: time.Duration(i) * time.Second,
Expand All @@ -210,15 +228,26 @@ func TestServerRead(t *testing.T) {
t.Run("always remux on", func(t *testing.T) {
desc := &description.Session{Medias: []*description.Media{test.MediaH264}}

stream, err := stream.New(
str, err := stream.New(
1460,
desc,
true,
test.NilLogger,
)
require.NoError(t, err)

pathManager := &dummyPathManager{stream: stream}
pm := &dummyPathManager{
findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) {
require.Equal(t, "mystream", req.AccessRequest.Name)
require.Equal(t, "myuser", req.AccessRequest.User)
require.Equal(t, "mypass", req.AccessRequest.Pass)
return &conf.Path{}, nil
},
addReader: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) {
require.Equal(t, "mystream", req.AccessRequest.Name)
return &dummyPath{}, str, nil
},
}

s := &Server{
Address: "127.0.0.1:8888",
Expand All @@ -236,7 +265,7 @@ func TestServerRead(t *testing.T) {
Directory: "",
ReadTimeout: conf.StringDuration(10 * time.Second),
WriteQueueSize: 512,
PathManager: pathManager,
PathManager: pm,
Parent: test.NilLogger,
}
err = s.Initialize()
Expand All @@ -248,7 +277,7 @@ func TestServerRead(t *testing.T) {
time.Sleep(100 * time.Millisecond)

for i := 0; i < 4; i++ {
stream.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
Base: unit.Base{
NTP: time.Time{},
PTS: time.Duration(i) * time.Second,
Expand Down Expand Up @@ -293,22 +322,122 @@ func TestServerRead(t *testing.T) {
})
}

func TestServerReadAuthorizationHeader(t *testing.T) {
desc := &description.Session{Medias: []*description.Media{test.MediaH264}}

str, err := stream.New(
1460,
desc,
true,
test.NilLogger,
)
require.NoError(t, err)

pm := &dummyPathManager{
findPathConf: func(req defs.PathFindPathConfReq) (*conf.Path, error) {
require.Equal(t, "jwt=testing", req.AccessRequest.Query)
return &conf.Path{}, nil
},
addReader: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) {
return &dummyPath{}, str, nil
},
}

s := &Server{
Address: "127.0.0.1:8888",
Encryption: false,
ServerKey: "",
ServerCert: "",
AlwaysRemux: true,
Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS),
SegmentCount: 7,
SegmentDuration: conf.StringDuration(1 * time.Second),
PartDuration: conf.StringDuration(200 * time.Millisecond),
SegmentMaxSize: 50 * 1024 * 1024,
AllowOrigin: "",
TrustedProxies: conf.IPNetworks{},
Directory: "",
ReadTimeout: conf.StringDuration(10 * time.Second),
WriteQueueSize: 512,
PathManager: pm,
Parent: test.NilLogger,
}
err = s.Initialize()
require.NoError(t, err)
defer s.Close()

s.PathReady(&dummyPath{})

time.Sleep(100 * time.Millisecond)

for i := 0; i < 4; i++ {
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
Base: unit.Base{
NTP: time.Time{},
PTS: time.Duration(i) * time.Second,
},
AU: [][]byte{
{5, 1}, // IDR
},
})
}

c := &gohlslib.Client{
URI: "http://127.0.0.1:8888/mystream/index.m3u8",
OnRequest: func(r *http.Request) {
r.Header.Set("Authorization", "Bearer testing")
},
}

recv := make(chan struct{})

c.OnTracks = func(tracks []*gohlslib.Track) error {
require.Equal(t, []*gohlslib.Track{{
Codec: &codecs.H264{},
}}, tracks)

c.OnDataH26x(tracks[0], func(pts, dts time.Duration, au [][]byte) {
require.Equal(t, time.Duration(0), pts)
require.Equal(t, time.Duration(0), dts)
require.Equal(t, [][]byte{
test.FormatH264.SPS,
test.FormatH264.PPS,
{5, 1},
}, au)
close(recv)
})

return nil
}

err = c.Start()
require.NoError(t, err)
defer func() { <-c.Wait() }()
defer c.Close()

<-recv
}

func TestDirectory(t *testing.T) {
dir, err := os.MkdirTemp("", "mediamtx-playback")
require.NoError(t, err)
defer os.RemoveAll(dir)

desc := &description.Session{Medias: []*description.Media{test.MediaH264}}

stream, err := stream.New(
str, err := stream.New(
1460,
desc,
true,
test.NilLogger,
)
require.NoError(t, err)

pathManager := &dummyPathManager{stream: stream}
pm := &dummyPathManager{
addReader: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) {
return &dummyPath{}, str, nil
},
}

s := &Server{
Address: "127.0.0.1:8888",
Expand All @@ -326,7 +455,7 @@ func TestDirectory(t *testing.T) {
Directory: filepath.Join(dir, "mydir"),
ReadTimeout: conf.StringDuration(10 * time.Second),
WriteQueueSize: 512,
PathManager: pathManager,
PathManager: pm,
Parent: test.NilLogger,
}
err = s.Initialize()
Expand Down
Loading

0 comments on commit caa9fa6

Please sign in to comment.