Skip to content

Commit

Permalink
NOISSUE - Use only sign in state in OAuth 2.0 flows (absmach#2270)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy Chao committed Sep 6, 2024
1 parent f138a9b commit d9f8359
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 174 deletions.
34 changes: 0 additions & 34 deletions pkg/oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,11 @@ package oauth2

import (
"context"
"errors"

mfclients "github.com/andychao217/magistrala/pkg/clients"
"golang.org/x/oauth2"
)

// State is the state of the OAuth2 flow.
type State uint8

const (
// SignIn is the state for the sign-in flow.
SignIn State = iota
// SignUp is the state for the sign-up flow.
SignUp
)

func (s State) String() string {
switch s {
case SignIn:
return "signin"
case SignUp:
return "signup"
default:
return "unknown"
}
}

// ToState converts string value to a valid OAuth2 state.
func ToState(state string) (State, error) {
switch state {
case "signin":
return SignIn, nil
case "signup":
return SignUp, nil
}

return State(0), errors.New("invalid state")
}

// Config is the configuration for the OAuth2 provider.
type Config struct {
ClientID string `env:"CLIENT_ID" envDefault:""`
Expand Down
16 changes: 2 additions & 14 deletions users/api/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,19 +543,7 @@ func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service) http.Handle
http.Redirect(w, r, oauth.ErrorURL()+"?error=oauth%20provider%20is%20disabled", http.StatusSeeOther)
return
}
// state is prefixed with signin- or signup- to indicate which flow we should use
var state string
var flow oauth2.State
var err error
if strings.Contains(r.FormValue("state"), "-") {
state = strings.Split(r.FormValue("state"), "-")[1]
flow, err = oauth2.ToState(strings.Split(r.FormValue("state"), "-")[0])
if err != nil {
http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther) //nolint:goconst
return
}
}

state := r.FormValue("state")
if state != oauth.State() {
http.Redirect(w, r, oauth.ErrorURL()+"?error=invalid%20state", http.StatusSeeOther)
return
Expand All @@ -574,7 +562,7 @@ func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service) http.Handle
return
}

jwt, err := svc.OAuthCallback(r.Context(), flow, client)
jwt, err := svc.OAuthCallback(r.Context(), client)
if err != nil {
http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther)
return
Expand Down
6 changes: 2 additions & 4 deletions users/api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/andychao217/magistrala"
mgclients "github.com/andychao217/magistrala/pkg/clients"
mgoauth2 "github.com/andychao217/magistrala/pkg/oauth2"
"github.com/andychao217/magistrala/users"
)

Expand Down Expand Up @@ -399,11 +398,10 @@ func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id str
return lm.svc.Identify(ctx, token)
}

func (lm *loggingMiddleware) OAuthCallback(ctx context.Context, state mgoauth2.State, client mgclients.Client) (token *magistrala.Token, err error) {
func (lm *loggingMiddleware) OAuthCallback(ctx context.Context, client mgclients.Client) (token *magistrala.Token, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("state", state.String()),
slog.String("user_id", client.ID),
}
if err != nil {
Expand All @@ -413,7 +411,7 @@ func (lm *loggingMiddleware) OAuthCallback(ctx context.Context, state mgoauth2.S
}
lm.logger.Info("OAuth callback completed successfully", args...)
}(time.Now())
return lm.svc.OAuthCallback(ctx, state, client)
return lm.svc.OAuthCallback(ctx, client)
}

// DeleteClient logs the delete_client request. It logs the client id and token and the time it took to complete the request.
Expand Down
10 changes: 4 additions & 6 deletions users/api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/andychao217/magistrala"
mgclients "github.com/andychao217/magistrala/pkg/clients"
mgoauth2 "github.com/andychao217/magistrala/pkg/oauth2"
"github.com/andychao217/magistrala/users"
"github.com/go-kit/kit/metrics"
)
Expand Down Expand Up @@ -193,13 +192,12 @@ func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (string
return ms.svc.Identify(ctx, token)
}

func (ms *metricsMiddleware) OAuthCallback(ctx context.Context, state mgoauth2.State, client mgclients.Client) (*magistrala.Token, error) {
method := "oauth_callback_" + state.String()
func (ms *metricsMiddleware) OAuthCallback(ctx context.Context, client mgclients.Client) (*magistrala.Token, error) {
defer func(begin time.Time) {
ms.counter.With("method", method).Add(1)
ms.latency.With("method", method).Observe(time.Since(begin).Seconds())
ms.counter.With("method", "oauth_callback").Add(1)
ms.latency.With("method", "oauth_callback").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.OAuthCallback(ctx, state, client)
return ms.svc.OAuthCallback(ctx, client)
}

// DeleteClient instruments DeleteClient method with metrics.
Expand Down
3 changes: 1 addition & 2 deletions users/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/andychao217/magistrala"
"github.com/andychao217/magistrala/pkg/clients"
mgoauth2 "github.com/andychao217/magistrala/pkg/oauth2"
)

// Service specifies an API that must be fullfiled by the domain service
Expand Down Expand Up @@ -80,5 +79,5 @@ type Service interface {

// OAuthCallback handles the callback from any supported OAuth provider.
// It processes the OAuth tokens and either signs in or signs up the user based on the provided state.
OAuthCallback(ctx context.Context, state mgoauth2.State, client clients.Client) (*magistrala.Token, error)
OAuthCallback(ctx context.Context, client clients.Client) (*magistrala.Token, error)
}
2 changes: 0 additions & 2 deletions users/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,12 @@ func (spre sendPasswordResetEvent) Encode() (map[string]interface{}, error) {
}

type oauthCallbackEvent struct {
state string
clientID string
}

func (oce oauthCallbackEvent) Encode() (map[string]interface{}, error) {
return map[string]interface{}{
"operation": oauthCallback,
"state": oce.state,
"client_id": oce.clientID,
}, nil
}
Expand Down
6 changes: 2 additions & 4 deletions users/events/streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
mgclients "github.com/andychao217/magistrala/pkg/clients"
"github.com/andychao217/magistrala/pkg/events"
"github.com/andychao217/magistrala/pkg/events/store"
mgoauth2 "github.com/andychao217/magistrala/pkg/oauth2"
"github.com/andychao217/magistrala/users"
)

