Skip to content

Commit

Permalink
Add a quota handler callback
Browse files Browse the repository at this point in the history
  • Loading branch information
rg0now committed Dec 5, 2024
1 parent 37ba0b5 commit 14ea578
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 1 deletion.
3 changes: 2 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ type Request struct {
NonceHash *NonceHash

// User Configuration
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)
QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool)

Log logging.LeveledLogger
Realm string
Expand Down
4 changes: 4 additions & 0 deletions internal/server/turn.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ func handleAllocateRequest(r Request, m *stun.Message) error {
// server is free to define this allocation quota any way it wishes,
// but SHOULD define it based on the username used to authenticate
// the request, and not on the client's transport address.
if r.QuotaHandler != nil && !r.QuotaHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr) {
quotaReachedMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached})
return buildAndSend(r.Conn, r.SrcAddr, quotaReachedMsg...)
}

// 8. Also at any point, the server MAY choose to reject the request
// with a 300 (Try Alternate) error if it wishes to redirect the
Expand Down
3 changes: 3 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
type Server struct {
log logging.LeveledLogger
authHandler AuthHandler
quotaHandler QuotaHandler
realm string
channelBindTimeout time.Duration
nonceHash *server.NonceHash
Expand Down Expand Up @@ -61,6 +62,7 @@ func NewServer(config ServerConfig) (*Server, error) {
s := &Server{
log: loggerFactory.NewLogger("turn"),
authHandler: config.AuthHandler,
quotaHandler: config.QuotaHandler,
realm: config.Realm,
channelBindTimeout: config.ChannelBindTimeout,
packetConnConfigs: config.PacketConnConfigs,
Expand Down Expand Up @@ -224,6 +226,7 @@ func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manage
Buff: buf[:n],
Log: s.log,
AuthHandler: s.authHandler,
QuotaHandler: s.quotaHandler,
Realm: s.realm,
AllocationManager: allocationManager,
ChannelBindTimeout: s.channelBindTimeout,
Expand Down
6 changes: 6 additions & 0 deletions server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ func genericEventHandler(handlers EventHandlers) allocation.EventHandler {
}
}

// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is exceeded. If the callback returns true the allocation request is accepted, otherwise it is rejected and a 486 (Allocation Quota Reached) error is returned to the user.
type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool)

// ServerConfig configures the Pion TURN Server
type ServerConfig struct {
// PacketConnConfigs and ListenerConfigs are a list of all the turn listeners
Expand All @@ -197,6 +200,9 @@ type ServerConfig struct {
// AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior
AuthHandler AuthHandler

// AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior
QuotaHandler QuotaHandler

// EventHandlers is a set of callbacks for tracking allocation lifecycle.
EventHandlers EventHandlers

Expand Down
52 changes: 52 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,58 @@ func TestSTUNOnly(t *testing.T) {
assert.Equal(t, err.Error(), "Allocate error response (error 400: )")
}

func TestQuotaReached(t *testing.T) {
serverAddr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:3478")
assert.NoError(t, err)

serverConn, err := net.ListenPacket(serverAddr.Network(), serverAddr.String())
assert.NoError(t, err)

defer serverConn.Close() //nolint:errcheck

credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")}
server, err := NewServer(ServerConfig{
AuthHandler: func(username, _ string, _ net.Addr) (key []byte, ok bool) {
if pw, ok := credMap[username]; ok {
return pw, true
}
return nil, false
},
QuotaHandler: func(_, _ string, _ net.Addr) (ok bool) { return false },
Realm: "pion.ly",
PacketConnConfigs: []PacketConnConfig{{
PacketConn: serverConn,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
}},
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)

defer server.Close() //nolint:errcheck

conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
assert.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: conn,
STUNServerAddr: "127.0.0.1:3478",
TURNServerAddr: "127.0.0.1:3478",
Username: "user",
Password: "pass",
Realm: "pion.ly",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)
assert.NoError(t, client.Listen())
defer client.Close()

_, err = client.Allocate()
assert.Equal(t, err.Error(), "Allocate error response (error 486: )")
}

func RunBenchmarkServer(b *testing.B, clientNum int) {
loggerFactory := logging.NewDefaultLoggerFactory()
credMap := map[string][]byte{
Expand Down

0 comments on commit 14ea578

Please sign in to comment.