Skip to content

Commit

Permalink
Remove defer and test parseRateLimiter (#4540)
Browse files Browse the repository at this point in the history
Remove defer (#3) and add three tests to check if parsing ratelimiting env vars works correctly
  • Loading branch information
RobbieMcKinstry authored May 10, 2018
1 parent b1456f5 commit f8a2e68
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
11 changes: 4 additions & 7 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,6 @@ func NewClient(c *Config) (*Client, error) {
c = def
}

// Check to see if the Limiter has been set.
// If is has not been set, then set it to have no limit.
if c.Limiter == nil {
c.Limiter = rate.NewLimiter(rate.Inf, 0)
}

c.modifyLock.Lock()
defer c.modifyLock.Unlock()

Expand Down Expand Up @@ -596,7 +590,10 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
defer c.config.Limiter.Wait(context.Background())
if c.config.Limiter != nil {
c.config.Limiter.Wait(context.Background())
}

c.modifyLock.RLock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()
Expand Down
42 changes: 42 additions & 0 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,48 @@ func TestClientEnvSettings(t *testing.T) {
}
}

func TestParsingRateAndBurst(t *testing.T) {
var (
correctFormat = "400:400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float(400), 400
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedRate, observedRate)
}
}

func TestParsingRateOnly(t *testing.T) {
var (
correctFormat = "400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float(400), 400
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedRate, observedRate)
}
}

func TestParsingErrorCase(t *testing.T) {
var incorrectFormat = "foobar"
var _, _, err = parseRateLimit(correctFormat)
if err == nil {
t.Error("Expected error, found no error")
}
}

func TestClientTimeoutSetting(t *testing.T) {
oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
os.Setenv(EnvVaultClientTimeout, "10")
Expand Down

0 comments on commit f8a2e68

Please sign in to comment.