diff --git a/agent/consul/session_endpoint.go b/agent/consul/session_endpoint.go index 05e2b5c43539..353635768dbd 100644 --- a/agent/consul/session_endpoint.go +++ b/agent/consul/session_endpoint.go @@ -19,6 +19,16 @@ type Session struct { logger hclog.Logger } +// in v1.7.0 we renamed Session -> SessionID. While its more descriptive of what +// we actually expect, it did break the RPC API for the SessionSpecificRequest. Now +// we have to put back the original name and support both with the new name being +// the canonical name and the other being considered only when the main one is empty. +func fixupSessionSpecificRequest(args *structs.SessionSpecificRequest) { + if args.SessionID == "" { + args.SessionID = args.Session + } +} + // Apply is used to apply a modifying request to the data store. This should // only be used for operations that modify the data func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { @@ -156,6 +166,8 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, return err } + fixupSessionSpecificRequest(args) + var authzContext acl.AuthorizerContext authz, err := s.srv.ResolveTokenAndDefaultMeta(args.Token, &args.EnterpriseMeta, &authzContext) if err != nil { @@ -262,6 +274,9 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, if done, err := s.srv.forward("Session.Renew", args, args, reply); done { return err } + + fixupSessionSpecificRequest(args) + defer metrics.MeasureSince([]string{"session", "renew"}, time.Now()) // Fetch the ACL token, if any, and apply the policy. diff --git a/agent/consul/session_endpoint_test.go b/agent/consul/session_endpoint_test.go index bd42febc153f..1bacb4a35e62 100644 --- a/agent/consul/session_endpoint_test.go +++ b/agent/consul/session_endpoint_test.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/testrpc" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" ) func TestSession_Apply(t *testing.T) { @@ -280,6 +280,52 @@ func TestSession_Get(t *testing.T) { } } +func TestSession_Get_Compat(t *testing.T) { + t.Parallel() + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + }, + } + var out string + if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + + getR := structs.SessionSpecificRequest{ + Datacenter: "dc1", + // this should get converted to the SessionID field internally + Session: out, + } + var sessions structs.IndexedSessions + if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { + t.Fatalf("err: %v", err) + } + + if sessions.Index == 0 { + t.Fatalf("Bad: %v", sessions) + } + if len(sessions.Sessions) != 1 { + t.Fatalf("Bad: %v", sessions) + } + s := sessions.Sessions[0] + if s.ID != out { + t.Fatalf("bad: %v", s) + } +} + func TestSession_List(t *testing.T) { t.Parallel() dir1, s1 := testServer(t) @@ -793,6 +839,63 @@ session "foo" { } } +func TestSession_Renew_Compat(t *testing.T) { + // This method is timing sensitive, disable Parallel + //t.Parallel() + ttl := 5 * time.Second + TTL := ttl.String() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.SessionTTLMin = ttl + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + + codec := rpcClient(t, s1) + defer codec.Close() + + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + var id string + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + TTL: TTL, + }, + } + if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &id); err != nil { + t.Fatalf("err: %v", err) + } + + // renew the session + renewR := structs.SessionSpecificRequest{ + Datacenter: "dc1", + // this will get ranslated internally to the SessionID field + Session: id, + } + var session structs.IndexedSessions + if err := msgpackrpc.CallWithCodec(codec, "Session.Renew", &renewR, &session); err != nil { + t.Fatalf("err: %v", err) + } + + if session.Index == 0 { + t.Fatalf("Bad: %v", session) + } + if len(session.Sessions) != 1 { + t.Fatalf("Bad: %v", session.Sessions) + } + + s := session.Sessions[0] + if id != s.ID { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } +} + func TestSession_NodeSessions(t *testing.T) { t.Parallel() dir1, s1 := testServer(t) diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index 46a512f48d5a..f13f7a37665d 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -98,6 +98,7 @@ func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) ( // Pull out the session id args.SessionID = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/") + args.Session = args.SessionID if args.SessionID == "" { resp.WriteHeader(http.StatusBadRequest) fmt.Fprint(resp, "Missing session") @@ -128,6 +129,7 @@ func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (in // Pull out the session id args.SessionID = strings.TrimPrefix(req.URL.Path, "/v1/session/info/") + args.Session = args.SessionID if args.SessionID == "" { resp.WriteHeader(http.StatusBadRequest) fmt.Fprint(resp, "Missing session") diff --git a/agent/structs/structs.go b/agent/structs/structs.go index b386cc4b6d0f..6d20fe3f8835 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -2036,6 +2036,8 @@ func (r *SessionRequest) RequestDatacenter() string { type SessionSpecificRequest struct { Datacenter string SessionID string + // DEPRECATED in 1.7.0 + Session string EnterpriseMeta QueryOptions }