Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sticky Sessions: Tolerate ClickHouse Session ID Mechanism #117

Merged
merged 12 commits into from
May 24, 2021
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ RUN go get golang.org/x/lint/golint
RUN mkdir -p /go/src/github.com/Vertamedia/chproxy
WORKDIR /go/src/github.com/Vertamedia/chproxy
COPY . ./
ARG EXT_BUILD_TAG
ENV EXT_BUILD_TAG ${EXT_BUILD_TAG}
RUN make release-build

FROM alpine
Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
current_dir = $(pwd)
pkgs = $(shell go list ./...)
gofiles := $(shell find . -name "*.go" -type f -not -path "./vendor/*")

BUILD_TAG = $(shell git tag --points-at HEAD)
BUILD_TAG = $(or $(shell git tag --points-at HEAD), $(EXT_BUILD_TAG), latest)

BUILD_CONSTS = \
-X main.buildTime=`date -u '+%Y-%m-%d_%H:%M:%S'` \
Expand Down Expand Up @@ -36,7 +37,16 @@ clean:
rm -f chproxy

release-build:
@echo "Ver: $(BUILD_TAG), OPTS: $(BUILD_OPTS)"
GOOS=linux GOARCH=amd64 go build $(BUILD_OPTS)
rm chproxy-linux-amd64-*.tar.gz
tar czf chproxy-linux-amd64-$(BUILD_TAG).tar.gz chproxy

release: format lint test clean release-build
@echo "Ver: $(BUILD_TAG), OPTS: $(BUILD_OPTS)"
tar czf chproxy-linux-amd64-$(BUILD_TAG).tar.gz chproxy

