Skip to content

Commit

Permalink
(jwk.Set).Clone, more robust test, fix another off by one (#415)
Browse files Browse the repository at this point in the history
* (jwk.Set).Clone, more robust test, fix another off by one

* make tests even more robust

* Update Changes
  • Loading branch information
lestrrat committed Jul 15, 2021
1 parent 5e2de6e commit d99f783
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
11 changes: 10 additions & 1 deletion Changes
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
Changes
=======

v1.2.4
[Bug fixes]
* We had the same off-by-one in another place and jumped the gun on
releasing a new version. At least we were making mistakes uniformally :/
`(jwk.Set).Remove` should finally be fixed.

[New features]
* `(jwk.Set).Clone()` has been added.

v1.2.3 15 Jul 2021
[Buf fixes]
[Bug fixes]
* jwk.Set incorrectly removed 2 elements instead of one.

[Miscellaneous]
Expand Down
3 changes: 3 additions & 0 deletions jwk/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ type Set interface {

// Iterate creates an iterator to iterate through all keys in the set.
Iterate(context.Context) KeyIterator

// Clone create a new set with identical keys. Keys themselves are not cloned.
Clone() (Set, error)
}

type set struct {
Expand Down
75 changes: 52 additions & 23 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1570,42 +1570,71 @@ func TestTypedFields(t *testing.T) {
}

func TestGH412(t *testing.T) {
set := jwk.NewSet()
base := jwk.NewSet()

const max = 5
kids := make(map[string]struct{})
for i := 0; i < max; i++ {
k, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxttest.GenerateRsaJwk() should succeed`) {
return
}

k.Set(jwk.KeyIDKey, strconv.Itoa(i))
set.Add(k)
kid := "key-" + strconv.Itoa(i)
k.Set(jwk.KeyIDKey, kid)
base.Add(k)
kids[kid] = struct{}{}
}

if !assert.Equal(t, max, set.Len(), `set.Len should be %d`, max) {
return
}
for i := 0; i < max; i++ {
idx := i
currentKid := "key-" + strconv.Itoa(i)
t.Run(fmt.Sprintf("Remove at position %d", i), func(t *testing.T) {
set, err := base.Clone()
if !assert.NoError(t, err, `base.Clone() should succeed`) {
return
}

k, ok := set.Get(max / 2)
if !assert.True(t, ok, `set.Get should succeed`) {
return
}
if !assert.Equal(t, max, set.Len(), `set.Len should be %d`, max) {
return
}

if !assert.True(t, set.Remove(k), `set.Remove should succeed`) {
return
}
k, ok := set.Get(idx)
if !assert.True(t, ok, `set.Get should succeed`) {
return
}

if !assert.Equal(t, max-1, set.Len(), `set.Len should be %d`, max-1) {
return
}
if !assert.True(t, set.Remove(k), `set.Remove should succeed`) {
return
}
t.Logf("deleted key %s", k.KeyID())

ctx := context.Background()
for iter := set.Iterate(ctx); iter.Next(ctx); {
pair := iter.Pair()
key := pair.Value.(jwk.Key)
if !assert.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) {
return
}
if !assert.Equal(t, max-1, set.Len(), `set.Len should be %d`, max-1) {
return
}

expected := make(map[string]struct{})
for k := range kids {
if k == currentKid {
continue
}
expected[k] = struct{}{}
}

ctx := context.Background()
for iter := set.Iterate(ctx); iter.Next(ctx); {
pair := iter.Pair()
key := pair.Value.(jwk.Key)
if !assert.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) {
return
}
t.Logf("%s found", key.KeyID())
delete(expected, key.KeyID())
}

if !assert.Len(t, expected, 0, `expected map should be empty`) {
return
}
})
}
}
18 changes: 16 additions & 2 deletions jwk/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (s *set) Get(idx int) (Key, bool) {
s.mu.RLock()
defer s.mu.RUnlock()

if idx >= 0 && idx < s.Len() {
if idx >= 0 && idx < len(s.keys) {
return s.keys[idx], true
}
return nil, false
Expand Down Expand Up @@ -69,7 +69,7 @@ func (s *set) Remove(key Key) bool {
case 0:
s.keys = s.keys[1:]
case len(s.keys) - 1:
s.keys = s.keys[:i-1]
s.keys = s.keys[:i]
default:
s.keys = append(s.keys[:i], s.keys[i+1:]...)
}
Expand Down Expand Up @@ -192,3 +192,17 @@ func (s *set) SetDecodeCtx(dc DecodeCtx) {
defer s.mu.Unlock()
s.dc = dc
}

func (s *set) Clone() (Set, error) {
s2 := &set{}

s.mu.RLock()
defer s.mu.RUnlock()

s2.keys = make([]Key, len(s.keys))

for i := 0; i < len(s.keys); i++ {
s2.keys[i] = s.keys[i]
}
return s2, nil
}

0 comments on commit d99f783

Please sign in to comment.