diff --git a/cmd/start/start.go b/cmd/start/start.go index 21f445cfd60..db9c9afc540 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -442,7 +442,7 @@ func startAPIs( if err := apis.RegisterService(ctx, user_v2.CreateServer(commands, queries, keys.User, keys.IDPConfig, idp.CallbackURL(), idp.SAMLRootURL(), assets.AssetAPI(), permissionCheck)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, session_v2beta.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, session_v2beta.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } if err := apis.RegisterService(ctx, settings_v2beta.CreateServer(commands, queries)); err != nil { @@ -454,7 +454,7 @@ func startAPIs( if err := apis.RegisterService(ctx, feature_v2beta.CreateServer(commands, queries)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, session_v2.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, session_v2.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } if err := apis.RegisterService(ctx, settings_v2.CreateServer(commands, queries)); err != nil { diff --git a/internal/api/authz/context_mock.go b/internal/api/authz/context_mock.go index 6badf158628..6891030bd30 100644 --- a/internal/api/authz/context_mock.go +++ b/internal/api/authz/context_mock.go @@ -7,6 +7,11 @@ func NewMockContext(instanceID, orgID, userID string) context.Context { return context.WithValue(ctx, instanceKey, &instance{id: instanceID}) } +func NewMockContextWithAgent(instanceID, orgID, userID, agentID string) context.Context { + ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID, AgentID: agentID}) + return context.WithValue(ctx, instanceKey, &instance{id: instanceID}) +} + func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context { ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID}) ctx = context.WithValue(ctx, instanceKey, &instance{id: instanceID}) diff --git a/internal/api/grpc/session/v2/integration_test/query_test.go b/internal/api/grpc/session/v2/integration_test/query_test.go new file mode 100644 index 00000000000..36e412be230 --- /dev/null +++ b/internal/api/grpc/session/v2/integration_test/query_test.go @@ -0,0 +1,714 @@ +//go:build integration + +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/object/v2" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +func TestServer_GetSession(t *testing.T) { + type args struct { + ctx context.Context + req *session.GetSessionRequest + dep func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 + } + tests := []struct { + name string + args args + want *session.GetSessionResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "get session, no id provided", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, not found", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "unknown", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, no permission", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + wantErr: true, + }, + { + name: "get session, permission, ok", + args: args{ + CTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, token, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, user agent, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + { + name: "get session, lifetime, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + Lifetime: durationpb.New(5 * time.Minute), + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantExpirationWindow: 5 * time.Minute, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, metadata, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + }, + }, + { + name: "get session, user, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var sequence uint64 + if tt.args.dep != nil { + sequence = tt.args.dep(CTX, t, tt.args.req) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.GetSession(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + tt.want.Session.Id = tt.args.req.SessionId + tt.want.Session.Sequence = sequence + verifySession(ttt, got.GetSession(), tt.want.GetSession(), time.Minute, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) + }, retryDuration, tick) + }) + } +} + +type sessionAttr struct { + ID string + UserID string + UserAgent string + CreationDate *timestamp.Timestamp + ChangeDate *timestamppb.Timestamp + Details *object.Details +} + +type sessionAttrs []*sessionAttr + +func (u sessionAttrs) ids() []string { + ids := make([]string, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return ids +} + +func createSessions(ctx context.Context, t *testing.T, count int, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) sessionAttrs { + infos := make([]*sessionAttr, count) + for i := 0; i < count; i++ { + infos[i] = createSession(ctx, t, userID, userAgent, lifetime, metadata) + } + return infos +} + +func createSession(ctx context.Context, t *testing.T, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) *sessionAttr { + req := &session.CreateSessionRequest{} + if userID != "" { + req.Checks = &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: userID, + }, + }, + } + } + if userAgent != "" { + req.UserAgent = &session.UserAgent{ + FingerprintId: gu.Ptr(userAgent), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + } + } + if lifetime != nil { + req.Lifetime = lifetime + } + if metadata != nil { + req.Metadata = metadata + } + resp, err := Client.CreateSession(ctx, req) + require.NoError(t, err) + return &sessionAttr{ + resp.GetSessionId(), + userID, + userAgent, + resp.GetDetails().GetChangeDate(), + resp.GetDetails().GetChangeDate(), + resp.GetDetails(), + } +} + +func TestServer_ListSessions(t *testing.T) { + type args struct { + ctx context.Context + req *session.ListSessionsRequest + dep func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr + } + tests := []struct { + name string + args args + want *session.ListSessionsResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "list sessions, not found", + args: args{ + CTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{ + {Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{"unknown"}}}}, + }, + }, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, no permission", + args: args{ + UserCTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{}, + }, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, "", "", nil, nil) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, permission, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, "", "", nil, nil) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{info} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{{}}, + }, + }, + { + name: "list sessions, full, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, multiple, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + infos := createSessions(ctx, t, 3, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: infos.ids()}}}) + return infos + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 3, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, userid, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + createdUser := createFullUser(ctx) + info := createSession(ctx, t, createdUser.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_UserIdQuery{UserIdQuery: &session.UserIDQuery{Id: createdUser.GetUserId()}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, own creator, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, creator, ok", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{Id: gu.Ptr(Instance.Users.Get(integration.UserTypeOrgOwner).ID)}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, wrong creator", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, empty creator", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{Id: gu.Ptr("")}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + wantErr: true, + }, + { + name: "list sessions, useragent, ok", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "useragent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_UserAgentQuery{UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("useragent")}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("useragent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, wrong useragent", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "useragent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_UserAgentQuery{UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("wronguseragent")}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, empty useragent", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_UserAgentQuery{UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("")}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + infos := tt.args.dep(CTX, t, tt.args.req) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.ListSessions(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + if !assert.Equal(ttt, got.Details.TotalResult, tt.want.Details.TotalResult) || !assert.Len(ttt, got.Sessions, len(tt.want.Sessions)) { + return + } + + for i := range infos { + tt.want.Sessions[i].Id = infos[i].ID + tt.want.Sessions[i].Sequence = infos[i].Details.GetSequence() + tt.want.Sessions[i].CreationDate = infos[i].Details.GetChangeDate() + tt.want.Sessions[i].ChangeDate = infos[i].Details.GetChangeDate() + + verifySession(ttt, got.Sessions[i], tt.want.Sessions[i], time.Minute, tt.wantExpirationWindow, infos[i].UserID, tt.wantFactors...) + } + integration.AssertListDetails(ttt, tt.want, got) + }, retryDuration, tick) + }) + } +} diff --git a/internal/api/grpc/session/v2/integration_test/server_test.go b/internal/api/grpc/session/v2/integration_test/server_test.go new file mode 100644 index 00000000000..70e21460697 --- /dev/null +++ b/internal/api/grpc/session/v2/integration_test/server_test.go @@ -0,0 +1,74 @@ +//go:build integration + +package session_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + CTX context.Context + IAMOwnerCTX context.Context + UserCTX context.Context + Instance *integration.Instance + Client session.SessionServiceClient + User *user.AddHumanUserResponse + DeactivatedUser *user.AddHumanUserResponse + LockedUser *user.AddHumanUserResponse +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + Client = Instance.Client.SessionV2 + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) + UserCTX = Instance.WithAuthorization(ctx, integration.UserTypeNoPermission) + User = createFullUser(CTX) + DeactivatedUser = createDeactivatedUser(CTX) + LockedUser = createLockedUser(CTX) + return m.Run() + }()) +} + +func createFullUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetEmailCode(), + }) + Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetPhoneCode(), + }) + Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) + Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) + return userResp +} + +func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("deactivate human user") + return userResp +} + +func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("lock human user") + return userResp +} diff --git a/internal/api/grpc/session/v2/integration_test/session_test.go b/internal/api/grpc/session/v2/integration_test/session_test.go index ccd08f34712..7622550b15c 100644 --- a/internal/api/grpc/session/v2/integration_test/session_test.go +++ b/internal/api/grpc/session/v2/integration_test/session_test.go @@ -5,7 +5,6 @@ package session_test import ( "context" "fmt" - "os" "testing" "time" @@ -14,7 +13,6 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -29,63 +27,7 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -var ( - CTX context.Context - IAMOwnerCTX context.Context - Instance *integration.Instance - Client session.SessionServiceClient - User *user.AddHumanUserResponse - DeactivatedUser *user.AddHumanUserResponse - LockedUser *user.AddHumanUserResponse -) - -func TestMain(m *testing.M) { - os.Exit(func() int { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - - Instance = integration.NewInstance(ctx) - Client = Instance.Client.SessionV2 - - CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) - IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) - User = createFullUser(CTX) - DeactivatedUser = createDeactivatedUser(CTX) - LockedUser = createLockedUser(CTX) - return m.Run() - }()) -} - -func createFullUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetEmailCode(), - }) - Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetPhoneCode(), - }) - Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) - Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) - return userResp -} - -func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("deactivate human user") - return userResp -} - -func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("lock human user") - return userResp -} - -func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { +func verifyCurrentSession(t *testing.T, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { t.Helper() require.NotEmpty(t, id) require.NotEmpty(t, token) @@ -96,15 +38,25 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo }) require.NoError(t, err) s := resp.GetSession() + want := &session.Session{ + Id: id, + Sequence: sequence, + Metadata: metadata, + UserAgent: userAgent, + } + verifySession(t, s, want, window, expirationWindow, userID, factors...) + return s +} - assert.Equal(t, id, s.GetId()) +func verifySession(t assert.TestingT, s *session.Session, want *session.Session, window time.Duration, expirationWindow time.Duration, userID string, factors ...wantFactor) { + assert.Equal(t, want.Id, s.GetId()) assert.WithinRange(t, s.GetCreationDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) assert.WithinRange(t, s.GetChangeDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) - assert.Equal(t, sequence, s.GetSequence()) - assert.Equal(t, metadata, s.GetMetadata()) + assert.Equal(t, want.Sequence, s.GetSequence()) + assert.Equal(t, want.Metadata, s.GetMetadata()) - if !proto.Equal(userAgent, s.GetUserAgent()) { - t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent) + if !proto.Equal(want.UserAgent, s.GetUserAgent()) { + t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), want.UserAgent) } if expirationWindow == 0 { assert.Nil(t, s.GetExpirationDate()) @@ -113,7 +65,6 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo } verifyFactors(t, s.GetFactors(), window, userID, factors) - return s } type wantFactor int @@ -129,7 +80,7 @@ const ( wantOTPEmailFactor ) -func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { +func verifyFactors(t assert.TestingT, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { for _, w := range want { switch w { case wantUserFactor: @@ -194,8 +145,15 @@ func TestServer_CreateSession(t *testing.T) { }, }, { - name: "user agent", + name: "full session", req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, Metadata: map[string][]byte{"foo": []byte("bar")}, UserAgent: &session.UserAgent{ FingerprintId: gu.Ptr("fingerPrintID"), @@ -205,6 +163,7 @@ func TestServer_CreateSession(t *testing.T) { "foo": {Values: []string{"foo", "bar"}}, }, }, + Lifetime: durationpb.New(5 * time.Minute), }, want: &session.CreateSessionResponse{ Details: &object.Details{ @@ -212,14 +171,6 @@ func TestServer_CreateSession(t *testing.T) { ResourceOwner: Instance.ID(), }, }, - wantUserAgent: &session.UserAgent{ - FingerprintId: gu.Ptr("fingerPrintID"), - Ip: gu.Ptr("1.2.3.4"), - Description: gu.Ptr("Description"), - Header: map[string]*session.UserAgent_HeaderValues{ - "foo": {Values: []string{"foo", "bar"}}, - }, - }, }, { name: "negative lifetime", @@ -229,40 +180,6 @@ func TestServer_CreateSession(t *testing.T) { }, wantErr: true, }, - { - name: "lifetime", - req: &session.CreateSessionRequest{ - Metadata: map[string][]byte{"foo": []byte("bar")}, - Lifetime: durationpb.New(5 * time.Minute), - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantExpirationWindow: 5 * time.Minute, - }, - { - name: "with user", - req: &session.CreateSessionRequest{ - Checks: &session.Checks{ - User: &session.CheckUser{ - Search: &session.CheckUser_UserId{ - UserId: User.GetUserId(), - }, - }, - }, - Metadata: map[string][]byte{"foo": []byte("bar")}, - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantFactors: []wantFactor{wantUserFactor}, - }, { name: "deactivated user", req: &session.CreateSessionRequest{ @@ -340,8 +257,6 @@ func TestServer_CreateSession(t *testing.T) { } require.NoError(t, err) integration.AssertDetails(t, tt.want, got) - - verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) }) } } @@ -946,21 +861,30 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.NoError(t, err) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("Bearer %s", createResp.GetSessionToken())) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) - require.Error(t, err) - require.Nil(t, sessionResp) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) + if !assert.Error(tt, err) { + return + } + assert.Nil(tt, sessionResp) + }, retryDuration, tick) } func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId()) - ctx := integration.WithAuthorizationToken(context.Background(), token) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) - webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() - require.NotNil(t, id, webAuthN.GetVerifiedAt().AsTime()) - require.True(t, webAuthN.GetUserVerified()) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() + assert.NotNil(tt, id, webAuthN.GetVerifiedAt().AsTime()) + assert.True(tt, webAuthN.GetUserVerified()) + }, retryDuration, tick) } func Test_ZITADEL_API_session_not_found(t *testing.T) { @@ -968,18 +892,30 @@ func Test_ZITADEL_API_session_not_found(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) //terminate the session and test it does not work anymore - _, err = Client.DeleteSession(CTX, &session.DeleteSessionRequest{ + _, err := Client.DeleteSession(CTX, &session.DeleteSessionRequest{ SessionId: id, SessionToken: gu.Ptr(token), }) require.NoError(t, err) + ctx = integration.WithAuthorizationToken(context.Background(), token) - _, err = Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.Error(t, err) + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.Error(tt, err) { + return + } + }, retryDuration, tick) } func Test_ZITADEL_API_session_expired(t *testing.T) { @@ -987,8 +923,13 @@ func Test_ZITADEL_API_session_expired(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) // ensure session expires and does not work anymore time.Sleep(20 * time.Second) diff --git a/internal/api/grpc/session/v2/query.go b/internal/api/grpc/session/v2/query.go new file mode 100644 index 00000000000..73303dd9e86 --- /dev/null +++ b/internal/api/grpc/session/v2/query.go @@ -0,0 +1,262 @@ +package session + +import ( + "context" + "time" + + "github.com/muhlemmer/gu" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + objpb "github.com/zitadel/zitadel/pkg/grpc/object" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +var ( + timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, + } +) + +func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { + res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) + if err != nil { + return nil, err + } + return &session.GetSessionResponse{ + Session: sessionToPb(res), + }, nil +} + +func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { + queries, err := listSessionsRequestToQuery(ctx, req) + if err != nil { + return nil, err + } + sessions, err := s.query.SearchSessions(ctx, queries, s.checkPermission) + if err != nil { + return nil, err + } + return &session.ListSessionsResponse{ + Details: object.ToListDetails(sessions.SearchResponse), + Sessions: sessionsToPb(sessions.Sessions), + }, nil +} + +func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { + offset, limit, asc := object.ListQueryToQuery(req.Query) + queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) + if err != nil { + return nil, err + } + return &query.SessionsSearchQueries{ + SearchRequest: query.SearchRequest{ + Offset: offset, + Limit: limit, + Asc: asc, + SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), + }, + Queries: queries, + }, nil +} + +func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { + q := make([]query.SearchQuery, len(queries)) + for i, v := range queries { + q[i], err = sessionQueryToQuery(ctx, v) + if err != nil { + return nil, err + } + } + return q, nil +} + +func sessionQueryToQuery(ctx context.Context, sq *session.SearchQuery) (query.SearchQuery, error) { + switch q := sq.Query.(type) { + case *session.SearchQuery_IdsQuery: + return idsQueryToQuery(q.IdsQuery) + case *session.SearchQuery_UserIdQuery: + return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) + case *session.SearchQuery_CreationDateQuery: + return creationDateQueryToQuery(q.CreationDateQuery) + case *session.SearchQuery_CreatorQuery: + if q.CreatorQuery != nil && q.CreatorQuery.Id != nil { + if q.CreatorQuery.GetId() != "" { + return query.NewSessionCreatorSearchQuery(q.CreatorQuery.GetId()) + } + } else { + if userID := authz.GetCtxData(ctx).UserID; userID != "" { + return query.NewSessionCreatorSearchQuery(userID) + } + } + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n24uh", "List.Query.Invalid") + case *session.SearchQuery_UserAgentQuery: + if q.UserAgentQuery != nil && q.UserAgentQuery.FingerprintId != nil { + if *q.UserAgentQuery.FingerprintId != "" { + return query.NewSessionUserAgentFingerprintIDSearchQuery(q.UserAgentQuery.GetFingerprintId()) + } + } else { + if agentID := authz.GetCtxData(ctx).AgentID; agentID != "" { + return query.NewSessionUserAgentFingerprintIDSearchQuery(agentID) + } + } + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid") + default: + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") + } +} + +func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { + return query.NewSessionIDsSearchQuery(q.Ids) +} + +func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { + comparison := timestampComparisons[q.GetMethod()] + return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) +} + +func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { + switch field { + case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: + return query.SessionColumnCreationDate + case session.SessionFieldName_SESSION_FIELD_NAME_UNSPECIFIED: + return query.Column{} + default: + return query.Column{} + } +} + +func sessionsToPb(sessions []*query.Session) []*session.Session { + s := make([]*session.Session, len(sessions)) + for i, session := range sessions { + s[i] = sessionToPb(session) + } + return s +} + +func sessionToPb(s *query.Session) *session.Session { + return &session.Session{ + Id: s.ID, + CreationDate: timestamppb.New(s.CreationDate), + ChangeDate: timestamppb.New(s.ChangeDate), + Sequence: s.Sequence, + Factors: factorsToPb(s), + Metadata: s.Metadata, + UserAgent: userAgentToPb(s.UserAgent), + ExpirationDate: expirationToPb(s.Expiration), + } +} + +func userAgentToPb(ua domain.UserAgent) *session.UserAgent { + if ua.IsEmpty() { + return nil + } + + out := &session.UserAgent{ + FingerprintId: ua.FingerprintID, + Description: ua.Description, + } + if ua.IP != nil { + out.Ip = gu.Ptr(ua.IP.String()) + } + if ua.Header == nil { + return out + } + out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) + for k, v := range ua.Header { + out.Header[k] = &session.UserAgent_HeaderValues{ + Values: v, + } + } + return out +} + +func expirationToPb(expiration time.Time) *timestamppb.Timestamp { + if expiration.IsZero() { + return nil + } + return timestamppb.New(expiration) +} + +func factorsToPb(s *query.Session) *session.Factors { + user := userFactorToPb(s.UserFactor) + if user == nil { + return nil + } + return &session.Factors{ + User: user, + Password: passwordFactorToPb(s.PasswordFactor), + WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), + Intent: intentFactorToPb(s.IntentFactor), + Totp: totpFactorToPb(s.TOTPFactor), + OtpSms: otpFactorToPb(s.OTPSMSFactor), + OtpEmail: otpFactorToPb(s.OTPEmailFactor), + } +} + +func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { + if factor.PasswordCheckedAt.IsZero() { + return nil + } + return &session.PasswordFactor{ + VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), + } +} + +func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { + if factor.IntentCheckedAt.IsZero() { + return nil + } + return &session.IntentFactor{ + VerifiedAt: timestamppb.New(factor.IntentCheckedAt), + } +} + +func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { + if factor.WebAuthNCheckedAt.IsZero() { + return nil + } + return &session.WebAuthNFactor{ + VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), + UserVerified: factor.UserVerified, + } +} + +func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { + if factor.TOTPCheckedAt.IsZero() { + return nil + } + return &session.TOTPFactor{ + VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), + } +} + +func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { + if factor.OTPCheckedAt.IsZero() { + return nil + } + return &session.OTPFactor{ + VerifiedAt: timestamppb.New(factor.OTPCheckedAt), + } +} + +func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { + if factor.UserID == "" || factor.UserCheckedAt.IsZero() { + return nil + } + return &session.UserFactor{ + VerifiedAt: timestamppb.New(factor.UserCheckedAt), + Id: factor.UserID, + LoginName: factor.LoginName, + DisplayName: factor.DisplayName, + OrganizationId: factor.ResourceOwner, + } +} diff --git a/internal/api/grpc/session/v2/server.go b/internal/api/grpc/session/v2/server.go index e94336bf47e..ee534cb26c5 100644 --- a/internal/api/grpc/session/v2/server.go +++ b/internal/api/grpc/session/v2/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) @@ -16,6 +17,8 @@ type Server struct { session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries + + checkPermission domain.PermissionCheck } type Config struct{} @@ -23,10 +26,12 @@ type Config struct{} func CreateServer( command *command.Commands, query *query.Queries, + checkPermission domain.PermissionCheck, ) *Server { return &Server{ - command: command, - query: query, + command: command, + query: query, + checkPermission: checkPermission, } } diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index aa25fa0ae3f..7562d643501 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -6,56 +6,17 @@ import ( "net/http" "time" - "github.com/muhlemmer/gu" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/zerrors" - objpb "github.com/zitadel/zitadel/pkg/grpc/object" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) -var ( - timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, - } -) - -func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) - if err != nil { - return nil, err - } - return &session.GetSessionResponse{ - Session: sessionToPb(res), - }, nil -} - -func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { - queries, err := listSessionsRequestToQuery(ctx, req) - if err != nil { - return nil, err - } - sessions, err := s.query.SearchSessions(ctx, queries) - if err != nil { - return nil, err - } - return &session.ListSessionsResponse{ - Details: object.ToListDetails(sessions.SearchResponse), - Sessions: sessionsToPb(sessions.Sessions), - }, nil -} - func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req) if err != nil { @@ -110,197 +71,6 @@ func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRe }, nil } -func sessionsToPb(sessions []*query.Session) []*session.Session { - s := make([]*session.Session, len(sessions)) - for i, session := range sessions { - s[i] = sessionToPb(session) - } - return s -} - -func sessionToPb(s *query.Session) *session.Session { - return &session.Session{ - Id: s.ID, - CreationDate: timestamppb.New(s.CreationDate), - ChangeDate: timestamppb.New(s.ChangeDate), - Sequence: s.Sequence, - Factors: factorsToPb(s), - Metadata: s.Metadata, - UserAgent: userAgentToPb(s.UserAgent), - ExpirationDate: expirationToPb(s.Expiration), - } -} - -func userAgentToPb(ua domain.UserAgent) *session.UserAgent { - if ua.IsEmpty() { - return nil - } - - out := &session.UserAgent{ - FingerprintId: ua.FingerprintID, - Description: ua.Description, - } - if ua.IP != nil { - out.Ip = gu.Ptr(ua.IP.String()) - } - if ua.Header == nil { - return out - } - out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) - for k, v := range ua.Header { - out.Header[k] = &session.UserAgent_HeaderValues{ - Values: v, - } - } - return out -} - -func expirationToPb(expiration time.Time) *timestamppb.Timestamp { - if expiration.IsZero() { - return nil - } - return timestamppb.New(expiration) -} - -func factorsToPb(s *query.Session) *session.Factors { - user := userFactorToPb(s.UserFactor) - if user == nil { - return nil - } - return &session.Factors{ - User: user, - Password: passwordFactorToPb(s.PasswordFactor), - WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), - Intent: intentFactorToPb(s.IntentFactor), - Totp: totpFactorToPb(s.TOTPFactor), - OtpSms: otpFactorToPb(s.OTPSMSFactor), - OtpEmail: otpFactorToPb(s.OTPEmailFactor), - } -} - -func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { - if factor.PasswordCheckedAt.IsZero() { - return nil - } - return &session.PasswordFactor{ - VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), - } -} - -func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { - if factor.IntentCheckedAt.IsZero() { - return nil - } - return &session.IntentFactor{ - VerifiedAt: timestamppb.New(factor.IntentCheckedAt), - } -} - -func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { - if factor.WebAuthNCheckedAt.IsZero() { - return nil - } - return &session.WebAuthNFactor{ - VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), - UserVerified: factor.UserVerified, - } -} - -func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { - if factor.TOTPCheckedAt.IsZero() { - return nil - } - return &session.TOTPFactor{ - VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), - } -} - -func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { - if factor.OTPCheckedAt.IsZero() { - return nil - } - return &session.OTPFactor{ - VerifiedAt: timestamppb.New(factor.OTPCheckedAt), - } -} - -func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { - if factor.UserID == "" || factor.UserCheckedAt.IsZero() { - return nil - } - return &session.UserFactor{ - VerifiedAt: timestamppb.New(factor.UserCheckedAt), - Id: factor.UserID, - LoginName: factor.LoginName, - DisplayName: factor.DisplayName, - OrganizationId: factor.ResourceOwner, - } -} - -func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { - offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) - if err != nil { - return nil, err - } - return &query.SessionsSearchQueries{ - SearchRequest: query.SearchRequest{ - Offset: offset, - Limit: limit, - Asc: asc, - SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), - }, - Queries: queries, - }, nil -} - -func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { - q := make([]query.SearchQuery, len(queries)+1) - for i, v := range queries { - q[i], err = sessionQueryToQuery(v) - if err != nil { - return nil, err - } - } - creatorQuery, err := query.NewSessionCreatorSearchQuery(authz.GetCtxData(ctx).UserID) - if err != nil { - return nil, err - } - q[len(queries)] = creatorQuery - return q, nil -} - -func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) { - switch q := sq.Query.(type) { - case *session.SearchQuery_IdsQuery: - return idsQueryToQuery(q.IdsQuery) - case *session.SearchQuery_UserIdQuery: - return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) - case *session.SearchQuery_CreationDateQuery: - return creationDateQueryToQuery(q.CreationDateQuery) - default: - return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") - } -} - -func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { - return query.NewSessionIDsSearchQuery(q.Ids) -} - -func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { - comparison := timestampComparisons[q.GetMethod()] - return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) -} - -func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { - switch field { - case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: - return query.SessionColumnCreationDate - default: - return query.Column{} - } -} - func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { diff --git a/internal/api/grpc/session/v2/session_test.go b/internal/api/grpc/session/v2/session_test.go index 917be882f84..ce4f5115f21 100644 --- a/internal/api/grpc/session/v2/session_test.go +++ b/internal/api/grpc/session/v2/session_test.go @@ -339,9 +339,7 @@ func Test_listSessionsRequestToQuery(t *testing.T) { Limit: 0, Asc: false, }, - Queries: []query.SearchQuery{ - mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), - }, + Queries: []query.SearchQuery{}, }, }, { @@ -359,15 +357,13 @@ func Test_listSessionsRequestToQuery(t *testing.T) { SortingColumn: query.SessionColumnCreationDate, Asc: false, }, - Queries: []query.SearchQuery{ - mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), - }, + Queries: []query.SearchQuery{}, }, }, { name: "with list query and sessions", args: args{ - ctx: authz.NewMockContext("123", "456", "789"), + ctx: authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent", UserID: "789"}), req: &session.ListSessionsRequest{ Query: &object.ListQuery{ Offset: 10, @@ -396,6 +392,12 @@ func Test_listSessionsRequestToQuery(t *testing.T) { Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER, }, }}, + {Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }}, + {Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{}, + }}, }, }, }, @@ -411,6 +413,7 @@ func Test_listSessionsRequestToQuery(t *testing.T) { mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals), mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater), mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), + mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals), }, }, }, @@ -458,13 +461,11 @@ func Test_sessionQueriesToQuery(t *testing.T) { wantErr error }{ { - name: "creator only", + name: "no queries", args: args{ ctx: authz.NewMockContext("123", "456", "789"), }, - want: []query.SearchQuery{ - mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), - }, + want: []query.SearchQuery{}, }, { name: "invalid argument", @@ -491,6 +492,9 @@ func Test_sessionQueriesToQuery(t *testing.T) { Ids: []string{"4", "5", "6"}, }, }}, + {Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }}, }, }, want: []query.SearchQuery{ @@ -511,6 +515,7 @@ func Test_sessionQueriesToQuery(t *testing.T) { func Test_sessionQueryToQuery(t *testing.T) { type args struct { + ctx context.Context query *session.SearchQuery } tests := []struct { @@ -521,60 +526,158 @@ func Test_sessionQueryToQuery(t *testing.T) { }{ { name: "invalid argument", - args: args{&session.SearchQuery{ - Query: nil, - }}, + args: args{ + context.Background(), + &session.SearchQuery{ + Query: nil, + }}, wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"), }, { name: "ids query", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_IdsQuery{ - IdsQuery: &session.IDsQuery{ - Ids: []string{"1", "2", "3"}, + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_IdsQuery{ + IdsQuery: &session.IDsQuery{ + Ids: []string{"1", "2", "3"}, + }, }, - }, - }}, + }}, want: mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn), }, { name: "user id query", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_UserIdQuery{ - UserIdQuery: &session.UserIDQuery{ - Id: "10", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_UserIdQuery{ + UserIdQuery: &session.UserIDQuery{ + Id: "10", + }, }, - }, - }}, + }}, want: mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals), }, { name: "creation date query", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_CreationDateQuery{ - CreationDateQuery: &session.CreationDateQuery{ - CreationDate: timestamppb.New(creationDate), - Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS, + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS, + }, }, - }, - }}, + }}, want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampLess), }, { name: "creation date query with default method", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_CreationDateQuery{ - CreationDateQuery: &session.CreationDateQuery{ - CreationDate: timestamppb.New(creationDate), + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + }, }, - }, - }}, + }}, want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampEquals), }, + { + name: "own creator", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: "creator"}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnCreator, "creator", query.TextEquals), + }, + { + name: "empty own creator, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: ""}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n24uh", "List.Query.Invalid"), + }, + { + name: "creator", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: "creator1"}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{Id: gu.Ptr("creator2")}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnCreator, "creator2", query.TextEquals), + }, + { + name: "empty creator, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: "creator1"}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{Id: gu.Ptr("")}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n24uh", "List.Query.Invalid"), + }, + { + name: "empty own useragent, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: ""}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid"), + }, + { + name: "own useragent", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent"}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals), + }, + { + name: "empty useragent, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent"}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("")}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid"), + }, + { + name: "useragent", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent1"}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("agent2")}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent2", query.TextEquals), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := sessionQueryToQuery(tt.args.query) + got, err := sessionQueryToQuery(tt.args.ctx, tt.args.query) require.ErrorIs(t, err, tt.wantErr) assert.Equal(t, tt.want, got) }) diff --git a/internal/api/grpc/session/v2beta/integration_test/query_test.go b/internal/api/grpc/session/v2beta/integration_test/query_test.go new file mode 100644 index 00000000000..b347ba8224c --- /dev/null +++ b/internal/api/grpc/session/v2beta/integration_test/query_test.go @@ -0,0 +1,512 @@ +//go:build integration + +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta" + session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" +) + +func TestServer_GetSession(t *testing.T) { + type args struct { + ctx context.Context + req *session.GetSessionRequest + dep func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 + } + tests := []struct { + name string + args args + want *session.GetSessionResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "get session, no id provided", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, not found", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "unknown", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, no permission", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + wantErr: true, + }, + { + name: "get session, permission, ok", + args: args{ + CTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, token, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, user agent, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + { + name: "get session, lifetime, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Lifetime: durationpb.New(5 * time.Minute), + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantExpirationWindow: 5 * time.Minute, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, metadata, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + }, + }, + { + name: "get session, user, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var sequence uint64 + if tt.args.dep != nil { + sequence = tt.args.dep(tt.args.ctx, t, tt.args.req) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.GetSession(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + tt.want.Session.Id = tt.args.req.SessionId + tt.want.Session.Sequence = sequence + verifySession(ttt, got.GetSession(), tt.want.GetSession(), time.Minute, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) + }, retryDuration, tick) + }) + } +} + +type sessionAttr struct { + ID string + UserID string + UserAgent string + CreationDate *timestamp.Timestamp + ChangeDate *timestamppb.Timestamp + Details *object.Details +} + +type sessionAttrs []*sessionAttr + +func (u sessionAttrs) ids() []string { + ids := make([]string, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return ids +} + +func createSessions(ctx context.Context, t *testing.T, count int, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) sessionAttrs { + infos := make([]*sessionAttr, count) + for i := 0; i < count; i++ { + infos[i] = createSession(ctx, t, userID, userAgent, lifetime, metadata) + } + return infos +} + +func createSession(ctx context.Context, t *testing.T, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) *sessionAttr { + req := &session.CreateSessionRequest{} + if userID != "" { + req.Checks = &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: userID, + }, + }, + } + } + if userAgent != "" { + req.UserAgent = &session.UserAgent{ + FingerprintId: gu.Ptr(userAgent), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + } + } + if lifetime != nil { + req.Lifetime = lifetime + } + if metadata != nil { + req.Metadata = metadata + } + resp, err := Client.CreateSession(ctx, req) + require.NoError(t, err) + return &sessionAttr{ + resp.GetSessionId(), + userID, + userAgent, + resp.GetDetails().GetChangeDate(), + resp.GetDetails().GetChangeDate(), + resp.GetDetails(), + } +} + +func TestServer_ListSessions(t *testing.T) { + type args struct { + ctx context.Context + req *session.ListSessionsRequest + dep func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr + } + tests := []struct { + name string + args args + want *session.ListSessionsResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "list sessions, not found", + args: args{ + CTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{ + {Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{"unknown"}}}}, + }, + }, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, wrong creator", + args: args{ + UserCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, "", "", nil, nil) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, full, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, multiple, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + infos := createSessions(ctx, t, 3, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: infos.ids()}}}) + return infos + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 3, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, userid, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + createdUser := createFullUser(ctx) + info := createSession(ctx, t, createdUser.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_UserIdQuery{UserIdQuery: &session.UserIDQuery{Id: createdUser.GetUserId()}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + infos := tt.args.dep(CTX, t, tt.args.req) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.ListSessions(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + if !assert.Equal(ttt, got.Details.TotalResult, tt.want.Details.TotalResult) || !assert.Len(ttt, got.Sessions, len(tt.want.Sessions)) { + return + } + + for i := range infos { + tt.want.Sessions[i].Id = infos[i].ID + tt.want.Sessions[i].Sequence = infos[i].Details.GetSequence() + tt.want.Sessions[i].CreationDate = infos[i].Details.GetChangeDate() + tt.want.Sessions[i].ChangeDate = infos[i].Details.GetChangeDate() + + verifySession(ttt, got.Sessions[i], tt.want.Sessions[i], time.Minute, tt.wantExpirationWindow, infos[i].UserID, tt.wantFactors...) + } + integration.AssertListDetails(ttt, tt.want, got) + }, retryDuration, tick) + }) + } +} diff --git a/internal/api/grpc/session/v2beta/integration_test/server_test.go b/internal/api/grpc/session/v2beta/integration_test/server_test.go new file mode 100644 index 00000000000..4920e6ec353 --- /dev/null +++ b/internal/api/grpc/session/v2beta/integration_test/server_test.go @@ -0,0 +1,74 @@ +//go:build integration + +package session_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/integration" + session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + CTX context.Context + IAMOwnerCTX context.Context + UserCTX context.Context + Instance *integration.Instance + Client session.SessionServiceClient + User *user.AddHumanUserResponse + DeactivatedUser *user.AddHumanUserResponse + LockedUser *user.AddHumanUserResponse +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + Client = Instance.Client.SessionV2beta + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) + UserCTX = Instance.WithAuthorization(ctx, integration.UserTypeNoPermission) + User = createFullUser(CTX) + DeactivatedUser = createDeactivatedUser(CTX) + LockedUser = createLockedUser(CTX) + return m.Run() + }()) +} + +func createFullUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetEmailCode(), + }) + Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetPhoneCode(), + }) + Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) + Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) + return userResp +} + +func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("deactivate human user") + return userResp +} + +func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("lock human user") + return userResp +} diff --git a/internal/api/grpc/session/v2beta/integration_test/session_test.go b/internal/api/grpc/session/v2beta/integration_test/session_test.go index 52e355204dd..26d22916296 100644 --- a/internal/api/grpc/session/v2beta/integration_test/session_test.go +++ b/internal/api/grpc/session/v2beta/integration_test/session_test.go @@ -5,7 +5,6 @@ package session_test import ( "context" "fmt" - "os" "testing" "time" @@ -14,7 +13,6 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -29,62 +27,6 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -var ( - CTX context.Context - IAMOwnerCTX context.Context - Instance *integration.Instance - Client session.SessionServiceClient - User *user.AddHumanUserResponse - DeactivatedUser *user.AddHumanUserResponse - LockedUser *user.AddHumanUserResponse -) - -func TestMain(m *testing.M) { - os.Exit(func() int { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - - Instance = integration.NewInstance(ctx) - Client = Instance.Client.SessionV2beta - - CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) - IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) - User = createFullUser(CTX) - DeactivatedUser = createDeactivatedUser(CTX) - LockedUser = createLockedUser(CTX) - return m.Run() - }()) -} - -func createFullUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetEmailCode(), - }) - Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetPhoneCode(), - }) - Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) - Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) - return userResp -} - -func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("deactivate human user") - return userResp -} - -func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("lock human user") - return userResp -} - func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { t.Helper() require.NotEmpty(t, id) @@ -96,15 +38,25 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo }) require.NoError(t, err) s := resp.GetSession() + want := &session.Session{ + Id: id, + Sequence: sequence, + Metadata: metadata, + UserAgent: userAgent, + } + verifySession(t, s, want, window, expirationWindow, userID, factors...) + return s +} - assert.Equal(t, id, s.GetId()) +func verifySession(t assert.TestingT, s *session.Session, want *session.Session, window time.Duration, expirationWindow time.Duration, userID string, factors ...wantFactor) { + assert.Equal(t, want.Id, s.GetId()) assert.WithinRange(t, s.GetCreationDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) assert.WithinRange(t, s.GetChangeDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) - assert.Equal(t, sequence, s.GetSequence()) - assert.Equal(t, metadata, s.GetMetadata()) + assert.Equal(t, want.Sequence, s.GetSequence()) + assert.Equal(t, want.Metadata, s.GetMetadata()) - if !proto.Equal(userAgent, s.GetUserAgent()) { - t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent) + if !proto.Equal(want.UserAgent, s.GetUserAgent()) { + t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), want.UserAgent) } if expirationWindow == 0 { assert.Nil(t, s.GetExpirationDate()) @@ -113,7 +65,6 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo } verifyFactors(t, s.GetFactors(), window, userID, factors) - return s } type wantFactor int @@ -129,7 +80,7 @@ const ( wantOTPEmailFactor ) -func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { +func verifyFactors(t assert.TestingT, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { for _, w := range want { switch w { case wantUserFactor: @@ -194,8 +145,15 @@ func TestServer_CreateSession(t *testing.T) { }, }, { - name: "user agent", + name: "full session", req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, Metadata: map[string][]byte{"foo": []byte("bar")}, UserAgent: &session.UserAgent{ FingerprintId: gu.Ptr("fingerPrintID"), @@ -205,6 +163,7 @@ func TestServer_CreateSession(t *testing.T) { "foo": {Values: []string{"foo", "bar"}}, }, }, + Lifetime: durationpb.New(5 * time.Minute), }, want: &session.CreateSessionResponse{ Details: &object.Details{ @@ -212,14 +171,6 @@ func TestServer_CreateSession(t *testing.T) { ResourceOwner: Instance.ID(), }, }, - wantUserAgent: &session.UserAgent{ - FingerprintId: gu.Ptr("fingerPrintID"), - Ip: gu.Ptr("1.2.3.4"), - Description: gu.Ptr("Description"), - Header: map[string]*session.UserAgent_HeaderValues{ - "foo": {Values: []string{"foo", "bar"}}, - }, - }, }, { name: "negative lifetime", @@ -229,40 +180,6 @@ func TestServer_CreateSession(t *testing.T) { }, wantErr: true, }, - { - name: "lifetime", - req: &session.CreateSessionRequest{ - Metadata: map[string][]byte{"foo": []byte("bar")}, - Lifetime: durationpb.New(5 * time.Minute), - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantExpirationWindow: 5 * time.Minute, - }, - { - name: "with user", - req: &session.CreateSessionRequest{ - Checks: &session.Checks{ - User: &session.CheckUser{ - Search: &session.CheckUser_UserId{ - UserId: User.GetUserId(), - }, - }, - }, - Metadata: map[string][]byte{"foo": []byte("bar")}, - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantFactors: []wantFactor{wantUserFactor}, - }, { name: "deactivated user", req: &session.CreateSessionRequest{ @@ -340,8 +257,6 @@ func TestServer_CreateSession(t *testing.T) { } require.NoError(t, err) integration.AssertDetails(t, tt.want, got) - - verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) }) } } @@ -946,21 +861,30 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.NoError(t, err) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("Bearer %s", createResp.GetSessionToken())) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) - require.Error(t, err) - require.Nil(t, sessionResp) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) + if !assert.Error(tt, err) { + return + } + assert.Nil(tt, sessionResp) + }, retryDuration, tick) } func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId()) - ctx := integration.WithAuthorizationToken(context.Background(), token) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) - webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() - require.NotNil(t, id, webAuthN.GetVerifiedAt().AsTime()) - require.True(t, webAuthN.GetUserVerified()) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() + assert.NotNil(tt, id, webAuthN.GetVerifiedAt().AsTime()) + assert.True(tt, webAuthN.GetUserVerified()) + }, retryDuration, tick) } func Test_ZITADEL_API_session_not_found(t *testing.T) { @@ -968,18 +892,30 @@ func Test_ZITADEL_API_session_not_found(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) //terminate the session and test it does not work anymore - _, err = Client.DeleteSession(CTX, &session.DeleteSessionRequest{ + _, err := Client.DeleteSession(CTX, &session.DeleteSessionRequest{ SessionId: id, SessionToken: gu.Ptr(token), }) require.NoError(t, err) + ctx = integration.WithAuthorizationToken(context.Background(), token) - _, err = Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.Error(t, err) + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.Error(tt, err) { + return + } + }, retryDuration, tick) } func Test_ZITADEL_API_session_expired(t *testing.T) { @@ -987,8 +923,13 @@ func Test_ZITADEL_API_session_expired(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) // ensure session expires and does not work anymore time.Sleep(20 * time.Second) diff --git a/internal/api/grpc/session/v2beta/server.go b/internal/api/grpc/session/v2beta/server.go index 550d013ad52..cf0d0c27f04 100644 --- a/internal/api/grpc/session/v2beta/server.go +++ b/internal/api/grpc/session/v2beta/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" ) @@ -16,6 +17,8 @@ type Server struct { session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries + + checkPermission domain.PermissionCheck } type Config struct{} @@ -23,10 +26,12 @@ type Config struct{} func CreateServer( command *command.Commands, query *query.Queries, + checkPermission domain.PermissionCheck, ) *Server { return &Server{ - command: command, - query: query, + command: command, + query: query, + checkPermission: checkPermission, } } diff --git a/internal/api/grpc/session/v2beta/session.go b/internal/api/grpc/session/v2beta/session.go index 7e67a4b3ff6..3b36b8ba83a 100644 --- a/internal/api/grpc/session/v2beta/session.go +++ b/internal/api/grpc/session/v2beta/session.go @@ -32,7 +32,7 @@ var ( ) func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) + res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) if err != nil { return nil, err } @@ -46,7 +46,7 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ if err != nil { return nil, err } - sessions, err := s.query.SearchSessions(ctx, queries) + sessions, err := s.query.SearchSessions(ctx, queries, s.checkPermission) if err != nil { return nil, err } diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index 9dec3fcf00a..b707631c224 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -159,7 +159,7 @@ func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - session, err := repo.Query.SessionByID(ctx, true, sessionID, token) + session, err := repo.Query.SessionByID(ctx, true, sessionID, token, nil) if err != nil { return "", "", "", err } diff --git a/internal/notification/handlers/mock/commands.mock.go b/internal/notification/handlers/mock/commands.mock.go index ee6eb3c6b14..de32ce067c0 100644 --- a/internal/notification/handlers/mock/commands.mock.go +++ b/internal/notification/handlers/mock/commands.mock.go @@ -25,7 +25,6 @@ import ( type MockCommands struct { ctrl *gomock.Controller recorder *MockCommandsMockRecorder - isgomock struct{} } // MockCommandsMockRecorder is the mock recorder for MockCommands. @@ -46,253 +45,253 @@ func (m *MockCommands) EXPECT() *MockCommandsMockRecorder { } // HumanEmailVerificationCodeSent mocks base method. -func (m *MockCommands) HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) HumanEmailVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // HumanEmailVerificationCodeSent indicates an expected call of HumanEmailVerificationCodeSent. -func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), arg0, arg1, arg2) } // HumanInitCodeSent mocks base method. -func (m *MockCommands) HumanInitCodeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) HumanInitCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanInitCodeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "HumanInitCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // HumanInitCodeSent indicates an expected call of HumanInitCodeSent. -func (mr *MockCommandsMockRecorder) HumanInitCodeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanInitCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), arg0, arg1, arg2) } // HumanOTPEmailCodeSent mocks base method. -func (m *MockCommands) HumanOTPEmailCodeSent(ctx context.Context, userID, resourceOwner string) error { +func (m *MockCommands) HumanOTPEmailCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", ctx, userID, resourceOwner) + ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // HumanOTPEmailCodeSent indicates an expected call of HumanOTPEmailCodeSent. -func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(ctx, userID, resourceOwner any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), ctx, userID, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), arg0, arg1, arg2) } // HumanOTPSMSCodeSent mocks base method. -func (m *MockCommands) HumanOTPSMSCodeSent(ctx context.Context, userID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) HumanOTPSMSCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", ctx, userID, resourceOwner, generatorInfo) + ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // HumanOTPSMSCodeSent indicates an expected call of HumanOTPSMSCodeSent. -func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(ctx, userID, resourceOwner, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), ctx, userID, resourceOwner, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), arg0, arg1, arg2, arg3) } // HumanPasswordlessInitCodeSent mocks base method. -func (m *MockCommands) HumanPasswordlessInitCodeSent(ctx context.Context, userID, resourceOwner, codeID string) error { +func (m *MockCommands) HumanPasswordlessInitCodeSent(arg0 context.Context, arg1, arg2, arg3 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", ctx, userID, resourceOwner, codeID) + ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // HumanPasswordlessInitCodeSent indicates an expected call of HumanPasswordlessInitCodeSent. -func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(ctx, userID, resourceOwner, codeID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), ctx, userID, resourceOwner, codeID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), arg0, arg1, arg2, arg3) } // HumanPhoneVerificationCodeSent mocks base method. -func (m *MockCommands) HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) HumanPhoneVerificationCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", ctx, orgID, userID, generatorInfo) + ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // HumanPhoneVerificationCodeSent indicates an expected call of HumanPhoneVerificationCodeSent. -func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(ctx, orgID, userID, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), ctx, orgID, userID, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), arg0, arg1, arg2, arg3) } // InviteCodeSent mocks base method. -func (m *MockCommands) InviteCodeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) InviteCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InviteCodeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "InviteCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // InviteCodeSent indicates an expected call of InviteCodeSent. -func (mr *MockCommandsMockRecorder) InviteCodeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), arg0, arg1, arg2) } // MilestonePushed mocks base method. -func (m *MockCommands) MilestonePushed(ctx context.Context, instanceID string, msType milestone.Type, endpoints []string) error { +func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 string, arg2 milestone.Type, arg3 []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MilestonePushed", ctx, instanceID, msType, endpoints) + ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // MilestonePushed indicates an expected call of MilestonePushed. -func (mr *MockCommandsMockRecorder) MilestonePushed(ctx, instanceID, msType, endpoints any) *gomock.Call { +func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), ctx, instanceID, msType, endpoints) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3) } // NotificationCanceled mocks base method. -func (m *MockCommands) NotificationCanceled(ctx context.Context, tx *sql.Tx, id, resourceOwner string, err error) error { +func (m *MockCommands) NotificationCanceled(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string, arg4 error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationCanceled", ctx, tx, id, resourceOwner, err) + ret := m.ctrl.Call(m, "NotificationCanceled", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // NotificationCanceled indicates an expected call of NotificationCanceled. -func (mr *MockCommandsMockRecorder) NotificationCanceled(ctx, tx, id, resourceOwner, err any) *gomock.Call { +func (mr *MockCommandsMockRecorder) NotificationCanceled(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationCanceled", reflect.TypeOf((*MockCommands)(nil).NotificationCanceled), ctx, tx, id, resourceOwner, err) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationCanceled", reflect.TypeOf((*MockCommands)(nil).NotificationCanceled), arg0, arg1, arg2, arg3, arg4) } // NotificationRetryRequested mocks base method. -func (m *MockCommands) NotificationRetryRequested(ctx context.Context, tx *sql.Tx, id, resourceOwner string, request *command.NotificationRetryRequest, err error) error { +func (m *MockCommands) NotificationRetryRequested(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string, arg4 *command.NotificationRetryRequest, arg5 error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationRetryRequested", ctx, tx, id, resourceOwner, request, err) + ret := m.ctrl.Call(m, "NotificationRetryRequested", arg0, arg1, arg2, arg3, arg4, arg5) ret0, _ := ret[0].(error) return ret0 } // NotificationRetryRequested indicates an expected call of NotificationRetryRequested. -func (mr *MockCommandsMockRecorder) NotificationRetryRequested(ctx, tx, id, resourceOwner, request, err any) *gomock.Call { +func (mr *MockCommandsMockRecorder) NotificationRetryRequested(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRetryRequested", reflect.TypeOf((*MockCommands)(nil).NotificationRetryRequested), ctx, tx, id, resourceOwner, request, err) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRetryRequested", reflect.TypeOf((*MockCommands)(nil).NotificationRetryRequested), arg0, arg1, arg2, arg3, arg4, arg5) } // NotificationSent mocks base method. -func (m *MockCommands) NotificationSent(ctx context.Context, tx *sql.Tx, id, instanceID string) error { +func (m *MockCommands) NotificationSent(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationSent", ctx, tx, id, instanceID) + ret := m.ctrl.Call(m, "NotificationSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // NotificationSent indicates an expected call of NotificationSent. -func (mr *MockCommandsMockRecorder) NotificationSent(ctx, tx, id, instanceID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) NotificationSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationSent", reflect.TypeOf((*MockCommands)(nil).NotificationSent), ctx, tx, id, instanceID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationSent", reflect.TypeOf((*MockCommands)(nil).NotificationSent), arg0, arg1, arg2, arg3) } // OTPEmailSent mocks base method. -func (m *MockCommands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error { +func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OTPEmailSent", ctx, sessionID, resourceOwner) + ret := m.ctrl.Call(m, "OTPEmailSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // OTPEmailSent indicates an expected call of OTPEmailSent. -func (mr *MockCommandsMockRecorder) OTPEmailSent(ctx, sessionID, resourceOwner any) *gomock.Call { +func (mr *MockCommandsMockRecorder) OTPEmailSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), ctx, sessionID, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), arg0, arg1, arg2) } // OTPSMSSent mocks base method. -func (m *MockCommands) OTPSMSSent(ctx context.Context, sessionID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) OTPSMSSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OTPSMSSent", ctx, sessionID, resourceOwner, generatorInfo) + ret := m.ctrl.Call(m, "OTPSMSSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // OTPSMSSent indicates an expected call of OTPSMSSent. -func (mr *MockCommandsMockRecorder) OTPSMSSent(ctx, sessionID, resourceOwner, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) OTPSMSSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), ctx, sessionID, resourceOwner, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), arg0, arg1, arg2, arg3) } // PasswordChangeSent mocks base method. -func (m *MockCommands) PasswordChangeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) PasswordChangeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PasswordChangeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "PasswordChangeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // PasswordChangeSent indicates an expected call of PasswordChangeSent. -func (mr *MockCommandsMockRecorder) PasswordChangeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) PasswordChangeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), arg0, arg1, arg2) } // PasswordCodeSent mocks base method. -func (m *MockCommands) PasswordCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) PasswordCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PasswordCodeSent", ctx, orgID, userID, generatorInfo) + ret := m.ctrl.Call(m, "PasswordCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // PasswordCodeSent indicates an expected call of PasswordCodeSent. -func (mr *MockCommandsMockRecorder) PasswordCodeSent(ctx, orgID, userID, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), ctx, orgID, userID, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2, arg3) } // RequestNotification mocks base method. -func (m *MockCommands) RequestNotification(ctx context.Context, instanceID string, request *command.NotificationRequest) error { +func (m *MockCommands) RequestNotification(arg0 context.Context, arg1 string, arg2 *command.NotificationRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RequestNotification", ctx, instanceID, request) + ret := m.ctrl.Call(m, "RequestNotification", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // RequestNotification indicates an expected call of RequestNotification. -func (mr *MockCommandsMockRecorder) RequestNotification(ctx, instanceID, request any) *gomock.Call { +func (mr *MockCommandsMockRecorder) RequestNotification(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestNotification", reflect.TypeOf((*MockCommands)(nil).RequestNotification), ctx, instanceID, request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestNotification", reflect.TypeOf((*MockCommands)(nil).RequestNotification), arg0, arg1, arg2) } // UsageNotificationSent mocks base method. -func (m *MockCommands) UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error { +func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UsageNotificationSent", ctx, dueEvent) + ret := m.ctrl.Call(m, "UsageNotificationSent", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UsageNotificationSent indicates an expected call of UsageNotificationSent. -func (mr *MockCommandsMockRecorder) UsageNotificationSent(ctx, dueEvent any) *gomock.Call { +func (mr *MockCommandsMockRecorder) UsageNotificationSent(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), ctx, dueEvent) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), arg0, arg1) } // UserDomainClaimedSent mocks base method. -func (m *MockCommands) UserDomainClaimedSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) UserDomainClaimedSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UserDomainClaimedSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "UserDomainClaimedSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // UserDomainClaimedSent indicates an expected call of UserDomainClaimedSent. -func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), arg0, arg1, arg2) } diff --git a/internal/notification/handlers/mock/queries.mock.go b/internal/notification/handlers/mock/queries.mock.go index 5ead2164376..670d3f38968 100644 --- a/internal/notification/handlers/mock/queries.mock.go +++ b/internal/notification/handlers/mock/queries.mock.go @@ -26,7 +26,6 @@ import ( type MockQueries struct { ctrl *gomock.Controller recorder *MockQueriesMockRecorder - isgomock struct{} } // MockQueriesMockRecorder is the mock recorder for MockQueries. @@ -61,240 +60,240 @@ func (mr *MockQueriesMockRecorder) ActiveInstances() *gomock.Call { } // ActiveLabelPolicyByOrg mocks base method. -func (m *MockQueries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.LabelPolicy, error) { +func (m *MockQueries) ActiveLabelPolicyByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.LabelPolicy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", ctx, orgID, withOwnerRemoved) + ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", arg0, arg1, arg2) ret0, _ := ret[0].(*query.LabelPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } // ActiveLabelPolicyByOrg indicates an expected call of ActiveLabelPolicyByOrg. -func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(ctx, orgID, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), ctx, orgID, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), arg0, arg1, arg2) } // ActivePrivateSigningKey mocks base method. -func (m *MockQueries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (*query.PrivateKeys, error) { +func (m *MockQueries) ActivePrivateSigningKey(arg0 context.Context, arg1 time.Time) (*query.PrivateKeys, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ActivePrivateSigningKey", ctx, t) + ret := m.ctrl.Call(m, "ActivePrivateSigningKey", arg0, arg1) ret0, _ := ret[0].(*query.PrivateKeys) ret1, _ := ret[1].(error) return ret0, ret1 } // ActivePrivateSigningKey indicates an expected call of ActivePrivateSigningKey. -func (mr *MockQueriesMockRecorder) ActivePrivateSigningKey(ctx, t any) *gomock.Call { +func (mr *MockQueriesMockRecorder) ActivePrivateSigningKey(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivePrivateSigningKey", reflect.TypeOf((*MockQueries)(nil).ActivePrivateSigningKey), ctx, t) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivePrivateSigningKey", reflect.TypeOf((*MockQueries)(nil).ActivePrivateSigningKey), arg0, arg1) } // CustomTextListByTemplate mocks base method. -func (m *MockQueries) CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error) { +func (m *MockQueries) CustomTextListByTemplate(arg0 context.Context, arg1, arg2 string, arg3 bool) (*query.CustomTexts, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CustomTextListByTemplate", ctx, aggregateID, template, withOwnerRemoved) + ret := m.ctrl.Call(m, "CustomTextListByTemplate", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(*query.CustomTexts) ret1, _ := ret[1].(error) return ret0, ret1 } // CustomTextListByTemplate indicates an expected call of CustomTextListByTemplate. -func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(ctx, aggregateID, template, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), ctx, aggregateID, template, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), arg0, arg1, arg2, arg3) } // GetActiveSigningWebKey mocks base method. -func (m *MockQueries) GetActiveSigningWebKey(ctx context.Context) (*jose.JSONWebKey, error) { +func (m *MockQueries) GetActiveSigningWebKey(arg0 context.Context) (*jose.JSONWebKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveSigningWebKey", ctx) + ret := m.ctrl.Call(m, "GetActiveSigningWebKey", arg0) ret0, _ := ret[0].(*jose.JSONWebKey) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActiveSigningWebKey indicates an expected call of GetActiveSigningWebKey. -func (mr *MockQueriesMockRecorder) GetActiveSigningWebKey(ctx any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetActiveSigningWebKey(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSigningWebKey", reflect.TypeOf((*MockQueries)(nil).GetActiveSigningWebKey), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSigningWebKey", reflect.TypeOf((*MockQueries)(nil).GetActiveSigningWebKey), arg0) } // GetDefaultLanguage mocks base method. -func (m *MockQueries) GetDefaultLanguage(ctx context.Context) language.Tag { +func (m *MockQueries) GetDefaultLanguage(arg0 context.Context) language.Tag { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDefaultLanguage", ctx) + ret := m.ctrl.Call(m, "GetDefaultLanguage", arg0) ret0, _ := ret[0].(language.Tag) return ret0 } // GetDefaultLanguage indicates an expected call of GetDefaultLanguage. -func (mr *MockQueriesMockRecorder) GetDefaultLanguage(ctx any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetDefaultLanguage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), arg0) } // GetInstanceRestrictions mocks base method. -func (m *MockQueries) GetInstanceRestrictions(ctx context.Context) (query.Restrictions, error) { +func (m *MockQueries) GetInstanceRestrictions(arg0 context.Context) (query.Restrictions, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInstanceRestrictions", ctx) + ret := m.ctrl.Call(m, "GetInstanceRestrictions", arg0) ret0, _ := ret[0].(query.Restrictions) ret1, _ := ret[1].(error) return ret0, ret1 } // GetInstanceRestrictions indicates an expected call of GetInstanceRestrictions. -func (mr *MockQueriesMockRecorder) GetInstanceRestrictions(ctx any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetInstanceRestrictions(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceRestrictions", reflect.TypeOf((*MockQueries)(nil).GetInstanceRestrictions), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceRestrictions", reflect.TypeOf((*MockQueries)(nil).GetInstanceRestrictions), arg0) } // GetNotifyUserByID mocks base method. -func (m *MockQueries) GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string) (*query.NotifyUser, error) { +func (m *MockQueries) GetNotifyUserByID(arg0 context.Context, arg1 bool, arg2 string) (*query.NotifyUser, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetNotifyUserByID", ctx, shouldTriggered, userID) + ret := m.ctrl.Call(m, "GetNotifyUserByID", arg0, arg1, arg2) ret0, _ := ret[0].(*query.NotifyUser) ret1, _ := ret[1].(error) return ret0, ret1 } // GetNotifyUserByID indicates an expected call of GetNotifyUserByID. -func (mr *MockQueriesMockRecorder) GetNotifyUserByID(ctx, shouldTriggered, userID any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetNotifyUserByID(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), ctx, shouldTriggered, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), arg0, arg1, arg2) } // InstanceByID mocks base method. -func (m *MockQueries) InstanceByID(ctx context.Context, id string) (authz.Instance, error) { +func (m *MockQueries) InstanceByID(arg0 context.Context, arg1 string) (authz.Instance, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InstanceByID", ctx, id) + ret := m.ctrl.Call(m, "InstanceByID", arg0, arg1) ret0, _ := ret[0].(authz.Instance) ret1, _ := ret[1].(error) return ret0, ret1 } // InstanceByID indicates an expected call of InstanceByID. -func (mr *MockQueriesMockRecorder) InstanceByID(ctx, id any) *gomock.Call { +func (mr *MockQueriesMockRecorder) InstanceByID(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceByID", reflect.TypeOf((*MockQueries)(nil).InstanceByID), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceByID", reflect.TypeOf((*MockQueries)(nil).InstanceByID), arg0, arg1) } // MailTemplateByOrg mocks base method. -func (m *MockQueries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.MailTemplate, error) { +func (m *MockQueries) MailTemplateByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.MailTemplate, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MailTemplateByOrg", ctx, orgID, withOwnerRemoved) + ret := m.ctrl.Call(m, "MailTemplateByOrg", arg0, arg1, arg2) ret0, _ := ret[0].(*query.MailTemplate) ret1, _ := ret[1].(error) return ret0, ret1 } // MailTemplateByOrg indicates an expected call of MailTemplateByOrg. -func (mr *MockQueriesMockRecorder) MailTemplateByOrg(ctx, orgID, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) MailTemplateByOrg(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), ctx, orgID, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), arg0, arg1, arg2) } // NotificationPolicyByOrg mocks base method. -func (m *MockQueries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error) { +func (m *MockQueries) NotificationPolicyByOrg(arg0 context.Context, arg1 bool, arg2 string, arg3 bool) (*query.NotificationPolicy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationPolicyByOrg", ctx, shouldTriggerBulk, orgID, withOwnerRemoved) + ret := m.ctrl.Call(m, "NotificationPolicyByOrg", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(*query.NotificationPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } // NotificationPolicyByOrg indicates an expected call of NotificationPolicyByOrg. -func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(ctx, shouldTriggerBulk, orgID, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), ctx, shouldTriggerBulk, orgID, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), arg0, arg1, arg2, arg3) } // NotificationProviderByIDAndType mocks base method. -func (m *MockQueries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error) { +func (m *MockQueries) NotificationProviderByIDAndType(arg0 context.Context, arg1 string, arg2 domain.NotificationProviderType) (*query.DebugNotificationProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", ctx, aggID, providerType) + ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", arg0, arg1, arg2) ret0, _ := ret[0].(*query.DebugNotificationProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // NotificationProviderByIDAndType indicates an expected call of NotificationProviderByIDAndType. -func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(ctx, aggID, providerType any) *gomock.Call { +func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), ctx, aggID, providerType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), arg0, arg1, arg2) } // SMSProviderConfigActive mocks base method. -func (m *MockQueries) SMSProviderConfigActive(ctx context.Context, resourceOwner string) (*query.SMSConfig, error) { +func (m *MockQueries) SMSProviderConfigActive(arg0 context.Context, arg1 string) (*query.SMSConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SMSProviderConfigActive", ctx, resourceOwner) + ret := m.ctrl.Call(m, "SMSProviderConfigActive", arg0, arg1) ret0, _ := ret[0].(*query.SMSConfig) ret1, _ := ret[1].(error) return ret0, ret1 } // SMSProviderConfigActive indicates an expected call of SMSProviderConfigActive. -func (mr *MockQueriesMockRecorder) SMSProviderConfigActive(ctx, resourceOwner any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SMSProviderConfigActive(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfigActive", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfigActive), ctx, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfigActive", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfigActive), arg0, arg1) } // SMTPConfigActive mocks base method. -func (m *MockQueries) SMTPConfigActive(ctx context.Context, resourceOwner string) (*query.SMTPConfig, error) { +func (m *MockQueries) SMTPConfigActive(arg0 context.Context, arg1 string) (*query.SMTPConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SMTPConfigActive", ctx, resourceOwner) + ret := m.ctrl.Call(m, "SMTPConfigActive", arg0, arg1) ret0, _ := ret[0].(*query.SMTPConfig) ret1, _ := ret[1].(error) return ret0, ret1 } // SMTPConfigActive indicates an expected call of SMTPConfigActive. -func (mr *MockQueriesMockRecorder) SMTPConfigActive(ctx, resourceOwner any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SMTPConfigActive(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigActive", reflect.TypeOf((*MockQueries)(nil).SMTPConfigActive), ctx, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigActive", reflect.TypeOf((*MockQueries)(nil).SMTPConfigActive), arg0, arg1) } // SearchInstanceDomains mocks base method. -func (m *MockQueries) SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) { +func (m *MockQueries) SearchInstanceDomains(arg0 context.Context, arg1 *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchInstanceDomains", ctx, queries) + ret := m.ctrl.Call(m, "SearchInstanceDomains", arg0, arg1) ret0, _ := ret[0].(*query.InstanceDomains) ret1, _ := ret[1].(error) return ret0, ret1 } // SearchInstanceDomains indicates an expected call of SearchInstanceDomains. -func (mr *MockQueriesMockRecorder) SearchInstanceDomains(ctx, queries any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SearchInstanceDomains(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), ctx, queries) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), arg0, arg1) } // SearchMilestones mocks base method. -func (m *MockQueries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error) { +func (m *MockQueries) SearchMilestones(arg0 context.Context, arg1 []string, arg2 *query.MilestonesSearchQueries) (*query.Milestones, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchMilestones", ctx, instanceIDs, queries) + ret := m.ctrl.Call(m, "SearchMilestones", arg0, arg1, arg2) ret0, _ := ret[0].(*query.Milestones) ret1, _ := ret[1].(error) return ret0, ret1 } // SearchMilestones indicates an expected call of SearchMilestones. -func (mr *MockQueriesMockRecorder) SearchMilestones(ctx, instanceIDs, queries any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SearchMilestones(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), ctx, instanceIDs, queries) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), arg0, arg1, arg2) } // SessionByID mocks base method. -func (m *MockQueries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error) { +func (m *MockQueries) SessionByID(arg0 context.Context, arg1 bool, arg2, arg3 string, arg4 domain.PermissionCheck) (*query.Session, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SessionByID", ctx, shouldTriggerBulk, id, sessionToken) + ret := m.ctrl.Call(m, "SessionByID", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(*query.Session) ret1, _ := ret[1].(error) return ret0, ret1 } // SessionByID indicates an expected call of SessionByID. -func (mr *MockQueriesMockRecorder) SessionByID(ctx, shouldTriggerBulk, id, sessionToken any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SessionByID(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), ctx, shouldTriggerBulk, id, sessionToken) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/notification/handlers/queries.go b/internal/notification/handlers/queries.go index 1c8d37598e9..a3d68e47978 100644 --- a/internal/notification/handlers/queries.go +++ b/internal/notification/handlers/queries.go @@ -20,7 +20,7 @@ type Queries interface { GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string) (*query.NotifyUser, error) CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error) SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) - SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error) + SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string, check domain.PermissionCheck) (*query.Session, error) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error) SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error) diff --git a/internal/notification/handlers/user_notifier.go b/internal/notification/handlers/user_notifier.go index ec30ab476f2..c24b87c2f69 100644 --- a/internal/notification/handlers/user_notifier.go +++ b/internal/notification/handlers/user_notifier.go @@ -400,7 +400,7 @@ func (u *userNotifier) reduceSessionOTPSMSChallenged(event eventstore.Event) (*h if alreadyHandled { return nil } - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return err } @@ -496,7 +496,7 @@ func (u *userNotifier) reduceSessionOTPEmailChallenged(event eventstore.Event) ( if alreadyHandled { return nil } - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return err } diff --git a/internal/notification/handlers/user_notifier_legacy.go b/internal/notification/handlers/user_notifier_legacy.go index 7df31cdf912..4bfa1a796e6 100644 --- a/internal/notification/handlers/user_notifier_legacy.go +++ b/internal/notification/handlers/user_notifier_legacy.go @@ -324,7 +324,7 @@ func (u *userNotifierLegacy) reduceSessionOTPSMSChallenged(event eventstore.Even return handler.NewNoOpStatement(e), nil } ctx := HandlerContext(event.Aggregate()) - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return nil, err } @@ -428,7 +428,7 @@ func (u *userNotifierLegacy) reduceSessionOTPEmailChallenged(event eventstore.Ev return handler.NewNoOpStatement(e), nil } ctx := HandlerContext(event.Aggregate()) - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return nil, err } diff --git a/internal/notification/handlers/user_notifier_legacy_test.go b/internal/notification/handlers/user_notifier_legacy_test.go index fe99eaa572e..02f21670f5b 100644 --- a/internal/notification/handlers/user_notifier_legacy_test.go +++ b/internal/notification/handlers/user_notifier_legacy_test.go @@ -1228,7 +1228,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, "testcode") expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1264,7 +1264,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, "testcode") expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{ Domains: []*query.InstanceDomain{{ Domain: instancePrimaryDomain, @@ -1306,7 +1306,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, testCode) expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1350,7 +1350,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { }}, }, nil) expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1386,7 +1386,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, testCode) expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1445,7 +1445,7 @@ func Test_userNotifierLegacy_reduceOTPSMSChallenged(t *testing.T) { Content: expectContent, } expectTemplateWithNotifyUserQueriesSMS(queries) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil) return fields{ queries: queries, @@ -1481,7 +1481,7 @@ func Test_userNotifierLegacy_reduceOTPSMSChallenged(t *testing.T) { Content: expectContent, } expectTemplateWithNotifyUserQueriesSMS(queries) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{ Domains: []*query.InstanceDomain{{ Domain: instancePrimaryDomain, diff --git a/internal/notification/handlers/user_notifier_test.go b/internal/notification/handlers/user_notifier_test.go index b57edcc57c8..b7b7ceb4460 100644 --- a/internal/notification/handlers/user_notifier_test.go +++ b/internal/notification/handlers/user_notifier_test.go @@ -980,7 +980,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) { name: "url with event trigger", test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { _, code := cryptoValue(t, ctrl, "testCode") - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1044,7 +1044,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) { IsPrimary: true, }}, }, nil) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1129,7 +1129,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) { name: "url template", test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { _, code := cryptoValue(t, ctrl, "testCode") - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1220,7 +1220,7 @@ func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) { test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { testCode := "testcode" _, code := cryptoValue(t, ctrl, testCode) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1284,7 +1284,7 @@ func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) { IsPrimary: true, }}, }, nil) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1339,7 +1339,7 @@ func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) { { name: "external code", test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ diff --git a/internal/query/session.go b/internal/query/session.go index 54afbde064e..d30fe4cda92 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -6,6 +6,7 @@ import ( "errors" "net" "net/http" + "slices" "time" sq "github.com/Masterminds/squirrel" @@ -80,6 +81,39 @@ type SessionsSearchQueries struct { Queries []SearchQuery } +func sessionsCheckPermission(ctx context.Context, sessions *Sessions, permissionCheck domain.PermissionCheck) { + sessions.Sessions = slices.DeleteFunc(sessions.Sessions, + func(session *Session) bool { + return sessionCheckPermission(ctx, session.ResourceOwner, session.Creator, session.UserAgent, session.UserFactor, permissionCheck) != nil + }, + ) +} + +func sessionCheckPermission(ctx context.Context, resourceOwner string, creator string, useragent domain.UserAgent, userFactor SessionUserFactor, permissionCheck domain.PermissionCheck) error { + data := authz.GetCtxData(ctx) + // no permission check necessary if user is creator + if data.UserID == creator { + return nil + } + // no permission check necessary if session belongs to the user + if userFactor.UserID != "" && data.UserID == userFactor.UserID { + return nil + } + // no permission check necessary if session belongs to the same useragent as used + if data.AgentID != "" && useragent.FingerprintID != nil && *useragent.FingerprintID != "" && data.AgentID == *useragent.FingerprintID { + return nil + } + // if session belongs to a user, check for permission on the user resource + if userFactor.ResourceOwner != "" { + if err := permissionCheck(ctx, domain.PermissionSessionRead, userFactor.ResourceOwner, userFactor.UserID); err != nil { + return err + } + return nil + } + // default, check for permission on instance + return permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, "") +} + func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { query = q.SearchRequest.toQuery(query) for _, q := range q.Queries { @@ -195,7 +229,24 @@ var ( } ) -func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (session *Session, err error) { +func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string, permissionCheck domain.PermissionCheck) (session *Session, err error) { + session, tokenID, err := q.sessionByID(ctx, shouldTriggerBulk, id) + if err != nil { + return nil, err + } + if sessionToken == "" { + if err := sessionCheckPermission(ctx, session.ResourceOwner, session.Creator, session.UserAgent, session.UserFactor, permissionCheck); err != nil { + return nil, err + } + return session, nil + } + if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil { + return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied") + } + return session, nil +} + +func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id string) (session *Session, tokenID string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -214,27 +265,31 @@ func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, s }, ).ToSql() if err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") + return nil, "", zerrors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") } - var tokenID string err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { session, tokenID, err = scan(row) return err }, stmt, args...) if err != nil { - return nil, err + return nil, "", err } - if sessionToken == "" { - return session, nil + return session, tokenID, nil +} + +func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) { + sessions, err := q.searchSessions(ctx, queries) + if err != nil { + return nil, err } - if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil { - return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied") + if permissionCheck != nil { + sessionsCheckPermission(ctx, sessions, permissionCheck) } - return session, nil + return sessions, nil } -func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { +func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -272,6 +327,10 @@ func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) { return NewTextQuery(SessionColumnCreator, creator, TextEquals) } +func NewSessionUserAgentFingerprintIDSearchQuery(fingerprintID string) (SearchQuery, error) { + return NewTextQuery(SessionColumnUserAgentFingerprintID, fingerprintID, TextEquals) +} + func NewUserIDSearchQuery(id string) (SearchQuery, error) { return NewTextQuery(SessionColumnUserID, id, TextEquals) } @@ -415,6 +474,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui SessionColumnOTPSMSCheckedAt.identifier(), SessionColumnOTPEmailCheckedAt.identifier(), SessionColumnMetadata.identifier(), + SessionColumnUserAgentFingerprintID.identifier(), + SessionColumnUserAgentIP.identifier(), + SessionColumnUserAgentDescription.identifier(), + SessionColumnUserAgentHeader.identifier(), SessionColumnExpiration.identifier(), countColumn.identifier(), ).From(sessionsTable.identifier()). @@ -441,6 +504,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui otpSMSCheckedAt sql.NullTime otpEmailCheckedAt sql.NullTime metadata database.Map[[]byte] + userAgentIP sql.NullString + userAgentHeader database.Map[[]string] expiration sql.NullTime ) @@ -465,6 +530,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui &otpSMSCheckedAt, &otpEmailCheckedAt, &metadata, + &session.UserAgent.FingerprintID, + &userAgentIP, + &session.UserAgent.Description, + &userAgentHeader, &expiration, &sessions.Count, ) @@ -485,6 +554,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui session.OTPSMSFactor.OTPCheckedAt = otpSMSCheckedAt.Time session.OTPEmailFactor.OTPCheckedAt = otpEmailCheckedAt.Time session.Metadata = metadata + session.UserAgent.Header = http.Header(userAgentHeader) + if userAgentIP.Valid { + session.UserAgent.IP = net.ParseIP(userAgentIP.String) + } session.Expiration = expiration.Time sessions.Sessions = append(sessions.Sessions, session) diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index c7929a98a8b..4109969262c 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -15,6 +15,7 @@ import ( "github.com/muhlemmer/gu" "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -71,6 +72,10 @@ var ( ` projections.sessions8.otp_sms_checked_at,` + ` projections.sessions8.otp_email_checked_at,` + ` projections.sessions8.metadata,` + + ` projections.sessions8.user_agent_fingerprint_id,` + + ` projections.sessions8.user_agent_ip,` + + ` projections.sessions8.user_agent_description,` + + ` projections.sessions8.user_agent_header,` + ` projections.sessions8.expiration,` + ` COUNT(*) OVER ()` + ` FROM projections.sessions8` + @@ -129,6 +134,10 @@ var ( "otp_sms_checked_at", "otp_email_checked_at", "metadata", + "user_agent_fingerprint_id", + "user_agent_ip", + "user_agent_description", + "user_agent_header", "expiration", "count", } @@ -186,6 +195,10 @@ func Test_SessionsPrepare(t *testing.T) { testNow, testNow, []byte(`{"key": "dmFsdWU="}`), + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), testNow, }, }, @@ -233,6 +246,12 @@ func Test_SessionsPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, Expiration: testNow, }, }, @@ -267,6 +286,10 @@ func Test_SessionsPrepare(t *testing.T) { testNow, testNow, []byte(`{"key": "dmFsdWU="}`), + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), testNow, }, { @@ -290,6 +313,10 @@ func Test_SessionsPrepare(t *testing.T) { testNow, testNow, []byte(`{"key": "dmFsdWU="}`), + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), testNow, }, }, @@ -337,6 +364,12 @@ func Test_SessionsPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, Expiration: testNow, }, { @@ -376,6 +409,12 @@ func Test_SessionsPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, Expiration: testNow, }, }, @@ -553,3 +592,157 @@ func prepareSessionQueryTesting(t *testing.T, token string) func(context.Context } } } + +func Test_sessionCheckPermission(t *testing.T) { + type args struct { + ctx context.Context + resourceOwner string + creator string + useragent domain.UserAgent + userFactor SessionUserFactor + permissionCheck domain.PermissionCheck + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "permission check, no user in context", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "", ""), + resourceOwner: "instance", + creator: "creator", + permissionCheck: expectedFailedPermissionCheck("instance", ""), + }, + wantErr: true, + }, + { + name: "permission check, factor, no user in context", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "", ""), + resourceOwner: "instance", + creator: "creator", + userFactor: SessionUserFactor{ResourceOwner: "resourceowner", UserID: "user"}, + permissionCheck: expectedFailedPermissionCheck("resourceowner", "user"), + }, + wantErr: true, + }, + { + name: "no permission check, creator", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "user", + }, + wantErr: false, + }, + { + name: "no permission check, same user", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "creator", + userFactor: SessionUserFactor{UserID: "user"}, + }, + wantErr: false, + }, + { + name: "no permission check, same useragent", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user1", "agent"), + resourceOwner: "instance", + creator: "creator", + userFactor: SessionUserFactor{UserID: "user2"}, + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("agent"), + }, + }, + wantErr: false, + }, + { + name: "permission check, factor", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{UserID: "user2", ResourceOwner: "resourceowner2"}, + permissionCheck: expectedSuccessfulPermissionCheck("resourceowner2", "user2"), + }, + wantErr: false, + }, + { + name: "permission check, factor, error", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{UserID: "user2", ResourceOwner: "resourceowner2"}, + permissionCheck: expectedFailedPermissionCheck("resourceowner2", "user2"), + }, + wantErr: true, + }, + { + name: "permission check", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{}, + permissionCheck: expectedSuccessfulPermissionCheck("instance", ""), + }, + wantErr: false, + }, + { + name: "permission check, error", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{}, + permissionCheck: expectedFailedPermissionCheck("instance", ""), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := sessionCheckPermission(tt.args.ctx, tt.args.resourceOwner, tt.args.creator, tt.args.useragent, tt.args.userFactor, tt.args.permissionCheck) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func expectedSuccessfulPermissionCheck(resourceOwner, userID string) func(ctx context.Context, permission, orgID, resourceID string) (err error) { + return func(ctx context.Context, permission, orgID, resourceID string) (err error) { + if orgID == resourceOwner && resourceID == userID { + return nil + } + return fmt.Errorf("permission check failed: %s %s", orgID, resourceID) + } +} + +func expectedFailedPermissionCheck(resourceOwner, userID string) func(ctx context.Context, permission, orgID, resourceID string) (err error) { + return func(ctx context.Context, permission, orgID, resourceID string) (err error) { + if orgID == resourceOwner && resourceID == userID { + return fmt.Errorf("permission check failed: %s %s", orgID, resourceID) + } + return nil + } +} diff --git a/proto/zitadel/session/v2/session.proto b/proto/zitadel/session/v2/session.proto index 2c17d81f994..7ab6b77610f 100644 --- a/proto/zitadel/session/v2/session.proto +++ b/proto/zitadel/session/v2/session.proto @@ -136,6 +136,8 @@ message SearchQuery { IDsQuery ids_query = 1; UserIDQuery user_id_query = 2; CreationDateQuery creation_date_query = 3; + CreatorQuery creator_query = 4; + UserAgentQuery user_agent_query = 5; } } @@ -157,9 +159,33 @@ message CreationDateQuery { ]; } +message CreatorQuery { + // ID of the user who created the session. If empty, the calling user's ID is used. + optional string id = 1 [ + (validate.rules).string = {max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + max_length: 200; + example: "\"69629023906488334\""; + } + ]; +} + +message UserAgentQuery { + // Finger print id of the user agent used for the session. + // Set an empty fingerprint_id to use the user agent from the call. + // If the user agent is not available from the current token, an error will be returned. + optional string fingerprint_id = 1 [ + (validate.rules).string = {max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + max_length: 200; + example: "\"69629023906488334\""; + } + ]; +} + message UserAgent { optional string fingerprint_id = 1; - optional string ip = 2; + optional string ip = 2; optional string description = 3; // A header may have multiple values. @@ -169,7 +195,7 @@ message UserAgent { message HeaderValues { repeated string values = 1; } - map header = 4; + map header = 4; } enum SessionFieldName {