release-build-docker:
@echo "Ver: $(BUILD_TAG)"
@DOCKER_BUILDKIT=1 docker build --target build --build-arg EXT_BUILD_TAG=$(BUILD_TAG) --progress plain -t chproxy-build .
@docker run --rm --entrypoint "/bin/sh" -v $(CURDIR):/host chproxy-build -c "/bin/cp /go/src/github.com/Vertamedia/chproxy/*.tar.gz /host"
20 changes: 20 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,26 @@ func TestServe(t *testing.T) {
},
startHTTP,
},
{
"http POST request with session id",
"testdata/http-session-id.yml",
func(t *testing.T) {
req, err := http.NewRequest("POST",
"http://127.0.0.1:9090/?query_id=45395792-a432-4b92-8cc9-536c14e1e1a9&extremes=0&session_id=default-session-id233",
bytes.NewBufferString("SELECT * FROM system.numbers LIMIT 10"))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded;") // This makes it work

checkErr(t, err)
resp, err := http.DefaultClient.Do(req)
checkErr(t, err)

if resp.StatusCode != http.StatusOK || resp.StatusCode != http.StatusOK && resp.Header.Get("X-Clickhouse-Server-Session-Id") == "" {
t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusOK)
pavelnemirovsky marked this conversation as resolved.
Show resolved Hide resolved
}
resp.Body.Close()
},
startHTTP,
},
{
"http request",
"testdata/http.yml",
Expand Down
17 changes: 11 additions & 6 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ func newReverseProxy() *reverseProxy {

func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
startTime := time.Now()

s, status, err := rp.getScope(req)
if err != nil {
q := getQuerySnippet(req)
Expand Down Expand Up @@ -99,6 +98,11 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ReadCloser: req.Body,
}

// publish session_id if needed
if s.sessionId != "" {
rw.Header().Set("X-ClickHouse-Server-Session-Id", s.sessionId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick check didn't show any results for "X-ClickHouse-Server-Session-Id" header in CH docs. Could you pls add a comment what exactly this header is used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this header just to make sure ChProxy acknowledged and processed correctly the value was set to session_id

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused while reading it. I think, it requires a proper comment to explain your intention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hagen1778 treat this header as ECHO Service, you pass the value of session_id and you'll want to assure that that value was accepted by server-side correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hagen1778 let's merge it already and let's enable GitHub actions to release versions properly for multiple platforms, what do u think?

}

if s.user.cache == nil {
rp.proxyRequest(s, srw, srw, req)
} else {
Expand All @@ -110,9 +114,9 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
q := getQuerySnippet(req)
if srw.statusCode == http.StatusOK {
requestSuccess.With(s.labels).Inc()
log.Debugf("%s: request success; query: %q; URL: %q", s, q, req.URL.String())
log.Debugf("%s: request success; query: %q; Method: %s; URL: %q", s, q, req.Method, req.URL.String())
} else {
log.Debugf("%s: request failure: non-200 status code %d; query: %q; URL: %q", s, srw.statusCode, q, req.URL.String())
log.Debugf("%s: request failure: non-200 status code %d; query: %q; Method: %s; URL: %q", s, srw.statusCode, q, req.Method, req.URL.String())
}

statusCodes.With(
Expand Down Expand Up @@ -435,7 +439,7 @@ func (rp *reverseProxy) applyConfig(cfg *config.Config) error {
return nil
}

// refreshCacheMetrics refresehs cacheSize and cacheItems metrics.
// refreshCacheMetrics refreshes cacheSize and cacheItems metrics.
func (rp *reverseProxy) refreshCacheMetrics() {
rp.lock.RLock()
defer rp.lock.RUnlock()
Expand All @@ -452,7 +456,8 @@ func (rp *reverseProxy) refreshCacheMetrics() {

func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
name, password := getAuth(req)

sessionId := getSessionId(req)
sessionTimeout := getSessionTimeout(req)
var (
u *user
c *cluster
Expand Down Expand Up @@ -489,6 +494,6 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name)
}

s := newScope(req, u, c, cu)
s := newScope(req, u, c, cu, sessionId, sessionTimeout)
return s, 0, nil
}
74 changes: 66 additions & 8 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"sync/atomic"
"time"
Expand Down Expand Up @@ -40,6 +41,9 @@ type scope struct {
user *user
clusterUser *clusterUser

sessionId string
sessionTimeout int

remoteAddr string
localAddr string

Expand All @@ -49,20 +53,24 @@ type scope struct {
labels prometheus.Labels
}

func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser) *scope {
func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope {
h := c.getHost()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use if-else here to avoid unnecessary counter's increase on line 58 when session_id isn't empty?


if sessionId != "" {
h = c.getHostSticky(sessionId)
}
var localAddr string
if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
localAddr = addr.String()
}
s := &scope{
startTime: time.Now(),
id: newScopeID(),
host: h,
cluster: c,
user: u,
clusterUser: cu,
startTime: time.Now(),
id: newScopeID(),
host: h,
cluster: c,
user: u,
clusterUser: cu,
sessionId: sessionId,
sessionTimeout: sessionTimeout,

remoteAddr: req.RemoteAddr,
localAddr: localAddr,
Expand Down Expand Up @@ -305,6 +313,10 @@ var allowedParams = []string{
"extremes",
// what to do if the volume of the result exceeds one of the limits
"result_overflow_mode",
// session stickiness
"session_id",
pavelnemirovsky marked this conversation as resolved.
Show resolved Hide resolved
// session timeout
"session_timeout",
}

// This regexp must match params needed to describe a way to use external data
Expand Down Expand Up @@ -349,6 +361,8 @@ func (s *scope) decorateRequest(req *http.Request) (*http.Request, url.Values) {

// Set query_id as scope_id to have possibility to kill query if needed.
params.Set("query_id", s.id.String())
// Set session_timeout an idle timeout for session
params.Set("session_timeout", strconv.Itoa(s.sessionTimeout))

req.URL.RawQuery = params.Encode()

Expand Down Expand Up @@ -810,6 +824,42 @@ func (c *cluster) getReplica() *replica {
return r
}

// getHostSticky returns host by stickiness from replica.
//
// Always returns non-nil.
func (r *replica) getHostSticky(sessionId string) *host {
idx := atomic.AddUint32(&r.nextHostIdx, 1)
n := uint32(len(r.hosts))
if n == 1 {
return r.hosts[0]
}

idx %= n
h := r.hosts[idx]

// Scan all the hosts for the least loaded host.
for i := uint32(1); i < n; i++ {
tmpIdx := (idx + i) % n

// handling sticky session
sessionId := hash(sessionId)
tmpIdx = (sessionId) % n
tmpHSticky := r.hosts[tmpIdx]
log.Debugf("Sticky server candidate is: %s", tmpHSticky.addr)
if !tmpHSticky.isActive() {
log.Debugf("Sticky session server has been picked up, but it is not available")
continue
}
log.Debugf("Sticky session server is: %s, session_id: %d, server_idx: %d, max nodes in pool: %d", tmpHSticky.addr, sessionId, tmpIdx, n)
return tmpHSticky
}

// The returned host may be inactive. This is OK,
// since this means all the hosts are inactive,
// so let's try proxying the request to any host.
return h
}

// getHost returns least loaded + round-robin host from replica.
//
// Always returns non-nil.
Expand Down Expand Up @@ -856,6 +906,14 @@ func (r *replica) getHost() *host {
return h
}

// getHostSticky returns host based on stickiness from cluster.
//
// Always returns non-nil.
func (c *cluster) getHostSticky(sessionId string) *host {
r := c.getReplica()
return r.getHostSticky(sessionId)
}

// getHost returns least loaded + round-robin host from cluster.
//
// Always returns non-nil.
Expand Down
12 changes: 6 additions & 6 deletions scope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ func TestDecorateRequest(t *testing.T) {
"text/plain",
"GET",
nil,
[]string{"query_id", "query"},
[]string{"query_id", "session_timeout", "query"},
},
{
"http://127.0.0.1?user=default&password=default&query=SELECT&database=default&wait_end_of_query=1",
"text/plain",
"GET",
nil,
[]string{"query_id", "query", "database"},
[]string{"query_id", "session_timeout", "query", "database"},
},
{
"http://127.0.0.1?user=default&password=default&query=SELECT&testdata_structure=id+UInt32&testdata_format=TSV",
Expand All @@ -352,7 +352,7 @@ func TestDecorateRequest(t *testing.T) {
},
},
},
[]string{"query_id", "query", "max_threads"},
[]string{"query_id", "session_timeout", "query", "max_threads"},
},
{
"http://127.0.0.1?user=default&password=default&query=SELECT&testdata_structure=id+UInt32&testdata_format=TSV",
Expand All @@ -367,7 +367,7 @@ func TestDecorateRequest(t *testing.T) {
},
},
},
[]string{"query_id", "query"},
[]string{"query_id", "session_timeout", "query"},
},
{
"http://127.0.0.1?user=default&password=default&query=SELECT&testdata_type_buzz=1&testdata_structure_foo=id+UInt32&testdata_format-bar=TSV",
Expand All @@ -386,14 +386,14 @@ func TestDecorateRequest(t *testing.T) {
},
},
},
[]string{"query_id", "query", "max_threads", "background_pool_size"},
[]string{"query_id", "session_timeout", "query", "max_threads", "background_pool_size"},
},
{
"http://127.0.0.1?user=default&password=default&query=SELECT&testdata_structure=id+UInt32&testdata_format=TSV",
"multipart/form-data; boundary=foobar",
"POST",
nil,
[]string{"query_id", "testdata_structure", "testdata_format", "query"},
[]string{"query_id", "session_timeout", "testdata_structure", "testdata_format", "query"},
},
}

Expand Down
15 changes: 15 additions & 0 deletions testdata/http-session-id.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
log_debug: true
server:
http:
listen_addr: ":9090"
allowed_networks: ["127.0.0.1/24"]

users:
- name: "default"
to_cluster: "default"
to_user: "default"

clusters:
- name: "default"
nodes:
- 127.0.0.1:8124
25 changes: 25 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"compress/gzip"
"fmt"
"hash/fnv"
"io"
"io/ioutil"
"net/http"
"sort"
"strconv"
"strings"

"github.com/Vertamedia/chproxy/chdecompressor"
Expand Down Expand Up @@ -43,6 +45,23 @@ func getAuth(req *http.Request) (string, string) {
return "default", ""
}

// getSessionId retrieves session id
func getSessionId(req *http.Request) string {
params := req.URL.Query()
sessionId := params.Get("session_id")
return sessionId
}

// getSessionId retrieves session id
func getSessionTimeout(req *http.Request) int {
params := req.URL.Query()
sessionTimeout, err := strconv.Atoi(params.Get("session_timeout"))
if err != nil && sessionTimeout > 0 {
return sessionTimeout
}
return 60
}

// getQuerySnippet returns query snippet.
//
// getQuerySnippet must be called only for error reporting.
Expand All @@ -57,6 +76,12 @@ func getQuerySnippet(req *http.Request) string {
return query + body
}

func hash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}

func getQuerySnippetFromBody(req *http.Request) string {
if req.Body == nil {
return ""
Expand Down