Skip to content

Commit

Permalink
a better fix for #21; track ncs in a bitset which allows for detectin…
Browse files Browse the repository at this point in the history
…g replays even if ncs arrive out of order
  • Loading branch information
Kevin Manley committed Jan 8, 2016
1 parent 538147e commit e837317
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 54 deletions.
72 changes: 72 additions & 0 deletions bitset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Package bitset implments a memory efficient bit array of booleans
// Adapted from https://github.com/lazybeaver/bitset

package auth

import "fmt"

type BitSet struct {
bits []uint8
size uint64
}

const (
bitMaskZero = uint8(0)
bitMaskOnes = uint8((1 << 8) - 1)
)

var (
bitMasks = [...]uint8{0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80}
)

func (b *BitSet) getPositionAndMask(index uint64) (uint64, uint8) {
if index < 0 || index >= b.size {
panic(fmt.Errorf("BitSet index (%d) out of bounds (size: %d)", index, b.size))
}
position := index >> 3
mask := bitMasks[index%8]
return position, mask
}

func (b *BitSet) Init(size uint64) {
b.bits = make([]uint8, (size+7)/8)
b.size = size
}

func (b *BitSet) Size() uint64 {
return b.size
}

func (b *BitSet) Get(index uint64) bool {
position, mask := b.getPositionAndMask(index)
return (b.bits[position] & mask) != 0
}

func (b *BitSet) Set(index uint64) {
position, mask := b.getPositionAndMask(index)
b.bits[position] |= mask
}

func (b *BitSet) Clear(index uint64) {
position, mask := b.getPositionAndMask(index)
b.bits[position] &^= mask
}

func (b *BitSet) String() string {
value := make([]byte, b.size)
var i uint64
for i = 0; i < b.size; i++ {
if b.Get(i) {
value[i] = '1'
} else {
value[i] = '0'
}
}
return string(value)
}

func NewBitSet(size uint64) *BitSet {
b := &BitSet{}
b.Init(size)
return b
}
79 changes: 79 additions & 0 deletions bitset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package auth

import (
"testing"
)

func TestNew(t *testing.T) {
var size uint64 = 101
bs := NewBitSet(size)
if bs.Size() != size {
t.Errorf("Unexpected initialization failure")
}
var i uint64
for i = 0; i < size; i++ {
if bs.Get(i) {
t.Errorf("Newly initialized bitset cannot have true values")
}
}
}

func TestGet(t *testing.T) {
bs := NewBitSet(2)
bs.Set(0)
bs.Clear(1)
if bs.Get(0) != true {
t.Errorf("Actual: false | Expected: true")
}
if bs.Get(1) != false {
t.Errorf("Actual: true | Expected: false")
}
}

func TestSet(t *testing.T) {
bs := NewBitSet(10)
bs.Set(2)
bs.Set(3)
bs.Set(5)
bs.Set(7)
actual := bs.String()
expected := "0011010100"
if actual != expected {
t.Errorf("Actual: %s | Expected: %s", actual, expected)
}
}

func TestClear(t *testing.T) {
bs := NewBitSet(10)
var i uint64
for i = 0; i < 10; i++ {
bs.Set(i)
}
bs.Clear(0)
bs.Clear(3)
bs.Clear(6)
bs.Clear(9)
actual := bs.String()
expected := "0110110110"
if actual != expected {
t.Errorf("Actual: %s | Expected: %s", actual, expected)
}
}

func BenchmarkGet(b *testing.B) {
bn := uint64(b.N)
bs := NewBitSet(bn)
var i uint64
for i = 0; i < bn; i++ {
_ = bs.Get(i)
}
}

func BenchmarkSet(b *testing.B) {
bn := uint64(b.N)
bs := NewBitSet(bn)
var i uint64
for i = 0; i < bn; i++ {
bs.Set(i)
}
}
82 changes: 55 additions & 27 deletions digest.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@ import (
"golang.org/x/net/context"
)

const DefaultNcCacheSize = 65536

type digest_client struct {
nc uint64
/*
ncs_seen is a bitset used to record the nc values we've seen for a given nonce.
This allows us to identify and deny replay attacks without relying on nc values
always increasing. That's important since in practice a client's use of multiple
server connections, a hierarchy of proxies, and AJAX can cause nc values to arrive
out of order (See https://github.com/abbot/go-http-auth/issues/21)
*/
ncs_seen *BitSet
last_seen int64
}

