diff --git a/internal/app/auth_server.go b/internal/app/auth_server.go index 7d80907c..5f96d9a5 100644 --- a/internal/app/auth_server.go +++ b/internal/app/auth_server.go @@ -4,11 +4,11 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "net/http" "net/url" "os" "strconv" - "strings" "github.com/caos/oidc/pkg/op" "github.com/golang/gddo/httputil/header" @@ -17,6 +17,7 @@ import ( "github.com/reearth/reearth-backend/internal/usecase/interactor" "github.com/reearth/reearth-backend/internal/usecase/interfaces" "github.com/reearth/reearth-backend/pkg/log" + "github.com/reearth/reearth-backend/pkg/user" ) const ( @@ -29,13 +30,9 @@ const ( func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *ServerConfig) { userUsecase := interactor.NewUser(cfg.Repos, cfg.Gateways, cfg.Config.SignupSecret, cfg.Config.Host_Web) - d := cfg.Config.AuthSrv.Domain - if d == "" { - d = cfg.Config.Host - } - domain, err := url.Parse(d) - if err != nil { - log.Panicf("auth: not valid auth domain: %s", d) + domain := cfg.Config.AuthServeDomainURL() + if domain == nil || domain.String() == "" { + log.Panicf("auth: not valid auth domain: %s", domain) } domain.Path = "/" @@ -95,7 +92,7 @@ func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *Server } // Actual login endpoint - r.POST(loginEndpoint, login(ctx, cfg, storage, userUsecase)) + r.POST(loginEndpoint, login(ctx, domain, storage, userUsecase)) r.GET(logoutEndpoint, logout()) @@ -191,44 +188,68 @@ type loginForm struct { AuthRequestID string `json:"id" form:"id"` } -func login(ctx context.Context, cfg *ServerConfig, storage op.Storage, userUsecase interfaces.User) func(ctx echo.Context) error { +func login(ctx context.Context, url *url.URL, storage op.Storage, userUsecase interfaces.User) func(ctx echo.Context) error { return func(ec echo.Context) error { request := new(loginForm) err := ec.Bind(request) if err != nil { log.Errorln("auth: filed to parse login request") - return ec.Redirect(http.StatusFound, redirectURL(ec.Request().Referer(), !cfg.Debug, "", "Bad request!")) + return ec.Redirect( + http.StatusFound, + redirectURL(url, "/login", "", "Bad request!"), + ) } - authRequest, err := storage.AuthRequestByID(ctx, request.AuthRequestID) - if err != nil { + if _, err := storage.AuthRequestByID(ctx, request.AuthRequestID); err != nil { log.Errorf("auth: filed to parse login request: %s\n", err) - return ec.Redirect(http.StatusFound, redirectURL(ec.Request().Referer(), !cfg.Debug, "", "Bad request!")) + return ec.Redirect( + http.StatusFound, + redirectURL(url, "/login", "", "Bad request!"), + ) } if len(request.Email) == 0 || len(request.Password) == 0 { log.Errorln("auth: one of credentials are not provided") - return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "Bad request!")) + return ec.Redirect( + http.StatusFound, + redirectURL(url, "/login", request.AuthRequestID, "Bad request!"), + ) } // check user credentials from db - user, err := userUsecase.GetUserByCredentials(ctx, interfaces.GetUserByCredentials{ + u, err := userUsecase.GetUserByCredentials(ctx, interfaces.GetUserByCredentials{ Email: request.Email, Password: request.Password, }) + var auth *user.Auth + if err == nil { + auth = u.GetAuthByProvider(authProvider) + if auth == nil { + err = errors.New("The account is not signed up with Re:Earth") + } + } if err != nil { log.Errorf("auth: wrong credentials: %s\n", err) - return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "Login failed; Invalid user ID or password.")) + return ec.Redirect( + http.StatusFound, + redirectURL(url, "/login", request.AuthRequestID, "Login failed; Invalid user ID or password."), + ) } // Complete the auth request && set the subject - err = storage.(*interactor.AuthStorage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider(authProvider).Sub) + err = storage.(*interactor.AuthStorage).CompleteAuthRequest(ctx, request.AuthRequestID, auth.Sub) if err != nil { log.Errorf("auth: failed to complete the auth request: %s\n", err) - return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "Bad request!")) + return ec.Redirect( + http.StatusFound, + redirectURL(url, "/login", request.AuthRequestID, "Bad request!"), + ) } - return ec.Redirect(http.StatusFound, "/authorize/callback?id="+request.AuthRequestID) + return ec.Redirect( + http.StatusFound, + redirectURL(url, "/authorize/callback", request.AuthRequestID, ""), + ) } } @@ -239,25 +260,27 @@ func logout() func(ec echo.Context) error { } } -func redirectURL(domain string, secure bool, requestID string, error string) string { - domain = strings.TrimPrefix(domain, "http://") - domain = strings.TrimPrefix(domain, "https://") - - schema := "http" - if secure { - schema = "https" +func redirectURL(u *url.URL, p string, requestID, err string) string { + v := cloneURL(u) + if p == "" { + p = "/login" } - - u := url.URL{ - Scheme: schema, - Host: domain, - Path: "login", - } - + v.Path = p queryValues := u.Query() queryValues.Set("id", requestID) - queryValues.Set("error", error) - u.RawQuery = queryValues.Encode() + if err != "" { + queryValues.Set("error", err) + } + v.RawQuery = queryValues.Encode() + return v.String() +} - return u.String() +func cloneURL(u *url.URL) *url.URL { + return &url.URL{ + Scheme: u.Scheme, + Opaque: u.Opaque, + User: u.User, + Host: u.Host, + Path: u.Path, + } } diff --git a/internal/app/config.go b/internal/app/config.go index 354f1cc4..56bb8a8a 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -143,6 +143,17 @@ func ReadConfig(debug bool) (*Config, error) { if debug { c.Dev = true } + c.Host = addHTTPScheme(c.Host) + if c.Host_Web == "" { + c.Host_Web = c.Host + } else { + c.Host_Web = addHTTPScheme(c.Host_Web) + } + if c.AuthSrv.Domain == "" { + c.AuthSrv.Domain = c.Host + } else { + c.AuthSrv.Domain = addHTTPScheme(c.AuthSrv.Domain) + } if c.Host_Web == "" { c.Host_Web = c.Host } @@ -242,3 +253,37 @@ func (ipd *AuthConfigs) Decode(value string) error { *ipd = providers return nil } + +func (c Config) HostURL() *url.URL { + u, err := url.Parse(c.Host) + if err != nil { + u = nil + } + return u +} + +func (c Config) HostWebURL() *url.URL { + u, err := url.Parse(c.Host_Web) + if err != nil { + u = nil + } + return u +} + +func (c Config) AuthServeDomainURL() *url.URL { + u, err := url.Parse(c.AuthSrv.Domain) + if err != nil { + u = nil + } + return u +} + +func addHTTPScheme(host string) string { + if host == "" { + return "" + } + if !strings.HasPrefix(host, "https://") && !strings.HasPrefix(host, "http://") { + host = "http://" + host + } + return host +} diff --git a/internal/app/config_test.go b/internal/app/config_test.go index d2405902..40b5a2fa 100644 --- a/internal/app/config_test.go +++ b/internal/app/config_test.go @@ -32,3 +32,9 @@ func TestReadConfig(t *testing.T) { assert.Equal(t, "hoge", cfg.Auth_ISS) assert.Equal(t, "foo", cfg.Auth_AUD) } + +func Test_AddHTTPScheme(t *testing.T) { + assert.Equal(t, "http://a", addHTTPScheme("a")) + assert.Equal(t, "http://a", addHTTPScheme("http://a")) + assert.Equal(t, "https://a", addHTTPScheme("https://a")) +}