Skip to content

Commit

Permalink
Support SesionTTLMin configuration
Browse files Browse the repository at this point in the history
- Allow setting SessionTTLMin
- Validate on the Server
  • Loading branch information
Michael Fraenkel committed Mar 27, 2015
1 parent 6e26162 commit 8c26836
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 116 deletions.
3 changes: 3 additions & 0 deletions command/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ func (a *Agent) consulConfig() *consul.Config {
if a.config.ACLDownPolicy != "" {
base.ACLDownPolicy = a.config.ACLDownPolicy
}
if a.config.SessionTTLMinRaw != "" {
base.SessionTTLMin = a.config.SessionTTLMin
}

// Format the build string
revision := a.config.Revision
Expand Down
17 changes: 16 additions & 1 deletion command/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ type Config struct {

// UnixSockets is a map of socket configuration data
UnixSockets UnixSocketConfig `mapstructure:"unix_sockets"`

// Minimum Session TTL
SessionTTLMin time.Duration `mapstructure:"-"`
SessionTTLMinRaw string `mapstructure:"session_ttl_min"`
}

// UnixSocketPermissions contains information about a unix socket, and
Expand Down Expand Up @@ -609,6 +613,14 @@ func DecodeConfig(r io.Reader) (*Config, error) {
result.DNSRecursors = append(result.DNSRecursors, result.DNSRecursor)
}

if raw := result.SessionTTLMinRaw; raw != "" {
dur, err := time.ParseDuration(raw)
if err != nil {
return nil, fmt.Errorf("Session TTL Min invalid: %v", err)
}
result.SessionTTLMin = dur
}

return &result, nil
}