Expand All @@ -25,7 +34,7 @@ type DigestAuth struct {
Opaque string
Secrets SecretProvider
PlainTextSecrets bool
IgnoreNonceCount bool
NcCacheSize uint64 // The max number of nc values we remember before issuing a new nonce

/*
Approximate size of Client's Cache. When actual number of
Expand Down Expand Up @@ -81,18 +90,21 @@ func (a *DigestAuth) Purge(count int) {
http.Handler for DigestAuth which initiates the authentication process
(or requires reauthentication).
*/
func (a *DigestAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
func (a *DigestAuth) RequireAuth(w http.ResponseWriter, r *http.Request, stale bool) {
a.mutex.Lock()
defer a.mutex.Unlock()

if len(a.clients) > a.ClientCacheSize+a.ClientCacheTolerance {
a.Purge(a.ClientCacheTolerance * 2)
}
nonce := RandomKey()
a.clients[nonce] = &digest_client{nc: 0, last_seen: time.Now().UnixNano()}
w.Header().Set(AuthenticateHeaderName(a.IsProxy),
fmt.Sprintf(`Digest realm="%s", nonce="%s", opaque="%s", algorithm="MD5", qop="auth"`,
a.Realm, nonce, a.Opaque))
a.clients[nonce] = &digest_client{ncs_seen: NewBitSet(a.NcCacheSize),
last_seen: time.Now().UnixNano()}
value := fmt.Sprintf(`Digest realm="%s", nonce="%s", opaque="%s", algorithm="MD5", qop="auth"`, a.Realm, nonce, a.Opaque)
if stale {
value += ", stale=true"
}
w.Header().Set(AuthenticateHeaderName(a.IsProxy), value)
http.Error(w, UnauthorizedStatusText(a.IsProxy), UnauthorizedStatusCode(a.IsProxy))
}

Expand All @@ -111,16 +123,18 @@ func (a *DigestAuth) DigestAuthParams(r *http.Request) map[string]string {
}

/*
Check if request contains valid authentication data. Returns a pair
of username, authinfo where username is the name of the authenticated
user or an empty string and authinfo is the contents for the optional
Authentication-Info response header.
Check if request contains valid authentication data. Returns a triplet
of username, authinfo, stale where username is the name of the authenticated
user or an empty string, authinfo is the contents for the optional Authentication-Info
response header, and stale indicates whether the server-returned Authenticate header
should specify stale=true (see https://www.ietf.org/rfc/rfc2617.txt Section 3.3)
*/
func (da *DigestAuth) CheckAuth(r *http.Request) (username string, authinfo *string) {
func (da *DigestAuth) CheckAuth(r *http.Request) (username string, authinfo *string, stale bool) {
da.mutex.Lock()
defer da.mutex.Unlock()
username = ""
authinfo = nil
stale = false
auth := da.DigestAuthParams(r)
if auth == nil || da.Opaque != auth["opaque"] || auth["algorithm"] != "MD5" || auth["qop"] != "auth" {
return
Expand Down Expand Up @@ -182,21 +196,30 @@ func (da *DigestAuth) CheckAuth(r *http.Request) (username string, authinfo *str
return
}

if client, ok := da.clients[auth["nonce"]]; !ok {
client, ok := da.clients[auth["nonce"]]
if !ok {
stale = true
return
}

// Check the nonce-count
if nc >= client.ncs_seen.Size() {
// nc exceeds the size of our bitset. We can just treat this the
// same as a stale nonce
stale = true
return
} else if client.ncs_seen.Get(nc) {
// We've already seen this nc! Possible replay attack!
return
} else {
if client.nc != 0 && client.nc >= nc && !da.IgnoreNonceCount {
return
}
client.nc = nc
client.last_seen = time.Now().UnixNano()
}
client.ncs_seen.Set(nc)
client.last_seen = time.Now().UnixNano()

resp_HA2 := H(":" + auth["uri"])
rspauth := H(strings.Join([]string{HA1, auth["nonce"], auth["nc"], auth["cnonce"], auth["qop"], resp_HA2}, ":"))

info := fmt.Sprintf(`qop="auth", rspauth="%s", cnonce="%s", nc="%s"`, rspauth, auth["cnonce"], auth["nc"])
return auth["username"], &info
return auth["username"], &info, stale
}

/*
Expand All @@ -216,8 +239,8 @@ const DefaultClientCacheTolerance = 100
*/
func (a *DigestAuth) Wrap(wrapped AuthenticatedHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if username, authinfo := a.CheckAuth(r); username == "" {
a.RequireAuth(w, r)
if username, authinfo, stale := a.CheckAuth(r); username == "" {
a.RequireAuth(w, r, stale)
} else {
ar := &AuthenticatedRequest{Request: *r, Username: username}
if authinfo != nil {
Expand All @@ -242,7 +265,7 @@ func (a *DigestAuth) JustCheck(wrapped http.HandlerFunc) http.HandlerFunc {

// NewContext returns a context carrying authentication information for the request.
func (a *DigestAuth) NewContext(ctx context.Context, r *http.Request) context.Context {
username, authinfo := a.CheckAuth(r)
username, authinfo, stale := a.CheckAuth(r)
info := &Info{Username: username, ResponseHeaders: make(http.Header)}
if username != "" {
info.Authenticated = true
Expand All @@ -253,10 +276,14 @@ func (a *DigestAuth) NewContext(ctx context.Context, r *http.Request) context.Co
a.Purge(a.ClientCacheTolerance * 2)
}
nonce := RandomKey()
a.clients[nonce] = &digest_client{nc: 0, last_seen: time.Now().UnixNano()}
info.ResponseHeaders.Set(AuthenticateHeaderName(a.IsProxy),
fmt.Sprintf(`Digest realm="%s", nonce="%s", opaque="%s", algorithm="MD5", qop="auth"`,
a.Realm, nonce, a.Opaque))
a.clients[nonce] = &digest_client{ncs_seen: NewBitSet(a.NcCacheSize),
last_seen: time.Now().UnixNano()}
value := fmt.Sprintf(`Digest realm="%s", nonce="%s", opaque="%s", algorithm="MD5", qop="auth"`,
a.Realm, nonce, a.Opaque)
if stale {
value += ", stale=true"
}
info.ResponseHeaders.Set(AuthenticateHeaderName(a.IsProxy), value)
}
return context.WithValue(ctx, infoKey, info)
}
Expand All @@ -267,6 +294,7 @@ func NewDigestAuthenticator(realm string, secrets SecretProvider) *DigestAuth {
Realm: realm,
Secrets: secrets,
PlainTextSecrets: false,
NcCacheSize: DefaultNcCacheSize,
ClientCacheSize: DefaultClientCacheSize,
ClientCacheTolerance: DefaultClientCacheTolerance,
clients: map[string]*digest_client{}}
Expand Down
Loading

0 comments on commit e837317

Please sign in to comment.