Skip to content

Commit

Permalink
Add environment variable for rate limiting.
Browse files Browse the repository at this point in the history
  • Loading branch information
Robbie McKinstry committed Apr 24, 2018
1 parent 8d35777 commit b1456f5
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const EnvVaultWrapTTL = "VAULT_WRAP_TTL"
const EnvVaultMaxRetries = "VAULT_MAX_RETRIES"
const EnvVaultToken = "VAULT_TOKEN"
const EnvVaultMFA = "VAULT_MFA"
const EnvRateLimit = "VAULT_RATE_LIMIT"

// WrappingLookupFunc is a function that, given an HTTP verb and a path,
// returns an optional string duration to be used for response wrapping (e.g.
Expand Down Expand Up @@ -214,6 +215,7 @@ func (c *Config) ReadEnvironment() error {
var envInsecure bool
var envTLSServerName string
var envMaxRetries *uint64
var limit *rate.Limiter

// Parse the environment variables
if v := os.Getenv(EnvVaultAddress); v != "" {
Expand All @@ -238,6 +240,13 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultClientKey); v != "" {
envClientKey = v
}
if v := os.Getenv(EnvRateLimit); v != "" {
rateLimit, burstLimit, err := parseRateLimit(v)
if err != nil {
return err
}
limit = rate.NewLimiter(rate.Limit(rateLimit), burstLimit)
}
if t := os.Getenv(EnvVaultClientTimeout); t != "" {
clientTimeout, err := parseutil.ParseDurationSecond(t)
if err != nil {
Expand Down Expand Up @@ -269,6 +278,8 @@ func (c *Config) ReadEnvironment() error {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()

c.Limiter = limit

if err := c.ConfigureTLS(t); err != nil {
return err
}
Expand All @@ -288,6 +299,38 @@ func (c *Config) ReadEnvironment() error {
return nil
}

func parseRateLimit(val string) (rate float64, burst int, err error) {
// First, check to see if the limit has a colon in it.
const delimiter = ':'
var position = strings.IndexRune(val, delimiter)
if position == -1 {
rate, err = strconv.ParseFloat(val, 64)
burst = int(rate)
if err != nil {
err = fmt.Errorf("%v was provided by incorrectly formatted", EnvRateLimit)
}
return rate, burst, err
}
// Env variable contains both a rate and a burst.
// the rate segment of val is up to and excluding
// the index of the delimiter
rateStr := string(val[position])
burstStr := string(val[burst+1:])

rate, err = strconv.ParseFloat(rateStr, 64)
if err != nil {
err = fmt.Errorf("%v was provided by incorrectly formatted", EnvRateLimit)
return 0.0, 0, err
}
burst64, err := strconv.ParseInt(burstStr, 10, 0)
burst = int(burst64)
if err != nil {
err = fmt.Errorf("%v was provided by incorrectly formatted", EnvRateLimit)
}

return rate, burst, err
}

// Client is the client to the Vault API. Create a client with NewClient.
type Client struct {
modifyLock sync.RWMutex
Expand Down Expand Up @@ -379,10 +422,11 @@ func (c *Client) Address() string {

// SetLimiter will set the rate limiter for this client.
// This method is thread-safe.
func (c *Client) SetLimiter(limit *rate.Limiter) {
// rateLimit and burst are specified according to https://godoc.org/golang.org/x/time/rate#NewLimiter
func (c *Client) SetLimiter(rateLimit float64, burst int) {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
c.config.Limiter = limit
c.config.Limiter = rate.NewLimiter(rate.Limit(rateLimit), burst)
}

// SetMaxRetries sets the number of retries that will be used in the case of certain errors
Expand Down

0 comments on commit b1456f5

Please sign in to comment.