Expand Down Expand Up @@ -297,14 +296,13 @@ func (es *eventStore) SendPasswordReset(ctx context.Context, host, email, user,
return es.Publish(ctx, event)
}

func (es *eventStore) OAuthCallback(ctx context.Context, state mgoauth2.State, client mgclients.Client) (*magistrala.Token, error) {
token, err := es.svc.OAuthCallback(ctx, state, client)
func (es *eventStore) OAuthCallback(ctx context.Context, client mgclients.Client) (*magistrala.Token, error) {
token, err := es.svc.OAuthCallback(ctx, client)
if err != nil {
return token, err
}

event := oauthCallbackEvent{
state: state.String(),
clientID: client.ID,
}

Expand Down
20 changes: 9 additions & 11 deletions users/mocks/service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 21 additions & 28 deletions users/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ import (
"github.com/andychao217/magistrala/pkg/errors"
repoerr "github.com/andychao217/magistrala/pkg/errors/repository"
svcerr "github.com/andychao217/magistrala/pkg/errors/service"
mgoauth2 "github.com/andychao217/magistrala/pkg/oauth2"
"github.com/andychao217/magistrala/users/postgres"
"github.com/go-redis/redis/v8"
"golang.org/x/sync/errgroup"
)

var (
errIssueToken = errors.New("failed to issue token")
errUserNotSignedUp = errors.New("user not signed up")
errFailedPermissionsList = errors.New("failed to list permissions")
errRecoveryToken = errors.New("failed to generate password recovery token")
errLoginDisableUser = errors.New("failed to login in disabled user")
Expand Down Expand Up @@ -961,37 +959,32 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per
return res.GetId(), nil
}

func (svc service) OAuthCallback(ctx context.Context, state mgoauth2.State, client mgclients.Client) (*magistrala.Token, error) {
switch state {
case mgoauth2.SignIn:
rclient, err := svc.clients.RetrieveByIdentity(ctx, client.Credentials.Identity)
if err != nil {
if errors.Contains(err, repoerr.ErrNotFound) {
return &magistrala.Token{}, errors.Wrap(svcerr.ErrNotFound, errUserNotSignedUp)
}
return &magistrala.Token{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
claims := &magistrala.IssueReq{
UserId: rclient.ID,
Type: uint32(auth.AccessKey),
}
return svc.auth.Issue(ctx, claims)
case mgoauth2.SignUp:
rclient, err := svc.RegisterClient(ctx, "", client)
if err != nil {
if errors.Contains(err, repoerr.ErrConflict) {
return &magistrala.Token{}, errors.Wrap(svcerr.ErrConflict, errors.New("user already exists"))
func (svc service) OAuthCallback(ctx context.Context, client mgclients.Client) (*magistrala.Token, error) {
rclient, err := svc.clients.RetrieveByIdentity(ctx, client.Credentials.Identity)
if err != nil {
switch errors.Contains(err, repoerr.ErrNotFound) {
case true:
rclient, err = svc.RegisterClient(ctx, "", client)
if err != nil {
return &magistrala.Token{}, err
}
default:
return &magistrala.Token{}, err
}
claims := &magistrala.IssueReq{
UserId: rclient.ID,
Type: uint32(auth.AccessKey),
}

if _, err = svc.authorize(ctx, auth.UserType, auth.UsersKind, rclient.ID, auth.MembershipPermission, auth.PlatformType, auth.MagistralaObject); err != nil {
if err := svc.addClientPolicy(ctx, rclient.ID, rclient.Role); err != nil {
return &magistrala.Token{}, err
}
return svc.auth.Issue(ctx, claims)
default:
return &magistrala.Token{}, fmt.Errorf("unknown state %s", state)
}

claims := &magistrala.IssueReq{
UserId: rclient.ID,
Type: uint32(auth.AccessKey),
}

return svc.auth.Issue(ctx, claims)
}

func (svc service) Identify(ctx context.Context, token string) (string, error) {
Expand Down
Loading

0 comments on commit d9f8359

Please sign in to comment.