Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(jwk.Set).Clone, more robust test, fix another off by one #415

Merged
merged 3 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Changes
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
Changes
=======

v1.2.4
[Bug fixes]
* We had the same off-by-one in another place and jumped the gun on
releasing a new version. At least we were making mistakes uniformally :/
`(jwk.Set).Remove` should finally be fixed.

[New features]
* `(jwk.Set).Clone()` has been added.

v1.2.3 15 Jul 2021
[Buf fixes]
[Bug fixes]
* jwk.Set incorrectly removed 2 elements instead of one.

[Miscellaneous]
Expand Down
3 changes: 3 additions & 0 deletions jwk/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ type Set interface {

// Iterate creates an iterator to iterate through all keys in the set.
Iterate(context.Context) KeyIterator

// Clone create a new set with identical keys. Keys themselves are not cloned.
Clone() (Set, error)
}

type set struct {
Expand Down
75 changes: 52 additions & 23 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1570,42 +1570,71 @@ func TestTypedFields(t *testing.T) {
}

func TestGH412(t *testing.T) {
set := jwk.NewSet()
base := jwk.NewSet()

const max = 5
kids := make(map[string]struct{})
for i := 0; i < max; i++ {
k, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxttest.GenerateRsaJwk() should succeed`) {
return
}

k.Set(jwk.KeyIDKey, strconv.Itoa(i))
set.Add(k)
kid := "key-" + strconv.Itoa(i)
k.Set(jwk.KeyIDKey, kid)
base.Add(k)
kids[kid] = struct{}{}
}

if !assert.Equal(t, max, set.Len(), `set.Len should be %d`, max) {
return
}
for i := 0; i < max; i++ {
idx := i
currentKid := "key-" + strconv.Itoa(i)
t.Run(fmt.Sprintf("Remove at position %d", i), func(t *testing.T) {
set, err := base.Clone()
if !assert.NoError(t, err, `base.Clone() should succeed`) {
return
}

k, ok := set.Get(max / 2)
if !assert.True(t, ok, `set.Get should succeed`) {
return
}
if !assert.Equal(t, max, set.Len(), `set.Len should be %d`, max) {
return
}

if !assert.True(t, set.Remove(k), `set.Remove should succeed`) {
return
}
k, ok := set.Get(idx)
if !assert.True(t, ok, `set.Get should succeed`) {
return
}

if !assert.Equal(t, max-1, set.Len(), `set.Len should be %d`, max-1) {
return
}
if !assert.True(t, set.Remove(k), `set.Remove should succeed`) {
return
}
t.Logf("deleted key %s", k.KeyID())

ctx := context.Background()
for iter := set.Iterate(ctx); iter.Next(ctx); {
pair := iter.Pair()
key := pair.Value.(jwk.Key)
if !assert.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) {
return
}
if !assert.Equal(t, max-1, set.Len(), `set.Len should be %d`, max-1) {
return
}

expected := make(map[string]struct{})
for k := range kids {
if k == currentKid {
continue
}
expected[k] = struct{}{}
}

ctx := context.Background()
for iter := set.Iterate(ctx); iter.Next(ctx); {
pair := iter.Pair()
key := pair.Value.(jwk.Key)
if !assert.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) {
return
}
t.Logf("%s found", key.KeyID())
delete(expected, key.KeyID())
}

if !assert.Len(t, expected, 0, `expected map should be empty`) {
return
}
})
}
}
18 changes: 16 additions & 2 deletions jwk/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (s *set) Get(idx int) (Key, bool) {
s.mu.RLock()
defer s.mu.RUnlock()

if idx >= 0 && idx < s.Len() {
if idx >= 0 && idx < len(s.keys) {
return s.keys[idx], true
}
return nil, false
Expand Down Expand Up @@ -69,7 +69,7 @@ func (s *set) Remove(key Key) bool {
case 0:
s.keys = s.keys[1:]
case len(s.keys) - 1:
s.keys = s.keys[:i-1]
s.keys = s.keys[:i]
default:
s.keys = append(s.keys[:i], s.keys[i+1:]...)
}
Expand Down Expand Up @@ -192,3 +192,17 @@ func (s *set) SetDecodeCtx(dc DecodeCtx) {
defer s.mu.Unlock()
s.dc = dc
}

func (s *set) Clone() (Set, error) {
s2 := &set{}

s.mu.RLock()
defer s.mu.RUnlock()

s2.keys = make([]Key, len(s.keys))

for i := 0; i < len(s.keys); i++ {
s2.keys[i] = s.keys[i]
}
return s2, nil
}