diff --git a/Godeps b/Godeps index 6a77c0f7a..b6bc08dd5 100644 --- a/Godeps +++ b/Godeps @@ -1,5 +1,9 @@ -github.com/BurntSushi/toml 3883ac1ce943878302255f538fce319d23226223 -github.com/bitly/go-simplejson 3378bdcb5cebedcbf8b5750edee28010f128fe24 -github.com/mreiferson/go-options ee94b57f2fbf116075426f853e5abbcdfeca8b3d -github.com/bmizerany/assert e17e99893cb6509f428e1728281c2ad60a6b31e3 -gopkg.in/fsnotify.v1 v1.2.0 +github.com/18F/hmacauth 1.0.1 +github.com/BurntSushi/toml 3883ac1ce943878302255f538fce319d23226223 +github.com/bitly/go-simplejson 3378bdcb5cebedcbf8b5750edee28010f128fe24 +github.com/mreiferson/go-options ee94b57f2fbf116075426f853e5abbcdfeca8b3d +github.com/bmizerany/assert e17e99893cb6509f428e1728281c2ad60a6b31e3 +gopkg.in/fsnotify.v1 v1.2.0 +golang.org/x/oauth2 397fe7649477ff2e8ced8fc0b2696f781e53745a +golang.org/x/oauth2/google 397fe7649477ff2e8ced8fc0b2696f781e53745a +google.golang.org/api/admin/directory/v1 a5c3e2a4792aff40e59840d9ecdff0542a202a80 diff --git a/README.md b/README.md index 370f10016..da71e1be1 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ oauth2_proxy (This project was renamed from Google Auth Proxy - May 2015) -A reverse proxy that provides authentication using Providers (Google, Github, and others) +A reverse proxy and static file server that provides authentication using Providers (Google, Github, and others) to validate accounts by email, domain or group. [![Build Status](https://secure.travis-ci.org/bitly/oauth2_proxy.png?branch=master)](http://travis-ci.org/bitly/oauth2_proxy) @@ -40,18 +40,42 @@ The provider can be selected using the `provider` configuration value. For Google, the registration steps are: 1. Create a new project: https://console.developers.google.com/project -2. Under "APIs & Auth", choose "Credentials" -3. Now, choose "Create new Client ID" - * The Application Type should be **Web application** - * Enter your domain in the Authorized Javascript Origins `https://internal.yourcompany.com` - * Enter the correct Authorized Redirect URL `https://internal.yourcompany.com/oauth2/callback` - * NOTE: `oauth2_proxy` will _only_ callback on the path `/oauth2/callback` -4. Under "APIs & Auth" choose "Consent Screen" - * Fill in the necessary fields and Save (this is _required_) -5. Take note of the **Client ID** and **Client Secret** +2. Choose the new project from the top right project dropdown (only if another project is selected) +3. In the project Dashboard center pane, choose **"Enable and manage APIs"** +4. In the left Nav pane, choose **"Credentials"** +5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save. +6. In the center pane, choose **"Credentials"** tab. + * Open the **"New credentials"** drop down + * Choose **"OAuth client ID"** + * Choose **"Web application"** + * Application name is freeform, choose something appropriate + * Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com` + * Authorized redirect URIs is the location of oath2/callback ex: `https://internal.yourcompany.com/oauth2/callback` + * Choose **"Create"** +4. Take note of the **Client ID** and **Client Secret** It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. +#### Restrict auth to specific Google groups on your domain. (optional) + +1. Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file. +2. Make note of the Client ID for a future step. +3. Under "APIs & Auth", choose APIs. +4. Click on Admin SDK and then Enable API. +5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes: +``` +https://www.googleapis.com/auth/admin.directory.group.readonly +https://www.googleapis.com/auth/admin.directory.user.readonly +``` +6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access. +7. Create or choose an existing administrative email address on the Gmail domain to assign to the ```google-admin-email``` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why. +8. Create or choose an existing email group and set that email to the ```google-group``` flag. You can pass multiple instances of this flag with different groups +and the user will be checked against all the provided groups. +9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the ```google-service-account-json``` flag. +10. Restart oauth2_proxy. + +Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). + ### GitHub Auth Provider 1. Create a new project: https://github.com/settings/developers @@ -94,14 +118,16 @@ An example [oauth2_proxy.cfg](contrib/oauth2_proxy.cfg.example) config file is i ``` Usage of oauth2_proxy: + -approval-prompt="force": Oauth approval_prompt -authenticated-emails-file="": authenticate against emails via file (one per line) + -basic-auth-password="": the password to set when passing the HTTP Basic Auth header -client-id="": the OAuth Client ID: ie: "123456.apps.googleusercontent.com" -client-secret="": the OAuth Client Secret -config="": path to config file -cookie-domain="": an optional cookie domain to force cookies to (ie: .yourcompany.com)* -cookie-expire=168h0m0s: expire timeframe for cookie -cookie-httponly=true: set HttpOnly cookie flag - -cookie-key="_oauth2_proxy": the name of the cookie that the oauth_proxy creates + -cookie-name="_oauth2_proxy": the name of the cookie that the oauth_proxy creates -cookie-refresh=0: refresh the cookie after this duration; 0 to disable -cookie-secret="": the seed string for secure cookies -cookie-secure=true: set secure (HTTPS) cookie flag @@ -110,13 +136,15 @@ Usage of oauth2_proxy: -email-domain=: authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email -github-org="": restrict logins to members of this organisation -github-team="": restrict logins to members of this team + -google-admin-email="": the google admin to impersonate for api calls + -google-group=: restrict logins to members of this google group (may be given multiple times). + -google-service-account-json="": the path to the service account json credentials -htpasswd-file="": additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption -http-address="127.0.0.1:4180": [http://]: or unix:// to listen on for HTTP clients -https-address=":443": : to listen on for HTTPS clients -login-url="": Authentication endpoint -pass-access-token=false: pass OAuth access_token to upstream via X-Forwarded-Access-Token header -pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream - -basic-auth-password="": the password to set when passing the HTTP Basic Auth header -pass-host-header=true: pass the request Host Header to upstream -profile-url="": Profile access endpoint -provider="google": OAuth provider @@ -125,23 +153,32 @@ Usage of oauth2_proxy: -redirect-url="": the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback" -request-logging=true: Log requests to stdout -scope="": Oauth scope specification + -signature-key="": GAP-Signature request signature key (algorithm:secretkey) -skip-auth-regex=: bypass authentication for requests path's that match (may be given multiple times) -tls-cert="": path to certificate file -tls-key="": path to private key file - -upstream=: the http url(s) of the upstream endpoint. If multiple, routing is based on path + -upstream=: the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path -validate-url="": Access token validation endpoint -version=false: print version string ``` See below for provider specific options +### Upstreams Configuration + +`oauth2_proxy` supports having multiple upstreams, and has the option to pass requests on to HTTP(S) servers or serve static files from the file system. HTTP and HTTPS upstreams are configured by providing a URL such as `http://127.0.0.1:8080/` for the upstream parameter, that will forward all authenticated requests to be forwarded to the upstream server. If you instead provide `http://127.0.0.1:8080/some/path/` then it will only be requests that start with `/some/path/` which are forwarded to the upstream. + +Static file paths are configured as a file:// URL. `file:///var/www/static/` will serve the files from that directory at `http://[oauth2_proxy url]/var/www/static/`, which may not be what you want. You can provide the path to where the files should be available by adding a fragment to the configured URL. The value of the fragment will then be used to specify which path the files are available at. `file:///var/www/static/#/static/` will ie. make `/var/www/static/` available at `http://[oauth2_proxy url]/static/`. + +Multiple upstreams can either be configured by supplying a comma separated list to the `-upstream` parameter, supplying the parameter multiple times or provinding a list in the [config file](#config-file). When multiple upstreams are used routing to them will be based on the path they are set up with. + ### Environment variables The environment variables `OAUTH2_PROXY_CLIENT_ID`, `OAUTH2_PROXY_CLIENT_SECRET`, `OAUTH2_PROXY_COOKIE_SECRET`, `OAUTH2_PROXY_COOKIE_DOMAIN` and `OAUTH2_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments. ## SSL Configuration -There are two recommended configurations. +There are two recommended configurations. 1) Configure SSL Terminiation with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`. @@ -171,7 +208,7 @@ Nginx will listen on port `443` and handle SSL connections while proxying to `oa `oauth2_proxy` will then authenticate requests for an upstream application. The external endpoint for this example would be `https://internal.yourcompany.com/`. -An example Nginx config follows. Note the use of `Strict-Transport-Security` header to pin requests to SSL +An example Nginx config follows. Note the use of `Strict-Transport-Security` header to pin requests to SSL via [HSTS](http://en.wikipedia.org/wiki/HTTP_Strict_Transport_Security): ``` @@ -207,7 +244,6 @@ The command line to run `oauth2_proxy` in this configuration would look like thi --client-secret=... ``` - ## Endpoint Documentation OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable. @@ -217,6 +253,25 @@ OAuth2 Proxy responds directly to the following endpoints. All other endpoints w * /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) * /oauth2/start - a URL that will redirect to start the OAuth cycle * /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url. +* /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request) + +## Request signatures + +If `signature_key` is defined, proxied requests will be signed with the +`GAP-Signature` header, which is a [Hash-based Message Authentication Code +(HMAC)](https://en.wikipedia.org/wiki/Hash-based_message_authentication_code) +of selected request information and the request body [see `SIGNATURE_HEADERS` +in `oauthproxy.go`](./oauthproxy.go). + +`signature_key` must be of the form `algorithm:secretkey`, (ie: `signature_key = "sha1:secret0"`) + +For more information about HMAC request signature validation, read the +following: + +* [Amazon Web Services: Signing and Authenticating REST + Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) +* [rc3.org: Using HMAC to authenticate Web service + requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/) ## Logging Format @@ -226,7 +281,6 @@ OAuth2 Proxy logs requests to stdout in a format similar to Apache Combined Log. - [19/Mar/2015:17:20:19 -0400] GET "/path/" HTTP/1.1 "" ``` - ## Adding a new Provider Follow the examples in the [`providers` package](providers/) to define a new @@ -234,3 +288,29 @@ Follow the examples in the [`providers` package](providers/) to define a new [`providers.New()`](providers/providers.go) to allow `oauth2_proxy` to use the new `Provider`. +## Configuring for use with the Nginx `auth_request` directive + +The [Nginx `auth_request` directive](http://nginx.org/en/docs/http/ngx_http_auth_request_module.html) allows Nginx to authenticate requests via the oauth2_proxy's `/auth` endpoint, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the request through. For example: + +```nginx +server { + listen 443 ssl spdy; + server_name ...; + include ssl/ssl.conf; + + location = /auth { + internal; + proxy_pass http://127.0.0.1:4180; + } + + location / { + auth_request /auth; + error_page 401 = ...; + + root /path/to/the/site; + default_type text/html; + charset utf-8; + charset_types application/json utf-8; + } +} +``` diff --git a/main.go b/main.go index e75240910..a8d3f1b6f 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ func main() { emailDomains := StringArray{} upstreams := StringArray{} skipAuthRegex := StringArray{} + googleGroups := StringArray{} config := flagSet.String("config", "", "path to config file") showVersion := flagSet.Bool("version", false, "print version string") @@ -29,7 +30,7 @@ func main() { flagSet.String("tls-cert", "", "path to certificate file") flagSet.String("tls-key", "", "path to private key file") flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") - flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path") + flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") @@ -39,6 +40,9 @@ func main() { flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.String("github-org", "", "restrict logins to members of this organisation") flagSet.String("github-team", "", "restrict logins to members of this team") + flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).") + flagSet.String("google-admin-email", "", "the google admin to impersonate for api calls") + flagSet.String("google-service-account-json", "", "the path to the service account json credentials") flagSet.String("client-id", "", "the OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") flagSet.String("client-secret", "", "the OAuth Client Secret") flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") @@ -62,7 +66,10 @@ func main() { flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("profile-url", "", "Profile access endpoint") flagSet.String("validate-url", "", "Access token validation endpoint") - flagSet.String("scope", "", "Oauth scope specification") + flagSet.String("scope", "", "OAuth scope specification") + flagSet.String("approval-prompt", "force", "OAuth approval_prompt") + + flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Parse(os.Args[1:]) @@ -90,7 +97,7 @@ func main() { } validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) - oauthproxy := NewOauthProxy(opts, validator) + oauthproxy := NewOAuthProxy(opts, validator) if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { if len(opts.EmailDomains) > 1 { diff --git a/oauthproxy.go b/oauthproxy.go index 07c3ec95e..dd69d6a41 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -14,11 +14,27 @@ import ( "strings" "time" + "github.com/18F/hmacauth" "github.com/bitly/oauth2_proxy/cookie" "github.com/bitly/oauth2_proxy/providers" ) -type OauthProxy struct { +const SignatureHeader = "GAP-Signature" + +var SignatureHeaders []string = []string{ + "Content-Length", + "Content-Md5", + "Content-Type", + "Date", + "Authorization", + "X-Forwarded-User", + "X-Forwarded-Email", + "X-Forwarded-Access-Token", + "Cookie", + "Gap-Auth", +} + +type OAuthProxy struct { CookieSeed string CookieName string CookieDomain string @@ -31,10 +47,11 @@ type OauthProxy struct { RobotsPath string PingPath string SignInPath string - OauthStartPath string - OauthCallbackPath string + OAuthStartPath string + OAuthCallbackPath string + AuthOnlyPath string - redirectUrl *url.URL // the url to receive requests at + redirectURL *url.URL // the url to receive requests at provider providers.Provider ProxyPrefix string SignInMessage string @@ -53,10 +70,15 @@ type OauthProxy struct { type UpstreamProxy struct { upstream string handler http.Handler + auth hmacauth.HmacAuth } func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("GAP-Upstream-Address", u.upstream) + if u.auth != nil { + r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) + u.auth.SignRequest(r) + } u.handler.ServeHTTP(w, r) } @@ -82,29 +104,50 @@ func setProxyDirector(proxy *httputil.ReverseProxy) { req.URL.RawQuery = "" } } +func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { + return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) +} -func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { +func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { serveMux := http.NewServeMux() - for _, u := range opts.proxyUrls { + var auth hmacauth.HmacAuth + if sigData := opts.signatureData; sigData != nil { + auth = hmacauth.NewHmacAuth(sigData.hash, []byte(sigData.key), + SignatureHeader, SignatureHeaders) + } + for _, u := range opts.proxyURLs { path := u.Path - u.Path = "" - log.Printf("mapping path %q => upstream %q", path, u) - proxy := NewReverseProxy(u) - if !opts.PassHostHeader { - setProxyUpstreamHostHeader(proxy, u) - } else { - setProxyDirector(proxy) + switch u.Scheme { + case "http", "https": + u.Path = "" + log.Printf("mapping path %q => upstream %q", path, u) + proxy := NewReverseProxy(u) + if !opts.PassHostHeader { + setProxyUpstreamHostHeader(proxy, u) + } else { + setProxyDirector(proxy) + } + serveMux.Handle(path, + &UpstreamProxy{u.Host, proxy, auth}) + case "file": + if u.Fragment != "" { + path = u.Fragment + } + log.Printf("mapping path %q => file system %q", path, u.Path) + proxy := NewFileServer(path, u.Path) + serveMux.Handle(path, &UpstreamProxy{path, proxy, nil}) + default: + panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) } - serveMux.Handle(path, &UpstreamProxy{u.Host, proxy}) } for _, u := range opts.CompiledRegex { log.Printf("compiled skip-auth-regex => %q", u) } - redirectUrl := opts.redirectUrl - redirectUrl.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) + redirectURL := opts.redirectURL + redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) - log.Printf("OauthProxy configured for %s Client ID: %s", opts.provider.Data().ProviderName, opts.ClientID) + log.Printf("OAuthProxy configured for %s Client ID: %s", opts.provider.Data().ProviderName, opts.ClientID) domain := opts.CookieDomain if domain == "" { domain = "" @@ -126,7 +169,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { } } - return &OauthProxy{ + return &OAuthProxy{ CookieName: opts.CookieName, CookieSeed: opts.CookieSecret, CookieDomain: opts.CookieDomain, @@ -139,13 +182,14 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { RobotsPath: "/robots.txt", PingPath: "/ping", SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), - OauthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), - OauthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix), + OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), + OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix), + AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, provider: opts.provider, serveMux: serveMux, - redirectUrl: redirectUrl, + redirectURL: redirectURL, skipAuthRegex: opts.SkipAuthRegex, compiledRegex: opts.CompiledRegex, PassBasicAuth: opts.PassBasicAuth, @@ -156,13 +200,13 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { } } -func (p *OauthProxy) GetRedirectURI(host string) string { +func (p *OAuthProxy) GetRedirectURI(host string) string { // default to the request Host if not set - if p.redirectUrl.Host != "" { - return p.redirectUrl.String() + if p.redirectURL.Host != "" { + return p.redirectURL.String() } var u url.URL - u = *p.redirectUrl + u = *p.redirectURL if u.Scheme == "" { if p.CookieSecure { u.Scheme = "https" @@ -174,16 +218,16 @@ func (p *OauthProxy) GetRedirectURI(host string) string { return u.String() } -func (p *OauthProxy) displayCustomLoginForm() bool { +func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { +func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } - redirectUri := p.GetRedirectURI(host) - s, err = p.provider.Redeem(redirectUri, code) + redirectURI := p.GetRedirectURI(host) + s, err = p.provider.Redeem(redirectURI, code) if err != nil { return } @@ -194,7 +238,7 @@ func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, e return } -func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { +func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { domain := req.Host if h, _, err := net.SplitHostPort(domain); err == nil { domain = h @@ -220,15 +264,15 @@ func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time } } -func (p *OauthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) { http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now())) } -func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) { +func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) { http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now())) } -func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { var age time.Duration c, err := req.Cookie(p.CookieName) if err != nil { @@ -249,7 +293,7 @@ func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt return session, age, nil } -func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { +func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { value, err := p.provider.CookieForSession(s, p.CookieCipher) if err != nil { return err @@ -258,30 +302,32 @@ func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p return nil } -func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) { +func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { rw.WriteHeader(http.StatusOK) fmt.Fprintf(rw, "User-agent: *\nDisallow: /") } -func (p *OauthProxy) PingPage(rw http.ResponseWriter) { +func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { rw.WriteHeader(http.StatusOK) fmt.Fprintf(rw, "OK") } -func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { +func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { log.Printf("ErrorPage %d %s %s", code, title, message) rw.WriteHeader(code) t := struct { - Title string - Message string + Title string + Message string + ProxyPrefix string }{ - Title: fmt.Sprintf("%d %s", code, title), - Message: message, + Title: fmt.Sprintf("%d %s", code, title), + Message: message, + ProxyPrefix: p.ProxyPrefix, } p.templates.ExecuteTemplate(rw, "error.html", t) } -func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { +func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { p.ClearCookie(rw, req) rw.WriteHeader(code) @@ -308,7 +354,7 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code p.templates.ExecuteTemplate(rw, "sign_in.html", t) } -func (p *OauthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { +func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { if req.Method != "POST" || p.HtpasswdFile == nil { return "", false } @@ -325,7 +371,7 @@ func (p *OauthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st return "", false } -func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) { +func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) { err := req.ParseForm() if err != nil { @@ -341,7 +387,7 @@ func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) { return redirect, err } -func (p *OauthProxy) IsWhitelistedPath(path string) (ok bool) { +func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { for _, u := range p.compiledRegex { ok = u.MatchString(path) if ok { @@ -359,7 +405,7 @@ func getRemoteAddr(req *http.Request) (s string) { return } -func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch path := req.URL.Path; { case path == p.RobotsPath: p.RobotsTxt(rw) @@ -369,16 +415,18 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.serveMux.ServeHTTP(rw, req) case path == p.SignInPath: p.SignIn(rw, req) - case path == p.OauthStartPath: - p.OauthStart(rw, req) - case path == p.OauthCallbackPath: - p.OauthCallback(rw, req) + case path == p.OAuthStartPath: + p.OAuthStart(rw, req) + case path == p.OAuthCallbackPath: + p.OAuthCallback(rw, req) + case path == p.AuthOnlyPath: + p.AuthenticateOnly(rw, req) default: p.Proxy(rw, req) } } -func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) @@ -395,7 +443,7 @@ func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { } } -func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) @@ -405,7 +453,7 @@ func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) { http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) } -func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { remoteAddr := getRemoteAddr(req) // finish the oauth cycle @@ -433,7 +481,7 @@ func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) { } // set cookie, or deny - if p.Validator(session.Email) { + if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { log.Printf("%s authentication complete %s", remoteAddr, session) err := p.SaveSession(rw, req, session) if err != nil { @@ -448,7 +496,28 @@ func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) { } } -func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { + status := p.Authenticate(rw, req) + if status == http.StatusAccepted { + rw.WriteHeader(http.StatusAccepted) + } else { + http.Error(rw, "unauthorized request", http.StatusUnauthorized) + } +} + +func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { + status := p.Authenticate(rw, req) + if status == http.StatusInternalServerError { + p.ErrorPage(rw, http.StatusInternalServerError, + "Internal Error", "Internal Error") + } else if status == http.StatusForbidden { + p.SignInPage(rw, req, http.StatusForbidden) + } else { + p.serveMux.ServeHTTP(rw, req) + } +} + +func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int { var saveSession, clearSession, revalidated bool remoteAddr := getRemoteAddr(req) @@ -477,7 +546,7 @@ func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { clearSession = true } - if saveSession && !revalidated && session.AccessToken != "" { + if saveSession && !revalidated && session != nil && session.AccessToken != "" { if !p.provider.ValidateSessionState(session) { log.Printf("%s removing session. error validating %s", remoteAddr, session) saveSession = false @@ -493,12 +562,11 @@ func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { clearSession = true } - if saveSession { + if saveSession && session != nil { err := p.SaveSession(rw, req, session) if err != nil { log.Printf("%s %s", remoteAddr, err) - p.ErrorPage(rw, 500, "Internal Error", "Internal Error") - return + return http.StatusInternalServerError } } @@ -514,8 +582,7 @@ func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } if session == nil { - p.SignInPage(rw, req, 403) - return + return http.StatusForbidden } // At this point, the user is authenticated. proxy normally @@ -534,11 +601,10 @@ func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } else { rw.Header().Set("GAP-Auth", session.Email) } - - p.serveMux.ServeHTTP(rw, req) + return http.StatusAccepted } -func (p *OauthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { +func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { if p.HtpasswdFile == nil { return nil, nil } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 52b48bb93..7af1de18b 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1,9 +1,12 @@ package main import ( + "crypto" "encoding/base64" + "github.com/18F/hmacauth" "github.com/bitly/oauth2_proxy/providers" "github.com/bmizerany/assert" + "io" "io/ioutil" "log" "net" @@ -75,13 +78,12 @@ func TestEncodedSlashes(t *testing.T) { func TestRobotsTxt(t *testing.T) { opts := NewOptions() - opts.Upstreams = append(opts.Upstreams, "unused") opts.ClientID = "bazquux" opts.ClientSecret = "foobar" opts.CookieSecret = "xyzzyplugh" opts.Validate() - proxy := NewOauthProxy(opts, func(string) bool { return true }) + proxy := NewOAuthProxy(opts, func(string) bool { return true }) rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) @@ -89,6 +91,45 @@ func TestRobotsTxt(t *testing.T) { assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) } +type TestProvider struct { + *providers.ProviderData + EmailAddress string + ValidToken bool +} + +func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider { + return &TestProvider{ + ProviderData: &providers.ProviderData{ + ProviderName: "Test Provider", + LoginURL: &url.URL{ + Scheme: "http", + Host: provider_url.Host, + Path: "/oauth/authorize", + }, + RedeemURL: &url.URL{ + Scheme: "http", + Host: provider_url.Host, + Path: "/oauth/token", + }, + ProfileURL: &url.URL{ + Scheme: "http", + Host: provider_url.Host, + Path: "/api/v1/profile", + }, + Scope: "profile.email", + }, + EmailAddress: email_address, + } +} + +func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { + return tp.EmailAddress, nil +} + +func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { + return tp.ValidToken +} + func TestBasicAuthPassword(t *testing.T) { provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%#v", r) @@ -122,30 +163,8 @@ func TestBasicAuthPassword(t *testing.T) { const email_address = "michael.bland@gsa.gov" const user_name = "michael.bland" - opts.provider = &TestProvider{ - ProviderData: &providers.ProviderData{ - ProviderName: "Test Provider", - LoginUrl: &url.URL{ - Scheme: "http", - Host: provider_url.Host, - Path: "/oauth/authorize", - }, - RedeemUrl: &url.URL{ - Scheme: "http", - Host: provider_url.Host, - Path: "/oauth/token", - }, - ProfileUrl: &url.URL{ - Scheme: "http", - Host: provider_url.Host, - Path: "/api/v1/profile", - }, - Scope: "profile.email", - }, - EmailAddress: email_address, - } - - proxy := NewOauthProxy(opts, func(email string) bool { + opts.provider = NewTestProvider(provider_url, email_address) + proxy := NewOAuthProxy(opts, func(email string) bool { return email == email_address }) @@ -184,23 +203,9 @@ func TestBasicAuthPassword(t *testing.T) { provider_server.Close() } -type TestProvider struct { - *providers.ProviderData - EmailAddress string - ValidToken bool -} - -func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { - return tp.EmailAddress, nil -} - -func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { - return tp.ValidToken -} - type PassAccessTokenTest struct { provider_server *httptest.Server - proxy *OauthProxy + proxy *OAuthProxy opts *Options } @@ -243,30 +248,8 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes provider_url, _ := url.Parse(t.provider_server.URL) const email_address = "michael.bland@gsa.gov" - t.opts.provider = &TestProvider{ - ProviderData: &providers.ProviderData{ - ProviderName: "Test Provider", - LoginUrl: &url.URL{ - Scheme: "http", - Host: provider_url.Host, - Path: "/oauth/authorize", - }, - RedeemUrl: &url.URL{ - Scheme: "http", - Host: provider_url.Host, - Path: "/oauth/token", - }, - ProfileUrl: &url.URL{ - Scheme: "http", - Host: provider_url.Host, - Path: "/api/v1/profile", - }, - Scope: "profile.email", - }, - EmailAddress: email_address, - } - - t.proxy = NewOauthProxy(t.opts, func(email string) bool { + t.opts.provider = NewTestProvider(provider_url, email_address) + t.proxy = NewOAuthProxy(t.opts, func(email string) bool { return email == email_address }) return t @@ -361,7 +344,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { type SignInPageTest struct { opts *Options - proxy *OauthProxy + proxy *OAuthProxy sign_in_regexp *regexp.Regexp } @@ -371,13 +354,12 @@ func NewSignInPageTest() *SignInPageTest { var sip_test SignInPageTest sip_test.opts = NewOptions() - sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused") sip_test.opts.CookieSecret = "foobar" sip_test.opts.ClientID = "bazquux" sip_test.opts.ClientSecret = "xyzzyplugh" sip_test.opts.Validate() - sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool { + sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { return true }) sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) @@ -427,7 +409,7 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { type ProcessCookieTest struct { opts *Options - proxy *OauthProxy + proxy *OAuthProxy rw *httptest.ResponseRecorder req *http.Request provider TestProvider @@ -443,7 +425,6 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { var pc_test ProcessCookieTest pc_test.opts = NewOptions() - pc_test.opts.Upstreams = append(pc_test.opts.Upstreams, "unused") pc_test.opts.ClientID = "bazquux" pc_test.opts.ClientSecret = "xyzzyplugh" pc_test.opts.CookieSecret = "0123456789abcdef" @@ -452,7 +433,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { pc_test.opts.CookieRefresh = time.Hour pc_test.opts.Validate() - pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool { + pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { return pc_test.validate_user }) pc_test.proxy.provider = &TestProvider{ @@ -558,3 +539,198 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { t.Errorf("expected nil session %#v", session) } } + +func NewAuthOnlyEndpointTest() *ProcessCookieTest { + pc_test := NewProcessCookieTestWithDefaults() + pc_test.req, _ = http.NewRequest("GET", + pc_test.opts.ProxyPrefix+"/auth", nil) + return pc_test +} + +func TestAuthOnlyEndpointAccepted(t *testing.T) { + test := NewAuthOnlyEndpointTest() + startSession := &providers.SessionState{ + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + test.SaveSession(startSession, time.Now()) + + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusAccepted, test.rw.Code) + bodyBytes, _ := ioutil.ReadAll(test.rw.Body) + assert.Equal(t, "", string(bodyBytes)) +} + +func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { + test := NewAuthOnlyEndpointTest() + + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + bodyBytes, _ := ioutil.ReadAll(test.rw.Body) + assert.Equal(t, "unauthorized request\n", string(bodyBytes)) +} + +func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { + test := NewAuthOnlyEndpointTest() + test.proxy.CookieExpire = time.Duration(24) * time.Hour + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) + startSession := &providers.SessionState{ + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + test.SaveSession(startSession, reference) + + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + bodyBytes, _ := ioutil.ReadAll(test.rw.Body) + assert.Equal(t, "unauthorized request\n", string(bodyBytes)) +} + +func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { + test := NewAuthOnlyEndpointTest() + startSession := &providers.SessionState{ + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + test.SaveSession(startSession, time.Now()) + test.validate_user = false + + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + bodyBytes, _ := ioutil.ReadAll(test.rw.Body) + assert.Equal(t, "unauthorized request\n", string(bodyBytes)) +} + +type SignatureAuthenticator struct { + auth hmacauth.HmacAuth +} + +func (v *SignatureAuthenticator) Authenticate( + w http.ResponseWriter, r *http.Request) { + result, headerSig, computedSig := v.auth.AuthenticateRequest(r) + if result == hmacauth.ResultNoSignature { + w.Write([]byte("no signature received")) + } else if result == hmacauth.ResultMatch { + w.Write([]byte("signatures match")) + } else if result == hmacauth.ResultMismatch { + w.Write([]byte("signatures do not match:" + + "\n received: " + headerSig + + "\n computed: " + computedSig)) + } else { + panic("Unknown result value: " + result.String()) + } +} + +type SignatureTest struct { + opts *Options + upstream *httptest.Server + upstream_host string + provider *httptest.Server + header http.Header + rw *httptest.ResponseRecorder + authenticator *SignatureAuthenticator +} + +func NewSignatureTest() *SignatureTest { + opts := NewOptions() + opts.CookieSecret = "cookie secret" + opts.ClientID = "client ID" + opts.ClientSecret = "client secret" + opts.EmailDomains = []string{"acm.org"} + + authenticator := &SignatureAuthenticator{} + upstream := httptest.NewServer( + http.HandlerFunc(authenticator.Authenticate)) + upstream_url, _ := url.Parse(upstream.URL) + opts.Upstreams = append(opts.Upstreams, upstream.URL) + + providerHandler := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"access_token": "my_auth_token"}`)) + } + provider := httptest.NewServer(http.HandlerFunc(providerHandler)) + provider_url, _ := url.Parse(provider.URL) + opts.provider = NewTestProvider(provider_url, "mbland@acm.org") + + return &SignatureTest{ + opts, + upstream, + upstream_url.Host, + provider, + make(http.Header), + httptest.NewRecorder(), + authenticator, + } +} + +func (st *SignatureTest) Close() { + st.provider.Close() + st.upstream.Close() +} + +// fakeNetConn simulates an http.Request.Body buffer that will be consumed +// when it is read by the hmacauth.HmacAuth if not handled properly. See: +// https://github.com/18F/hmacauth/pull/4 +type fakeNetConn struct { + reqBody string +} + +func (fnc *fakeNetConn) Read(p []byte) (n int, err error) { + if bodyLen := len(fnc.reqBody); bodyLen != 0 { + copy(p, fnc.reqBody) + fnc.reqBody = "" + return bodyLen, io.EOF + } + return 0, io.EOF +} + +func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { + err := st.opts.Validate() + if err != nil { + panic(err) + } + proxy := NewOAuthProxy(st.opts, func(email string) bool { return true }) + + var bodyBuf io.ReadCloser + if body != "" { + bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body}) + } + req, err := http.NewRequest(method, "/foo/bar", bodyBuf) + if err != nil { + panic(err) + } + req.Header = st.header + + state := &providers.SessionState{ + Email: "mbland@acm.org", AccessToken: "my_access_token"} + value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) + if err != nil { + panic(err) + } + cookie := proxy.MakeCookie(req, value, proxy.CookieExpire, time.Now()) + req.AddCookie(cookie) + // This is used by the upstream to validate the signature. + st.authenticator.auth = hmacauth.NewHmacAuth( + crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) + proxy.ServeHTTP(st.rw, req) +} + +func TestNoRequestSignature(t *testing.T) { + st := NewSignatureTest() + defer st.Close() + st.MakeRequestWithExpectedKey("GET", "", "") + assert.Equal(t, 200, st.rw.Code) + assert.Equal(t, st.rw.Body.String(), "no signature received") +} + +func TestRequestSignatureGetRequest(t *testing.T) { + st := NewSignatureTest() + defer st.Close() + st.opts.SignatureKey = "sha1:foobar" + st.MakeRequestWithExpectedKey("GET", "", "foobar") + assert.Equal(t, 200, st.rw.Code) + assert.Equal(t, st.rw.Body.String(), "signatures match") +} + +func TestRequestSignaturePostRequest(t *testing.T) { + st := NewSignatureTest() + defer st.Close() + st.opts.SignatureKey = "sha1:foobar" + payload := `{ "hello": "world!" }` + st.MakeRequestWithExpectedKey("POST", payload, "foobar") + assert.Equal(t, 200, st.rw.Code) + assert.Equal(t, st.rw.Body.String(), "signatures match") +} diff --git a/options.go b/options.go index bcf7d29cd..b64396cbc 100644 --- a/options.go +++ b/options.go @@ -1,12 +1,15 @@ package main import ( + "crypto" "fmt" "net/url" + "os" "regexp" "strings" "time" + "github.com/18F/hmacauth" "github.com/bitly/oauth2_proxy/providers" ) @@ -15,19 +18,22 @@ type Options struct { ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` HttpAddress string `flag:"http-address" cfg:"http_address"` HttpsAddress string `flag:"https-address" cfg:"https_address"` - RedirectUrl string `flag:"redirect-url" cfg:"redirect_url"` + RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` - AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` - EmailDomains []string `flag:"email-domain" cfg:"email_domains"` - GitHubOrg string `flag:"github-org" cfg:"github_org"` - GitHubTeam string `flag:"github-team" cfg:"github_team"` - HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` - DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` - CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` + AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` + EmailDomains []string `flag:"email-domain" cfg:"email_domains"` + GitHubOrg string `flag:"github-org" cfg:"github_org"` + GitHubTeam string `flag:"github-team" cfg:"github_team"` + GoogleGroups []string `flag:"google-group" cfg:"google_group"` + GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"` + GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"` + HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` + DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` + CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` @@ -46,20 +52,29 @@ type Options struct { // These options allow for other providers besides Google, with // potential overrides. - Provider string `flag:"provider" cfg:"provider"` - LoginUrl string `flag:"login-url" cfg:"login_url"` - RedeemUrl string `flag:"redeem-url" cfg:"redeem_url"` - ProfileUrl string `flag:"profile-url" cfg:"profile_url"` - ValidateUrl string `flag:"validate-url" cfg:"validate_url"` - Scope string `flag:"scope" cfg:"scope"` + Provider string `flag:"provider" cfg:"provider"` + LoginURL string `flag:"login-url" cfg:"login_url"` + RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` + ProfileURL string `flag:"profile-url" cfg:"profile_url"` + ValidateURL string `flag:"validate-url" cfg:"validate_url"` + Scope string `flag:"scope" cfg:"scope"` + ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` RequestLogging bool `flag:"request-logging" cfg:"request_logging"` + SignatureKey string `flag:"signature-key" cfg:"signature_key"` + // internal values that are set after config validation - redirectUrl *url.URL - proxyUrls []*url.URL + redirectURL *url.URL + proxyURLs []*url.URL CompiledRegex []*regexp.Regexp provider providers.Provider + signatureData *SignatureData +} + +type SignatureData struct { + hash crypto.Hash + key string } func NewOptions() *Options { @@ -76,11 +91,12 @@ func NewOptions() *Options { PassBasicAuth: true, PassAccessToken: false, PassHostHeader: true, + ApprovalPrompt: "force", RequestLogging: true, } } -func parseUrl(to_parse string, urltype string, msgs []string) (*url.URL, []string) { +func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { parsed, err := url.Parse(to_parse) if err != nil { return nil, append(msgs, fmt.Sprintf( @@ -103,20 +119,23 @@ func (o *Options) Validate() error { if o.ClientSecret == "" { msgs = append(msgs, "missing setting: client-secret") } + if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" { + msgs = append(msgs, "missing setting for email validation: email-domain or authenticated-emails-file required.\n use email-domain=* to authorize all email addresses") + } - o.redirectUrl, msgs = parseUrl(o.RedirectUrl, "redirect", msgs) + o.redirectURL, msgs = parseURL(o.RedirectURL, "redirect", msgs) for _, u := range o.Upstreams { - upstreamUrl, err := url.Parse(u) + upstreamURL, err := url.Parse(u) if err != nil { msgs = append(msgs, fmt.Sprintf( "error parsing upstream=%q %s", - upstreamUrl, err)) + upstreamURL, err)) } - if upstreamUrl.Path == "" { - upstreamUrl.Path = "/" + if upstreamURL.Path == "" { + upstreamURL.Path = "/" } - o.proxyUrls = append(o.proxyUrls, upstreamUrl) + o.proxyURLs = append(o.proxyURLs, upstreamURL) } for _, u := range o.SkipAuthRegex { @@ -154,6 +173,20 @@ func (o *Options) Validate() error { o.CookieExpire.String())) } + if len(o.GoogleGroups) > 0 || o.GoogleAdminEmail != "" || o.GoogleServiceAccountJSON != "" { + if len(o.GoogleGroups) < 1 { + msgs = append(msgs, "missing setting: google-group") + } + if o.GoogleAdminEmail == "" { + msgs = append(msgs, "missing setting: google-admin-email") + } + if o.GoogleServiceAccountJSON == "" { + msgs = append(msgs, "missing setting: google-service-account-json") + } + } + + msgs = parseSignatureKey(o, msgs) + if len(msgs) != 0 { return fmt.Errorf("Invalid configuration:\n %s", strings.Join(msgs, "\n ")) @@ -162,16 +195,51 @@ func (o *Options) Validate() error { } func parseProviderInfo(o *Options, msgs []string) []string { - p := &providers.ProviderData{Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret} - p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs) - p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs) - p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs) - p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs) + p := &providers.ProviderData{ + Scope: o.Scope, + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + ApprovalPrompt: o.ApprovalPrompt, + } + p.LoginURL, msgs = parseURL(o.LoginURL, "login", msgs) + p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs) + p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) + p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) o.provider = providers.New(o.Provider, p) switch p := o.provider.(type) { case *providers.GitHubProvider: p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) + case *providers.GoogleProvider: + if o.GoogleServiceAccountJSON != "" { + file, err := os.Open(o.GoogleServiceAccountJSON) + if err != nil { + msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) + } else { + p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) + } + } + } + return msgs +} + +func parseSignatureKey(o *Options, msgs []string) []string { + if o.SignatureKey == "" { + return msgs + } + + components := strings.Split(o.SignatureKey, ":") + if len(components) != 2 { + return append(msgs, "invalid signature hash:key spec: "+ + o.SignatureKey) + } + + algorithm, secretKey := components[0], components[1] + if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { + return append(msgs, "unsupported signature hash algorithm: "+ + o.SignatureKey) + } else { + o.signatureData = &SignatureData{hash, secretKey} } return msgs } diff --git a/options_test.go b/options_test.go index 8d8fdf875..8a8b6a776 100644 --- a/options_test.go +++ b/options_test.go @@ -1,6 +1,7 @@ package main import ( + "crypto" "net/url" "strings" "testing" @@ -15,6 +16,7 @@ func testOptions() *Options { o.CookieSecret = "foobar" o.ClientID = "bazquux" o.ClientSecret = "xyzzyplugh" + o.EmailDomains = []string{"*"} return o } @@ -27,6 +29,7 @@ func errorMsg(msgs []string) string { func TestNewOptions(t *testing.T) { o := NewOptions() + o.EmailDomains = []string{"*"} err := o.Validate() assert.NotEqual(t, nil, err) @@ -38,6 +41,32 @@ func TestNewOptions(t *testing.T) { assert.Equal(t, expected, err.Error()) } +func TestGoogleGroupOptions(t *testing.T) { + o := testOptions() + o.GoogleGroups = []string{"googlegroup"} + err := o.Validate() + assert.NotEqual(t, nil, err) + + expected := errorMsg([]string{ + "missing setting: google-admin-email", + "missing setting: google-service-account-json"}) + assert.Equal(t, expected, err.Error()) +} + +func TestGoogleGroupInvalidFile(t *testing.T) { + o := testOptions() + o.GoogleGroups = []string{"test_group"} + o.GoogleAdminEmail = "admin@example.com" + o.GoogleServiceAccountJSON = "file_doesnt_exist.json" + err := o.Validate() + assert.NotEqual(t, nil, err) + + expected := errorMsg([]string{ + "invalid Google credentials file: file_doesnt_exist.json", + }) + assert.Equal(t, expected, err.Error()) +} + func TestInitializedOptions(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) @@ -45,16 +74,16 @@ func TestInitializedOptions(t *testing.T) { // Note that it's not worth testing nonparseable URLs, since url.Parse() // seems to parse damn near anything. -func TestRedirectUrl(t *testing.T) { +func TestRedirectURL(t *testing.T) { o := testOptions() - o.RedirectUrl = "https://myhost.com/oauth2/callback" + o.RedirectURL = "https://myhost.com/oauth2/callback" assert.Equal(t, nil, o.Validate()) expected := &url.URL{ Scheme: "https", Host: "myhost.com", Path: "/oauth2/callback"} - assert.Equal(t, expected, o.redirectUrl) + assert.Equal(t, expected, o.redirectURL) } -func TestProxyUrls(t *testing.T) { +func TestProxyURLs(t *testing.T) { o := testOptions() o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") assert.Equal(t, nil, o.Validate()) @@ -63,7 +92,7 @@ func TestProxyUrls(t *testing.T) { // note the '/' was added &url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, } - assert.Equal(t, expected, o.proxyUrls) + assert.Equal(t, expected, o.proxyURLs) } func TestCompiledRegex(t *testing.T) { @@ -97,10 +126,10 @@ func TestDefaultProviderApiSettings(t *testing.T) { assert.Equal(t, nil, o.Validate()) p := o.provider.Data() assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", - p.LoginUrl.String()) + p.LoginURL.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", - p.RedeemUrl.String()) - assert.Equal(t, "", p.ProfileUrl.String()) + p.RedeemURL.String()) + assert.Equal(t, "", p.ProfileURL.String()) assert.Equal(t, "profile email", p.Scope) } @@ -138,3 +167,27 @@ func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { o.CookieRefresh -= time.Duration(1) assert.Equal(t, nil, o.Validate()) } + +func TestValidateSignatureKey(t *testing.T) { + o := testOptions() + o.SignatureKey = "sha1:secret" + assert.Equal(t, nil, o.Validate()) + assert.Equal(t, o.signatureData.hash, crypto.SHA1) + assert.Equal(t, o.signatureData.key, "secret") +} + +func TestValidateSignatureKeyInvalidSpec(t *testing.T) { + o := testOptions() + o.SignatureKey = "invalid spec" + err := o.Validate() + assert.Equal(t, err.Error(), "Invalid configuration:\n"+ + " invalid signature hash:key spec: "+o.SignatureKey) +} + +func TestValidateSignatureKeyUnsupportedAlgorithm(t *testing.T) { + o := testOptions() + o.SignatureKey = "unsupported:default secret" + err := o.Validate() + assert.Equal(t, err.Error(), "Invalid configuration:\n"+ + " unsupported signature hash algorithm: "+o.SignatureKey) +} diff --git a/providers/github.go b/providers/github.go index 4f2a988f2..cf0cfcbe2 100644 --- a/providers/github.go +++ b/providers/github.go @@ -2,7 +2,6 @@ package providers import ( "encoding/json" - "errors" "fmt" "io/ioutil" "log" @@ -18,22 +17,22 @@ type GitHubProvider struct { func NewGitHubProvider(p *ProviderData) *GitHubProvider { p.ProviderName = "GitHub" - if p.LoginUrl.String() == "" { - p.LoginUrl = &url.URL{ + if p.LoginURL == nil || p.LoginURL.String() == "" { + p.LoginURL = &url.URL{ Scheme: "https", Host: "github.com", Path: "/login/oauth/authorize", } } - if p.RedeemUrl.String() == "" { - p.RedeemUrl = &url.URL{ + if p.RedeemURL == nil || p.RedeemURL.String() == "" { + p.RedeemURL = &url.URL{ Scheme: "https", Host: "github.com", Path: "/login/oauth/access_token", } } - if p.ValidateUrl.String() == "" { - p.ValidateUrl = &url.URL{ + if p.ValidateURL == nil || p.ValidateURL.String() == "" { + p.ValidateURL = &url.URL{ Scheme: "https", Host: "api.github.com", Path: "/user/emails", @@ -66,7 +65,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { endpoint := "https://api.github.com/user/orgs?" + params.Encode() req, _ := http.NewRequest("GET", endpoint, nil) - req.Header.Set("Accept", "application/vnd.github.moondragon+json") + req.Header.Set("Accept", "application/vnd.github.v3+json") resp, err := http.DefaultClient.Do(req) if err != nil { return false, err @@ -85,11 +84,16 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { return false, err } + var presentOrgs []string for _, org := range orgs { if p.Org == org.Login { + log.Printf("Found Github Organization: %q", org.Login) return true, nil } + presentOrgs = append(presentOrgs, org.Login) } + + log.Printf("Missing Organization:%q in %v", p.Org, presentOrgs) return false, nil } @@ -111,7 +115,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { endpoint := "https://api.github.com/user/teams?" + params.Encode() req, _ := http.NewRequest("GET", endpoint, nil) - req.Header.Set("Accept", "application/vnd.github.moondragon+json") + req.Header.Set("Accept", "application/vnd.github.v3+json") resp, err := http.DefaultClient.Do(req) if err != nil { return false, err @@ -130,12 +134,28 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { return false, fmt.Errorf("%s unmarshaling %s", err, body) } + var hasOrg bool + presentOrgs := make(map[string]bool) + var presentTeams []string for _, team := range teams { + presentOrgs[team.Org.Login] = true if p.Org == team.Org.Login { - if p.Team == "" || p.Team == team.Slug { + hasOrg = true + if p.Team == team.Slug { + log.Printf("Found Github Organization:%q Team:%q (Name:%q)", team.Org.Login, team.Slug, team.Name) return true, nil } + presentTeams = append(presentTeams, team.Slug) + } + } + if hasOrg { + log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) + } else { + var allOrgs []string + for org, _ := range presentOrgs { + allOrgs = append(allOrgs, org) } + log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) } return false, nil } @@ -190,5 +210,5 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { } } - return "", errors.New("no email address found") + return "", nil } diff --git a/providers/google.go b/providers/google.go index 8c0a0ccf9..539657b03 100644 --- a/providers/google.go +++ b/providers/google.go @@ -6,43 +6,59 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "log" "net/http" "net/url" "strings" "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "google.golang.org/api/admin/directory/v1" ) type GoogleProvider struct { *ProviderData - RedeemRefreshUrl *url.URL + RedeemRefreshURL *url.URL + // GroupValidator is a function that determines if the passed email is in + // the configured Google group. + GroupValidator func(string) bool } func NewGoogleProvider(p *ProviderData) *GoogleProvider { p.ProviderName = "Google" - if p.LoginUrl.String() == "" { - p.LoginUrl = &url.URL{Scheme: "https", + if p.LoginURL.String() == "" { + p.LoginURL = &url.URL{Scheme: "https", Host: "accounts.google.com", Path: "/o/oauth2/auth", // to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline RawQuery: "access_type=offline", } } - if p.RedeemUrl.String() == "" { - p.RedeemUrl = &url.URL{Scheme: "https", + if p.RedeemURL.String() == "" { + p.RedeemURL = &url.URL{Scheme: "https", Host: "www.googleapis.com", Path: "/oauth2/v3/token"} } - if p.ValidateUrl.String() == "" { - p.ValidateUrl = &url.URL{Scheme: "https", + if p.ValidateURL.String() == "" { + p.ValidateURL = &url.URL{Scheme: "https", Host: "www.googleapis.com", Path: "/oauth2/v1/tokeninfo"} } if p.Scope == "" { p.Scope = "profile email" } - return &GoogleProvider{ProviderData: p} + + return &GoogleProvider{ + ProviderData: p, + // Set a default GroupValidator to just always return valid (true), it will + // be overwritten if we configured a Google group restriction. + GroupValidator: func(email string) bool { + return true + }, + } } func emailFromIdToken(idToken string) (string, error) { @@ -80,20 +96,20 @@ func jwtDecodeSegment(seg string) ([]byte, error) { return base64.URLEncoding.DecodeString(seg) } -func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err error) { +func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { if code == "" { err = errors.New("missing code") return } params := url.Values{} - params.Add("redirect_uri", redirectUrl) + params.Add("redirect_uri", redirectURL) params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") var req *http.Request - req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -111,7 +127,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err } if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) return } @@ -139,6 +155,102 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err return } +// SetGroupRestriction configures the GoogleProvider to restrict access to the +// specified group(s). AdminEmail has to be an administrative email on the domain that is +// checked. CredentialsFile is the path to a json file containing a Google service +// account credentials. +func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { + adminService := getAdminService(adminEmail, credentialsReader) + p.GroupValidator = func(email string) bool { + return userInGroup(adminService, groups, email) + } +} + +func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Service { + data, err := ioutil.ReadAll(credentialsReader) + if err != nil { + log.Fatal("can't read Google credentials file:", err) + } + conf, err := google.JWTConfigFromJSON(data, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) + if err != nil { + log.Fatal("can't load Google credentials file:", err) + } + conf.Subject = adminEmail + + client := conf.Client(oauth2.NoContext) + adminService, err := admin.New(client) + if err != nil { + log.Fatal(err) + } + return adminService +} + +func userInGroup(service *admin.Service, groups []string, email string) bool { + user, err := fetchUser(service, email) + if err != nil { + log.Printf("error fetching user: %v", err) + return false + } + id := user.Id + custID := user.CustomerId + + for _, group := range groups { + members, err := fetchGroupMembers(service, group) + if err != nil { + log.Printf("error fetching group members: %v", err) + return false + } + + for _, member := range members { + switch member.Type { + case "CUSTOMER": + if member.Id == custID { + return true + } + case "USER": + if member.Id == id { + return true + } + } + } + } + return false +} + +func fetchUser(service *admin.Service, email string) (*admin.User, error) { + user, err := service.Users.Get(email).Do() + return user, err +} + +func fetchGroupMembers(service *admin.Service, group string) ([]*admin.Member, error) { + members := []*admin.Member{} + pageToken := "" + for { + req := service.Members.List(group) + if pageToken != "" { + req.PageToken(pageToken) + } + r, err := req.Do() + if err != nil { + return nil, err + } + for _, member := range r.Members { + members = append(members, member) + } + if r.NextPageToken == "" { + break + } + pageToken = r.NextPageToken + } + return members, nil +} + +// ValidateGroup validates that the provided email exists in the configured Google +// group(s). +func (p *GoogleProvider) ValidateGroup(email string) bool { + return p.GroupValidator(email) +} + func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil @@ -148,6 +260,12 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { if err != nil { return false, err } + + // re-check that the user is in the proper google group(s) + if !p.ValidateGroup(s.Email) { + return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) + } + origExpiration := s.ExpiresOn s.AccessToken = newToken s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) @@ -163,7 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, params.Add("refresh_token", refreshToken) params.Add("grant_type", "refresh_token") var req *http.Request - req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -181,7 +299,7 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, } if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) return } diff --git a/providers/google_test.go b/providers/google_test.go index 0da80f457..8f9b0542e 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -23,10 +23,10 @@ func newGoogleProvider() *GoogleProvider { return NewGoogleProvider( &ProviderData{ ProviderName: "", - LoginUrl: &url.URL{}, - RedeemUrl: &url.URL{}, - ProfileUrl: &url.URL{}, - ValidateUrl: &url.URL{}, + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, Scope: ""}) } @@ -35,31 +35,31 @@ func TestGoogleProviderDefaults(t *testing.T) { assert.NotEqual(t, nil, p) assert.Equal(t, "Google", p.Data().ProviderName) assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", - p.Data().LoginUrl.String()) + p.Data().LoginURL.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", - p.Data().RedeemUrl.String()) + p.Data().RedeemURL.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v1/tokeninfo", - p.Data().ValidateUrl.String()) - assert.Equal(t, "", p.Data().ProfileUrl.String()) + p.Data().ValidateURL.String()) + assert.Equal(t, "", p.Data().ProfileURL.String()) assert.Equal(t, "profile email", p.Data().Scope) } func TestGoogleProviderOverrides(t *testing.T) { p := NewGoogleProvider( &ProviderData{ - LoginUrl: &url.URL{ + LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, - RedeemUrl: &url.URL{ + RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, - ProfileUrl: &url.URL{ + ProfileURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, - ValidateUrl: &url.URL{ + ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, @@ -67,13 +67,13 @@ func TestGoogleProviderOverrides(t *testing.T) { assert.NotEqual(t, nil, p) assert.Equal(t, "Google", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", - p.Data().LoginUrl.String()) + p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", - p.Data().RedeemUrl.String()) + p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/oauth/profile", - p.Data().ProfileUrl.String()) + p.Data().ProfileURL.String()) assert.Equal(t, "https://example.com/oauth/tokeninfo", - p.Data().ValidateUrl.String()) + p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } @@ -94,7 +94,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { }) assert.Equal(t, nil, err) var server *httptest.Server - p.RedeemUrl, server = newRedeemServer(body) + p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") @@ -105,6 +105,23 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { assert.Equal(t, "refresh12345", session.RefreshToken) } +func TestGoogleProviderValidateGroup(t *testing.T) { + p := newGoogleProvider() + p.GroupValidator = func(email string) bool { + return email == "michael.bland@gsa.gov" + } + assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) + p.GroupValidator = func(email string) bool { + return email != "michael.bland@gsa.gov" + } + assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) +} + +func TestGoogleProviderWithoutValidateGroup(t *testing.T) { + p := newGoogleProvider() + assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) +} + // func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p := newGoogleProvider() @@ -114,7 +131,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { }) assert.Equal(t, nil, err) var server *httptest.Server - p.RedeemUrl, server = newRedeemServer(body) + p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") @@ -133,7 +150,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { }) assert.Equal(t, nil, err) var server *httptest.Server - p.RedeemUrl, server = newRedeemServer(body) + p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") @@ -152,7 +169,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { }) assert.Equal(t, nil, err) var server *httptest.Server - p.RedeemUrl, server = newRedeemServer(body) + p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") diff --git a/providers/internal_util.go b/providers/internal_util.go index ff0cafa77..436744cb6 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -11,10 +11,10 @@ import ( // validateToken returns true if token is valid func validateToken(p Provider, access_token string, header http.Header) bool { - if access_token == "" || p.Data().ValidateUrl == nil { + if access_token == "" || p.Data().ValidateURL == nil { return false } - endpoint := p.Data().ValidateUrl.String() + endpoint := p.Data().ValidateURL.String() if len(header) == 0 { params := url.Values{"access_token": {access_token}} endpoint = endpoint + "?" + params.Encode() diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index bace76d5f..ad42bf10b 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -63,7 +63,7 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest { backend_url, _ := url.Parse(vt_test.backend.URL) vt_test.provider = &ValidateSessionStateTestProvider{ ProviderData: &ProviderData{ - ValidateUrl: &url.URL{ + ValidateURL: &url.URL{ Scheme: "http", Host: backend_url.Host, Path: "/oauth/tokeninfo", @@ -99,10 +99,10 @@ func TestValidateSessionStateEmptyToken(t *testing.T) { assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) } -func TestValidateSessionStateEmptyValidateUrl(t *testing.T) { +func TestValidateSessionStateEmptyValidateURL(t *testing.T) { vt_test := NewValidateSessionStateTest() defer vt_test.Close() - vt_test.provider.Data().ValidateUrl = nil + vt_test.provider.Data().ValidateURL = nil assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) } diff --git a/providers/linkedin.go b/providers/linkedin.go index 78ad3c9c3..971734c44 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -3,7 +3,6 @@ package providers import ( "errors" "fmt" - "log" "net/http" "net/url" @@ -16,23 +15,23 @@ type LinkedInProvider struct { func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { p.ProviderName = "LinkedIn" - if p.LoginUrl.String() == "" { - p.LoginUrl = &url.URL{Scheme: "https", + if p.LoginURL.String() == "" { + p.LoginURL = &url.URL{Scheme: "https", Host: "www.linkedin.com", Path: "/uas/oauth2/authorization"} } - if p.RedeemUrl.String() == "" { - p.RedeemUrl = &url.URL{Scheme: "https", + if p.RedeemURL.String() == "" { + p.RedeemURL = &url.URL{Scheme: "https", Host: "www.linkedin.com", Path: "/uas/oauth2/accessToken"} } - if p.ProfileUrl.String() == "" { - p.ProfileUrl = &url.URL{Scheme: "https", + if p.ProfileURL.String() == "" { + p.ProfileURL = &url.URL{Scheme: "https", Host: "www.linkedin.com", Path: "/v1/people/~/email-address"} } - if p.ValidateUrl.String() == "" { - p.ValidateUrl = p.ProfileUrl + if p.ValidateURL.String() == "" { + p.ValidateURL = p.ProfileURL } if p.Scope == "" { p.Scope = "r_emailaddress r_basicprofile" @@ -52,7 +51,7 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", nil) + req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) if err != nil { return "", err } @@ -60,13 +59,11 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { json, err := api.Request(req) if err != nil { - log.Printf("failed making request %s", err) return "", err } email, err := json.String() if err != nil { - log.Printf("failed making request %s", err) return "", err } return email, nil diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index c75a4a8d4..f43c96bf5 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -12,15 +12,15 @@ func testLinkedInProvider(hostname string) *LinkedInProvider { p := NewLinkedInProvider( &ProviderData{ ProviderName: "", - LoginUrl: &url.URL{}, - RedeemUrl: &url.URL{}, - ProfileUrl: &url.URL{}, - ValidateUrl: &url.URL{}, + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, Scope: ""}) if hostname != "" { - updateUrl(p.Data().LoginUrl, hostname) - updateUrl(p.Data().RedeemUrl, hostname) - updateUrl(p.Data().ProfileUrl, hostname) + updateURL(p.Data().LoginURL, hostname) + updateURL(p.Data().RedeemURL, hostname) + updateURL(p.Data().ProfileURL, hostname) } return p } @@ -47,32 +47,32 @@ func TestLinkedInProviderDefaults(t *testing.T) { assert.NotEqual(t, nil, p) assert.Equal(t, "LinkedIn", p.Data().ProviderName) assert.Equal(t, "https://www.linkedin.com/uas/oauth2/authorization", - p.Data().LoginUrl.String()) + p.Data().LoginURL.String()) assert.Equal(t, "https://www.linkedin.com/uas/oauth2/accessToken", - p.Data().RedeemUrl.String()) + p.Data().RedeemURL.String()) assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", - p.Data().ProfileUrl.String()) + p.Data().ProfileURL.String()) assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", - p.Data().ValidateUrl.String()) + p.Data().ValidateURL.String()) assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope) } func TestLinkedInProviderOverrides(t *testing.T) { p := NewLinkedInProvider( &ProviderData{ - LoginUrl: &url.URL{ + LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, - RedeemUrl: &url.URL{ + RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, - ProfileUrl: &url.URL{ + ProfileURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, - ValidateUrl: &url.URL{ + ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, @@ -80,13 +80,13 @@ func TestLinkedInProviderOverrides(t *testing.T) { assert.NotEqual(t, nil, p) assert.Equal(t, "LinkedIn", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", - p.Data().LoginUrl.String()) + p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", - p.Data().RedeemUrl.String()) + p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/oauth/profile", - p.Data().ProfileUrl.String()) + p.Data().ProfileURL.String()) assert.Equal(t, "https://example.com/oauth/tokeninfo", - p.Data().ValidateUrl.String()) + p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } diff --git a/providers/myusa.go b/providers/myusa.go index c244ed04a..ae76d3436 100644 --- a/providers/myusa.go +++ b/providers/myusa.go @@ -16,23 +16,23 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider { const myUsaHost string = "alpha.my.usa.gov" p.ProviderName = "MyUSA" - if p.LoginUrl.String() == "" { - p.LoginUrl = &url.URL{Scheme: "https", + if p.LoginURL.String() == "" { + p.LoginURL = &url.URL{Scheme: "https", Host: myUsaHost, Path: "/oauth/authorize"} } - if p.RedeemUrl.String() == "" { - p.RedeemUrl = &url.URL{Scheme: "https", + if p.RedeemURL.String() == "" { + p.RedeemURL = &url.URL{Scheme: "https", Host: myUsaHost, Path: "/oauth/token"} } - if p.ProfileUrl.String() == "" { - p.ProfileUrl = &url.URL{Scheme: "https", + if p.ProfileURL.String() == "" { + p.ProfileURL = &url.URL{Scheme: "https", Host: myUsaHost, Path: "/api/v1/profile"} } - if p.ValidateUrl.String() == "" { - p.ValidateUrl = &url.URL{Scheme: "https", + if p.ValidateURL.String() == "" { + p.ValidateURL = &url.URL{Scheme: "https", Host: myUsaHost, Path: "/api/v1/tokeninfo"} } @@ -44,7 +44,7 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider { func (p *MyUsaProvider) GetEmailAddress(s *SessionState) (string, error) { req, err := http.NewRequest("GET", - p.ProfileUrl.String()+"?access_token="+s.AccessToken, nil) + p.ProfileURL.String()+"?access_token="+s.AccessToken, nil) if err != nil { log.Printf("failed building request %s", err) return "", err diff --git a/providers/myusa_test.go b/providers/myusa_test.go index b4bdb30f9..d058845c7 100644 --- a/providers/myusa_test.go +++ b/providers/myusa_test.go @@ -9,7 +9,7 @@ import ( "github.com/bmizerany/assert" ) -func updateUrl(url *url.URL, hostname string) { +func updateURL(url *url.URL, hostname string) { url.Scheme = "http" url.Host = hostname } @@ -18,16 +18,16 @@ func testMyUsaProvider(hostname string) *MyUsaProvider { p := NewMyUsaProvider( &ProviderData{ ProviderName: "", - LoginUrl: &url.URL{}, - RedeemUrl: &url.URL{}, - ProfileUrl: &url.URL{}, - ValidateUrl: &url.URL{}, + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, Scope: ""}) if hostname != "" { - updateUrl(p.Data().LoginUrl, hostname) - updateUrl(p.Data().RedeemUrl, hostname) - updateUrl(p.Data().ProfileUrl, hostname) - updateUrl(p.Data().ValidateUrl, hostname) + updateURL(p.Data().LoginURL, hostname) + updateURL(p.Data().RedeemURL, hostname) + updateURL(p.Data().ProfileURL, hostname) + updateURL(p.Data().ValidateURL, hostname) } return p } @@ -53,32 +53,32 @@ func TestMyUsaProviderDefaults(t *testing.T) { assert.NotEqual(t, nil, p) assert.Equal(t, "MyUSA", p.Data().ProviderName) assert.Equal(t, "https://alpha.my.usa.gov/oauth/authorize", - p.Data().LoginUrl.String()) + p.Data().LoginURL.String()) assert.Equal(t, "https://alpha.my.usa.gov/oauth/token", - p.Data().RedeemUrl.String()) + p.Data().RedeemURL.String()) assert.Equal(t, "https://alpha.my.usa.gov/api/v1/profile", - p.Data().ProfileUrl.String()) + p.Data().ProfileURL.String()) assert.Equal(t, "https://alpha.my.usa.gov/api/v1/tokeninfo", - p.Data().ValidateUrl.String()) + p.Data().ValidateURL.String()) assert.Equal(t, "profile.email", p.Data().Scope) } func TestMyUsaProviderOverrides(t *testing.T) { p := NewMyUsaProvider( &ProviderData{ - LoginUrl: &url.URL{ + LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, - RedeemUrl: &url.URL{ + RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, - ProfileUrl: &url.URL{ + ProfileURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, - ValidateUrl: &url.URL{ + ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, @@ -86,13 +86,13 @@ func TestMyUsaProviderOverrides(t *testing.T) { assert.NotEqual(t, nil, p) assert.Equal(t, "MyUSA", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", - p.Data().LoginUrl.String()) + p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", - p.Data().RedeemUrl.String()) + p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/oauth/profile", - p.Data().ProfileUrl.String()) + p.Data().ProfileURL.String()) assert.Equal(t, "https://example.com/oauth/tokeninfo", - p.Data().ValidateUrl.String()) + p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } diff --git a/providers/provider_data.go b/providers/provider_data.go index 40cda0412..a13ed8e52 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -5,14 +5,15 @@ import ( ) type ProviderData struct { - ProviderName string - ClientID string - ClientSecret string - LoginUrl *url.URL - RedeemUrl *url.URL - ProfileUrl *url.URL - ValidateUrl *url.URL - Scope string + ProviderName string + ClientID string + ClientSecret string + LoginURL *url.URL + RedeemURL *url.URL + ProfileURL *url.URL + ValidateURL *url.URL + Scope string + ApprovalPrompt string } func (p *ProviderData) Data() *ProviderData { return p } diff --git a/providers/provider_default.go b/providers/provider_default.go index b18212fd6..77b3dfdf0 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -13,20 +13,20 @@ import ( "github.com/bitly/oauth2_proxy/cookie" ) -func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err error) { +func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { if code == "" { err = errors.New("missing code") return } params := url.Values{} - params.Add("redirect_uri", redirectUrl) + params.Add("redirect_uri", redirectURL) params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") var req *http.Request - req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -45,7 +45,7 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err er } if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) return } @@ -77,10 +77,10 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err er // GetLoginURL with typical oauth parameters func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { var a url.URL - a = *p.LoginUrl + a = *p.LoginURL params, _ := url.ParseQuery(a.RawQuery) params.Set("redirect_uri", redirectURI) - params.Set("approval_prompt", "force") + params.Set("approval_prompt", p.ApprovalPrompt) params.Add("scope", p.Scope) params.Set("client_id", p.ClientID) params.Set("response_type", "code") @@ -105,6 +105,12 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { return "", errors.New("not implemented") } +// ValidateGroup validates that the provided email exists in the configured provider +// email group(s). +func (p *ProviderData) ValidateGroup(email string) bool { + return true +} + func (p *ProviderData) ValidateSessionState(s *SessionState) bool { return validateToken(p, s.AccessToken, nil) } diff --git a/providers/providers.go b/providers/providers.go index 3192011e4..59e5f9a2c 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -8,6 +8,7 @@ type Provider interface { Data() *ProviderData GetEmailAddress(*SessionState) (string, error) Redeem(string, string) (*SessionState, error) + ValidateGroup(string) bool ValidateSessionState(*SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string RefreshSessionIfNeeded(*SessionState) (bool, error) diff --git a/validator.go b/validator.go index 396e6055c..e3c0a542b 100644 --- a/validator.go +++ b/validator.go @@ -71,9 +71,11 @@ func newValidatorImpl(domains []string, usersFile string, domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) } - validator := func(email string) bool { + validator := func(email string) (valid bool) { + if email == "" { + return + } email = strings.ToLower(email) - valid := false for _, domain := range domains { valid = valid || strings.HasSuffix(email, domain) }