Skip to content

Commit

Permalink
Handle vault redirects
Browse files Browse the repository at this point in the history
  • Loading branch information
Drew MacInnis authored and Drew MacInnis committed Nov 18, 2016
1 parent 3d0c34f commit 8bf28ec
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 23 deletions.
Binary file added gomplate
Binary file not shown.
19 changes: 15 additions & 4 deletions vault/app-id_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,31 @@ func NewAppIDAuthStrategy() *AppIDAuthStrategy {
return nil
}

// GetToken - log in to the app-id auth backend and return the client token
func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
func (a *AppIDAuthStrategy) GetHttpClient() *http.Client {
if a.hc == nil {
a.hc = &http.Client{Timeout: time.Second * 5}
}
client := a.hc
return a.hc
}

func (a *AppIDAuthStrategy) SetToken(req *http.Request) {
// no-op
}

func (a *AppIDAuthStrategy) Do(req *http.Request) (*http.Response, error) {
hc := a.GetHttpClient()
return hc.Do(req)
}

// GetToken - log in to the app-id auth backend and return the client token
func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
buf := new(bytes.Buffer)
json.NewEncoder(buf).Encode(&a)

u := &url.URL{}
*u = *addr
u.Path = "/v1/auth/app-id/login"
res, err := client.Post(u.String(), "application/json; charset=utf-8", buf)
res, err := requestAndFollow(a, "POST", u, buf.Bytes())
if err != nil {
return "", err
}
Expand Down
46 changes: 27 additions & 19 deletions vault/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ func getAuthStrategy() AuthStrategy {
return nil
}

func (c *Client) GetHttpClient() *http.Client {
if c.hc == nil {
c.hc = &http.Client{
Timeout: time.Second * 5,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
req.Header.Set("X-Vault-Token", c.token)
return nil
},
}
}
return c.hc
}

func (c *Client) SetToken(req *http.Request) {
req.Header.Set("X-Vault-Token", c.token)
}

func (c *Client) Do(req *http.Request) (*http.Response, error) {
hc := c.GetHttpClient()
return hc.Do(req)
}

// Login - log in to Vault with the discovered auth backend and save the token
func (c *Client) Login() error {
token, err := c.Auth.GetToken(c.Addr)
Expand All @@ -72,17 +94,12 @@ func (c *Client) RevokeToken() {
return
}

if c.hc == nil {
c.hc = &http.Client{Timeout: time.Second * 5}
}

u := &url.URL{}
*u = *c.Addr
u.Path = "/v1/auth/token/revoke-self"
req, _ := http.NewRequest("POST", u.String(), nil)
req.Header.Set("X-Vault-Token", c.token)

res, err := c.hc.Do(req)
res, err := requestAndFollow(c, "POST", u, nil)

if err != nil {
log.Println("Error while revoking Vault Token", err)
}
Expand All @@ -94,32 +111,23 @@ func (c *Client) RevokeToken() {

func (c *Client) Read(path string) ([]byte, error) {
path = normalizeURLPath(path)
if c.hc == nil {
c.hc = &http.Client{Timeout: time.Second * 5}
}

u := &url.URL{}
*u = *c.Addr
u.Path = "/v1" + path
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, err
}
req.Header.Set("X-Vault-Token", c.token)

res, err := c.hc.Do(req)
res, err := requestAndFollow(c, "GET", u, nil)
if err != nil {
return nil, err
}

body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return nil, err
}

if res.StatusCode != 200 {
err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, u, body)
err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, path, body)
return nil, err
}

Expand All @@ -131,7 +139,7 @@ func (c *Client) Read(path string) ([]byte, error) {
}

if _, ok := response["data"]; !ok {
return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", u, body)
return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", path, body)
}

return json.Marshal(response["data"])
Expand Down
47 changes: 47 additions & 0 deletions vault/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package vault

import (
"bytes"
"net/http"
"net/url"
)

// httpClient
type httpClient interface {
GetHttpClient() *http.Client
SetToken(req *http.Request)
Do(req *http.Request) (*http.Response, error)
}

func requestAndFollow(hc httpClient, method string, u *url.URL, body []byte) (*http.Response, error) {
var res *http.Response
var err error
for attempts := 0; attempts < 2; attempts++ {
reader := bytes.NewReader(body)
req, err := http.NewRequest(method, u.String(), reader)

if err != nil {
return nil, err
}
hc.SetToken(req)
if method == "POST" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

res, err = hc.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode == 307 {
res.Body.Close()
location, errLocation := res.Location()
if errLocation != nil {
return nil, errLocation
}
u.Host = location.Host
} else {
break
}
}
return res, err
}

0 comments on commit 8bf28ec

Please sign in to comment.