diff --git a/Changes b/Changes index 82b1354ae..ebab0769c 100644 --- a/Changes +++ b/Changes @@ -1,8 +1,17 @@ Changes ======= +v1.2.4 +[Bug fixes] + * We had the same off-by-one in another place and jumped the gun on + releasing a new version. At least we were making mistakes uniformally :/ + `(jwk.Set).Remove` should finally be fixed. + +[New features] + * `(jwk.Set).Clone()` has been added. + v1.2.3 15 Jul 2021 -[Buf fixes] +[Bug fixes] * jwk.Set incorrectly removed 2 elements instead of one. [Miscellaneous] diff --git a/jwk/interface.go b/jwk/interface.go index 90ee48430..54d71496d 100644 --- a/jwk/interface.go +++ b/jwk/interface.go @@ -78,6 +78,9 @@ type Set interface { // Iterate creates an iterator to iterate through all keys in the set. Iterate(context.Context) KeyIterator + + // Clone create a new set with identical keys. Keys themselves are not cloned. + Clone() (Set, error) } type set struct { diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 7ac028ba6..4aee0dc61 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -1570,42 +1570,71 @@ func TestTypedFields(t *testing.T) { } func TestGH412(t *testing.T) { - set := jwk.NewSet() + base := jwk.NewSet() const max = 5 + kids := make(map[string]struct{}) for i := 0; i < max; i++ { k, err := jwxtest.GenerateRsaJwk() if !assert.NoError(t, err, `jwxttest.GenerateRsaJwk() should succeed`) { return } - k.Set(jwk.KeyIDKey, strconv.Itoa(i)) - set.Add(k) + kid := "key-" + strconv.Itoa(i) + k.Set(jwk.KeyIDKey, kid) + base.Add(k) + kids[kid] = struct{}{} } - if !assert.Equal(t, max, set.Len(), `set.Len should be %d`, max) { - return - } + for i := 0; i < max; i++ { + idx := i + currentKid := "key-" + strconv.Itoa(i) + t.Run(fmt.Sprintf("Remove at position %d", i), func(t *testing.T) { + set, err := base.Clone() + if !assert.NoError(t, err, `base.Clone() should succeed`) { + return + } - k, ok := set.Get(max / 2) - if !assert.True(t, ok, `set.Get should succeed`) { - return - } + if !assert.Equal(t, max, set.Len(), `set.Len should be %d`, max) { + return + } - if !assert.True(t, set.Remove(k), `set.Remove should succeed`) { - return - } + k, ok := set.Get(idx) + if !assert.True(t, ok, `set.Get should succeed`) { + return + } - if !assert.Equal(t, max-1, set.Len(), `set.Len should be %d`, max-1) { - return - } + if !assert.True(t, set.Remove(k), `set.Remove should succeed`) { + return + } + t.Logf("deleted key %s", k.KeyID()) - ctx := context.Background() - for iter := set.Iterate(ctx); iter.Next(ctx); { - pair := iter.Pair() - key := pair.Value.(jwk.Key) - if !assert.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) { - return - } + if !assert.Equal(t, max-1, set.Len(), `set.Len should be %d`, max-1) { + return + } + + expected := make(map[string]struct{}) + for k := range kids { + if k == currentKid { + continue + } + expected[k] = struct{}{} + } + + ctx := context.Background() + for iter := set.Iterate(ctx); iter.Next(ctx); { + pair := iter.Pair() + key := pair.Value.(jwk.Key) + if !assert.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) { + return + } + t.Logf("%s found", key.KeyID()) + delete(expected, key.KeyID()) + } + + if !assert.Len(t, expected, 0, `expected map should be empty`) { + return + } + }) } } diff --git a/jwk/set.go b/jwk/set.go index 547393a7b..caea02f71 100644 --- a/jwk/set.go +++ b/jwk/set.go @@ -18,7 +18,7 @@ func (s *set) Get(idx int) (Key, bool) { s.mu.RLock() defer s.mu.RUnlock() - if idx >= 0 && idx < s.Len() { + if idx >= 0 && idx < len(s.keys) { return s.keys[idx], true } return nil, false @@ -69,7 +69,7 @@ func (s *set) Remove(key Key) bool { case 0: s.keys = s.keys[1:] case len(s.keys) - 1: - s.keys = s.keys[:i-1] + s.keys = s.keys[:i] default: s.keys = append(s.keys[:i], s.keys[i+1:]...) } @@ -192,3 +192,17 @@ func (s *set) SetDecodeCtx(dc DecodeCtx) { defer s.mu.Unlock() s.dc = dc } + +func (s *set) Clone() (Set, error) { + s2 := &set{} + + s.mu.RLock() + defer s.mu.RUnlock() + + s2.keys = make([]Key, len(s.keys)) + + for i := 0; i < len(s.keys); i++ { + s2.keys[i] = s.keys[i] + } + return s2, nil +}