diff --git a/identity/handler.go b/identity/handler.go index c4f90cec7619..58578d56e72c 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/gofrs/uuid" + "github.com/ory/x/crdbx" "github.com/ory/x/pagination/keysetpagination" @@ -193,6 +195,7 @@ type listIdentitiesParameters struct { // default: errorGeneric func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { includeCredentials := r.URL.Query()["include_credential"] + var err error var declassify []CredentialsType for _, v := range includeCredentials { tc, ok := ParseCredentialsType(v) @@ -204,17 +207,24 @@ func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Para } } - var ( - err error - params = ListIdentityParameters{ - Expand: ExpandDefault, - IdsFilter: r.URL.Query()["ids"], - CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"), - CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"), - ConsistencyLevel: crdbx.ConsistencyLevelFromRequest(r), - DeclassifyCredentials: declassify, + var idsFilter []uuid.UUID + for _, v := range r.URL.Query()["ids"] { + id, err := uuid.FromString(v) + if err != nil { + h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid UUID value `%s` for parameter `ids`.", v))) + return } - ) + idsFilter = append(idsFilter, id) + } + + params := ListIdentityParameters{ + Expand: ExpandDefault, + IdsFilter: idsFilter, + CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"), + CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"), + ConsistencyLevel: crdbx.ConsistencyLevelFromRequest(r), + DeclassifyCredentials: declassify, + } if params.CredentialsIdentifier != "" && params.CredentialsIdentifierSimilar != "" { h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithReason("Cannot pass both credentials_identifier and preview_credentials_identifier_similar.")) return diff --git a/identity/handler_test.go b/identity/handler_test.go index 33d6a46adf87..12bbde6cfa80 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -372,18 +372,23 @@ func TestHandler(t *testing.T) { require.Equal(t, len(ids), identitiesAmount) }) - t.Run("case= list few identities", func(t *testing.T) { + t.Run("case=list few identities", func(t *testing.T) { url := "/identities?ids=" + ids[0].String() for i := 1; i < listAmount; i++ { url += "&ids=" + ids[i].String() } - res := get(t, adminTS, url, 200) + res := get(t, adminTS, url, http.StatusOK) identities := res.Array() require.Equal(t, len(identities), listAmount) }) }) + t.Run("case=malformed ids should return an error", func(t *testing.T) { + res := get(t, adminTS, "/identities?ids=not-a-uuid", http.StatusBadRequest) + assert.Contains(t, res.Get("error.reason").String(), "Invalid UUID value `not-a-uuid` for parameter `ids`.", "%s", res.Raw) + }) + t.Run("suite=create and update", func(t *testing.T) { var i identity.Identity createOidcIdentity := func(t *testing.T, identifier, accessToken, refreshToken, idToken string, encrypt bool) string { diff --git a/identity/pool.go b/identity/pool.go index 86559f0a8a3f..30a7308245b4 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -18,7 +18,7 @@ import ( type ( ListIdentityParameters struct { Expand Expandables - IdsFilter []string + IdsFilter []uuid.UUID CredentialsIdentifier string CredentialsIdentifierSimilar string DeclassifyCredentials []CredentialsType diff --git a/identity/test/pool.go b/identity/test/pool.go index 20aa96462eb2..4d9f4c440910 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -676,10 +676,7 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, }) t.Run("list some using ids filter", func(t *testing.T) { - var filterIds []string - for _, id := range createdIDs[:2] { - filterIds = append(filterIds, id.String()) - } + filterIds := createdIDs[:2] is, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, IdsFilter: filterIds}) require.NoError(t, err)