Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sticky Sessions: Tolerate ClickHouse Session ID Mechanism #117

Merged
merged 12 commits into from
May 24, 2021
18 changes: 18 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,24 @@ func TestServe(t *testing.T) {
},
startHTTP,
},
{
"http POST request with session id",
"testdata/http-session-id.yml",
func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/?query_id=45395792-a432-4b92-8cc9-536c14e1e1a9&extremes=0&session_id=default-session-id233", bytes.NewBufferString("SELECT * FROM system.numbers LIMIT 10"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be nice to keep the line length consistent with the rest of the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted

req.Header.Set("Content-Type", "application/x-www-form-urlencoded;") // This makes it work

checkErr(t, err)
resp, err := http.DefaultClient.Do(req)
checkErr(t, err)

if resp.StatusCode != http.StatusOK || resp.Header.Get("X-Clickhouse-Server-Session-Id") == "" {
t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusOK)
pavelnemirovsky marked this conversation as resolved.
Show resolved Hide resolved
}
resp.Body.Close()
},
startHTTP,
},
{
"http request",
"testdata/http.yml",
Expand Down
16 changes: 10 additions & 6 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ func newReverseProxy() *reverseProxy {

func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
startTime := time.Now()

s, status, err := rp.getScope(req)
if err != nil {
q := getQuerySnippet(req)
Expand Down Expand Up @@ -99,6 +98,11 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ReadCloser: req.Body,
}

// publish session_id if needed
if s.sessionId != "" {
rw.Header().Set("X-ClickHouse-Server-Session-Id", s.sessionId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick check didn't show any results for "X-ClickHouse-Server-Session-Id" header in CH docs. Could you pls add a comment what exactly this header is used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this header just to make sure ChProxy acknowledged and processed correctly the value was set to session_id

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused while reading it. I think, it requires a proper comment to explain your intention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hagen1778 treat this header as ECHO Service, you pass the value of session_id and you'll want to assure that that value was accepted by server-side correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hagen1778 let's merge it already and let's enable GitHub actions to release versions properly for multiple platforms, what do u think?

}

if s.user.cache == nil {
rp.proxyRequest(s, srw, srw, req)
} else {
Expand All @@ -110,9 +114,9 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
q := getQuerySnippet(req)
if srw.statusCode == http.StatusOK {
requestSuccess.With(s.labels).Inc()
log.Debugf("%s: request success; query: %q; URL: %q", s, q, req.URL.String())
log.Debugf("%s: request success; query: %q; Method: %s; URL: %q", s, q, req.Method, req.URL.String())
} else {
log.Debugf("%s: request failure: non-200 status code %d; query: %q; URL: %q", s, srw.statusCode, q, req.URL.String())
log.Debugf("%s: request failure: non-200 status code %d; query: %q; Method: %s; URL: %q", s, srw.statusCode, q, req.Method, req.URL.String())
}

statusCodes.With(
Expand Down Expand Up @@ -435,7 +439,7 @@ func (rp *reverseProxy) applyConfig(cfg *config.Config) error {
return nil
}

// refreshCacheMetrics refresehs cacheSize and cacheItems metrics.
// refreshCacheMetrics refreshes cacheSize and cacheItems metrics.
func (rp *reverseProxy) refreshCacheMetrics() {
rp.lock.RLock()
defer rp.lock.RUnlock()
Expand All @@ -452,7 +456,7 @@ func (rp *reverseProxy) refreshCacheMetrics() {

func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
name, password := getAuth(req)

sessionId := getSessionId(req)
var (
u *user
c *cluster
Expand Down Expand Up @@ -489,6 +493,6 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name)
}

s := newScope(req, u, c, cu)
s := newScope(req, u, c, cu, sessionId)
return s, 0, nil
}
60 changes: 58 additions & 2 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
)

// var route = make(map[int]*host)
pavelnemirovsky marked this conversation as resolved.
Show resolved Hide resolved

type scopeID uint64

func (sid scopeID) String() string {
Expand All @@ -40,6 +42,8 @@ type scope struct {
user *user
clusterUser *clusterUser

sessionId string

remoteAddr string
localAddr string

Expand All @@ -49,9 +53,12 @@ type scope struct {
labels prometheus.Labels
}

func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser) *scope {
h := c.getHost()
func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string) *scope {

h := c.getHost()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use if-else here to avoid unnecessary counter's increase on line 58 when session_id isn't empty?

if sessionId != "" {
h = c.getHostSticky(sessionId)
}
var localAddr string
if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
localAddr = addr.String()
Expand All @@ -63,6 +70,7 @@ func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser) *scope {
cluster: c,
user: u,
clusterUser: cu,
sessionId: sessionId,

remoteAddr: req.RemoteAddr,
localAddr: localAddr,
Expand Down Expand Up @@ -305,6 +313,8 @@ var allowedParams = []string{
"extremes",
// what to do if the volume of the result exceeds one of the limits
"result_overflow_mode",
// session stickiness
"session_id",
pavelnemirovsky marked this conversation as resolved.
Show resolved Hide resolved
}

// This regexp must match params needed to describe a way to use external data
Expand Down Expand Up @@ -810,6 +820,44 @@ func (c *cluster) getReplica() *replica {
return r
}

// getHostSticky returns host by stickiness from replica.
//
// Always returns non-nil.
func (r *replica) getHostSticky(sessionId string) *host {
idx := atomic.AddUint32(&r.nextHostIdx, 1)
n := uint32(len(r.hosts))
if n == 1 {
return r.hosts[0]
}

idx %= n
h := r.hosts[idx]

// Scan all the hosts for the least loaded host.
for i := uint32(1); i < n; i++ {
tmpIdx := (idx + i) % n

// handling sticky session
if sessionId != "" {
pavelnemirovsky marked this conversation as resolved.
Show resolved Hide resolved
sessionId := hash(sessionId)
tmpIdx = (sessionId) % n
tmpHSticky := r.hosts[tmpIdx]
log.Debugf("Sticky server candidate is: %s", tmpHSticky.addr)
if !tmpHSticky.isActive() {
log.Debugf("Sticky session server has been picked up, but it is not available")
continue
}
log.Debugf("Sticky session server is: %s, session_id: %d, server_idx: %d, max nodes in pool: %d", tmpHSticky.addr, sessionId, tmpIdx, n)
return tmpHSticky
}
}

// The returned host may be inactive. This is OK,
// since this means all the hosts are inactive,
// so let's try proxying the request to any host.
return h
}

// getHost returns least loaded + round-robin host from replica.
//
// Always returns non-nil.
Expand Down Expand Up @@ -856,6 +904,14 @@ func (r *replica) getHost() *host {
return h
}

// getHostSticky returns host based on stickiness from cluster.
//
// Always returns non-nil.
func (c *cluster) getHostSticky(sessionId string) *host {
r := c.getReplica()
return r.getHostSticky(sessionId)
}

// getHost returns least loaded + round-robin host from cluster.
//
// Always returns non-nil.
Expand Down
15 changes: 15 additions & 0 deletions testdata/http-session-id.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
log_debug: true
server:
http:
listen_addr: ":9090"
allowed_networks: ["127.0.0.1/24"]

users:
- name: "default"
to_cluster: "default"
to_user: "default"

clusters:
- name: "default"
nodes:
- 127.0.0.1:8124
14 changes: 14 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"compress/gzip"
"fmt"
"hash/fnv"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -43,6 +44,13 @@ func getAuth(req *http.Request) (string, string) {
return "default", ""
}

// getSessionId retrieves session id
func getSessionId(req *http.Request) string {
params := req.URL.Query()
sessionId := params.Get("session_id")
return sessionId
}

// getQuerySnippet returns query snippet.
//
// getQuerySnippet must be called only for error reporting.
Expand All @@ -57,6 +65,12 @@ func getQuerySnippet(req *http.Request) string {
return query + body
}

func hash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}

func getQuerySnippetFromBody(req *http.Request) string {
if req.Body == nil {
return ""
Expand Down