Expand Down Expand Up @@ -970,7 +982,10 @@ func MergeConfig(a, b *Config) *Config {
if b.AtlasJoin {
result.AtlasJoin = true
}

if b.SessionTTLMinRaw != "" {
result.SessionTTLMin = b.SessionTTLMin
result.SessionTTLMinRaw = b.SessionTTLMinRaw
}
if len(b.HTTPAPIResponseHeaders) != 0 {
if result.HTTPAPIResponseHeaders == nil {
result.HTTPAPIResponseHeaders = make(map[string]string)
Expand Down
13 changes: 13 additions & 0 deletions command/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,17 @@ func TestDecodeConfig(t *testing.T) {
if !config.AtlasJoin {
t.Fatalf("bad: %#v", config)
}

// SessionTTLMin
input = `{"session_ttl_min": "5s"}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
}

if config.SessionTTLMin != 5*time.Second {
t.Fatalf("bad: %s %#v", config.SessionTTLMin.String(), config)
}
}

func TestDecodeConfig_invalidKeys(t *testing.T) {
Expand Down Expand Up @@ -1120,6 +1131,8 @@ func TestMergeConfig(t *testing.T) {
AtlasToken: "123456789",
AtlasACLToken: "abcdefgh",
AtlasJoin: true,
SessionTTLMinRaw: "1000s",
SessionTTLMin: 1000 * time.Second,
}

c := MergeConfig(a, b)
Expand Down
6 changes: 5 additions & 1 deletion command/agent/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,11 @@ func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 {
}

func httpTest(t *testing.T, f func(srv *HTTPServer)) {
dir, srv := makeHTTPServer(t)
httpTestWithConfig(t, f, nil)
}

func httpTestWithConfig(t *testing.T, f func(srv *HTTPServer), cb func(c *Config)) {
dir, srv := makeHTTPServerWithConfig(t, cb)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
Expand Down
15 changes: 0 additions & 15 deletions command/agent/session_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,6 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
return nil, nil
}

if args.Session.TTL != "" {
ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
return nil, nil
}

if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
return nil, nil
}
}
}

// Create the session, get the ID
Expand Down
102 changes: 18 additions & 84 deletions command/agent/session_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package agent
import (
"bytes"
"encoding/json"
"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
)

func TestSessionCreate(t *testing.T) {
Expand Down Expand Up @@ -215,9 +216,20 @@ func TestSessionDestroy(t *testing.T) {
}

func TestSessionTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
TTL := "10s" // use the minimum legal ttl
ttl := 10 * time.Second
// use the minimum legal ttl
testSessionTTL(t, 10*time.Second, nil)
}

func TestSessionTTLConfig(t *testing.T) {
testSessionTTL(t, 1*time.Second, func(c *Config) {
c.SessionTTLMinRaw = "1s"
c.SessionTTLMin = 1 * time.Second
})
}

func testSessionTTL(t *testing.T, ttl time.Duration, cb func(c *Config)) {
httpTestWithConfig(t, func(srv *HTTPServer) {
TTL := ttl.String()

id := makeTestSessionTTL(t, srv, TTL)

Expand Down Expand Up @@ -252,85 +264,7 @@ func TestSessionTTL(t *testing.T) {
if len(respObj) != 0 {
t.Fatalf("session '%s' should have been destroyed", id)
}
})
}

func TestSessionBadTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
badTTL := "10z"

// Create Session with illegal TTL
body := bytes.NewBuffer(nil)
enc := json.NewEncoder(body)
raw := map[string]interface{}{
"TTL": badTTL,
}
enc.Encode(raw)

req, err := http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := httptest.NewRecorder()
obj, err := srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}

// less than SessionTTLMin
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "5s",
}
enc.Encode(raw)

req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}

// more than SessionTTLMax
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "4000s",
}
enc.Encode(raw)

req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
})
}, cb)
}

func TestSessionTTLRenew(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions consul/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ type Config struct {
// to reduce overhead. It is unlikely a user would ever need to tune this.
TombstoneTTLGranularity time.Duration

// Minimum Session TTL
SessionTTLMin time.Duration

// ServerUp callback can be used to trigger a notification that
// a Consul server is now up and known about.
ServerUp func()
Expand Down Expand Up @@ -241,6 +244,7 @@ func DefaultConfig() *Config {
ACLDownPolicy: "extend-cache",
TombstoneTTL: 15 * time.Minute,
TombstoneTTLGranularity: 30 * time.Second,
SessionTTLMin: 10 * time.Second,
}

// Increase our reap interval to 3 days instead of 24h.
Expand Down
4 changes: 2 additions & 2 deletions consul/session_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
}

if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
if ttl != 0 && (ttl < s.srv.config.SessionTTLMin || ttl > structs.SessionTTLMax) {
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]",
ttl, structs.SessionTTLMin, structs.SessionTTLMax)
ttl, s.srv.config.SessionTTLMin, structs.SessionTTLMax)
}
}

Expand Down
53 changes: 53 additions & 0 deletions consul/session_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,56 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) {
}
}
}

func TestSessionEndpoint_Apply_BadTTL(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
client := rpcClient(t, s1)
defer client.Close()

testutil.WaitForLeader(t, client.Call, "dc1")

arg := structs.SessionRequest{
Datacenter: "dc1",
Op: structs.SessionCreate,
Session: structs.Session{
Node: "foo",
Name: "my-session",
},
}

// Session with illegal TTL
arg.Session.TTL = "10z"

var out string
err := client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Session TTL '10z' invalid: time: unknown unit z in duration 10z" {
t.Fatalf("incorrect error message: %s", err.Error())
}

// less than SessionTTLMin
arg.Session.TTL = "5s"

err = client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Invalid Session TTL '5000000000', must be between [10s=1h0m0s]" {
t.Fatalf("incorrect error message: %s", err.Error())
}

// more than SessionTTLMax
arg.Session.TTL = "4000s"

err = client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Invalid Session TTL '4000000000000', must be between [10s=1h0m0s]" {
t.Fatalf("incorrect error message: %s", err.Error())
}
}
12 changes: 0 additions & 12 deletions consul/state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -1625,18 +1625,6 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error
return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior)
}

if session.TTL != "" {
ttl, err := time.ParseDuration(session.TTL)
if err != nil {
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
}

if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]",
session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)
}
}

// Assign the create index
session.CreateIndex = index

Expand Down
1 change: 0 additions & 1 deletion consul/structs/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ const (
)

const (
SessionTTLMin = 10 * time.Second
SessionTTLMax = 3600 * time.Second
SessionTTLMultiplier = 2
)
Expand Down

0 comments on commit 8c26836

Please sign in to comment.