Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sticky session support #97

Merged
merged 11 commits into from
Dec 8, 2017
69 changes: 51 additions & 18 deletions roundrobin/rebalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type Rebalancer struct {
// creates new meters
newMeter NewMeterFn

// sticky session object
stickySession *StickySession

requestRewriteListener RequestRewriteListener
}

Expand Down Expand Up @@ -80,6 +83,13 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption {
}
}

func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
return func(r *Rebalancer) error {
r.stickySession = stickySession
return nil
}
}

// RebalancerErrorHandler is a functional argument that sets error handler of the server
func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption {
return func(r *Rebalancer) error {
Expand All @@ -90,8 +100,9 @@ func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOpti

func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) {
rb := &Rebalancer{
mtx: &sync.Mutex{},
next: handler,
mtx: &sync.Mutex{},
next: handler,
stickySession: nil,
}
for _, o := range opts {
if err := o(rb); err != nil {
Expand Down Expand Up @@ -139,20 +150,42 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {

pw := &utils.ProxyWriter{W: w}
start := rb.clock.UtcNow()
url, err := rb.next.NextServer()
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
}

if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
}

// make shallow copy of request before changing anything to avoid side effects
newReq := *req
newReq.URL = url
stuck := false

if rb.stickySession != nil {
cookieUrl, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers())

if err != nil {
log.Infof("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err)
}

if present {
newReq.URL = cookieUrl
stuck = true
}
}

if !stuck {
url, err := rb.next.NextServer()
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
}

if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
}

if rb.stickySession != nil {
rb.stickySession.StickBackend(url, &w)
}

newReq.URL = url
}

//Emit event to a listener if one exists
if rb.requestRewriteListener != nil {
Expand All @@ -161,7 +194,7 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {

rb.next.Next().ServeHTTP(pw, &newReq)

rb.recordMetrics(url, pw.Code, rb.clock.UtcNow().Sub(start))
rb.recordMetrics(newReq.URL, pw.Code, rb.clock.UtcNow().Sub(start))
rb.adjustWeights()
}

Expand Down Expand Up @@ -244,11 +277,11 @@ func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error {
return nil
}

func (r *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
if len(r.servers) == 0 {
func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
if len(rb.servers) == 0 {
return nil, -1
}
for i, s := range r.servers {
for i, s := range rb.servers {
if sameURL(u, s.url) {
return s, i
}
Expand Down Expand Up @@ -351,7 +384,7 @@ func (rb *Rebalancer) markServers() bool {
}

func (rb *Rebalancer) convergeWeights() bool {
// If we have previoulsy changed servers try to restore weights to the original state
// If we have previously changed servers try to restore weights to the original state
changed := false
for _, s := range rb.servers {
if s.origWeight == s.curWeight {
Expand Down
48 changes: 47 additions & 1 deletion roundrobin/rebalancer_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package roundrobin

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"time"

"github.com/mailgun/timetools"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/testutils"

. "gopkg.in/check.v1"
)

Expand Down Expand Up @@ -339,6 +339,52 @@ func (s *RBSuite) TestRequestRewriteListener(c *C) {
c.Assert(rb.requestRewriteListener, NotNil)
}

func (s *RBSuite) TestRebalancerStickySession(c *C) {
a, b, x := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("x")
defer a.Close()
defer b.Close()
defer x.Close()

sticky := NewStickySession("test")
c.Assert(sticky, NotNil)

fwd, err := forward.New()
c.Assert(err, IsNil)

lb, err := New(fwd)
c.Assert(err, IsNil)

rb, err := NewRebalancer(lb, RebalancerStickySession(sticky))
c.Assert(err, IsNil)

rb.UpsertServer(testutils.ParseURI(a.URL))
rb.UpsertServer(testutils.ParseURI(b.URL))
rb.UpsertServer(testutils.ParseURI(x.URL))

proxy := httptest.NewServer(rb)
defer proxy.Close()

for i := 0; i < 10; i++ {
req, err := http.NewRequest(http.MethodGet, proxy.URL, nil)
c.Assert(err, IsNil)
req.AddCookie(&http.Cookie{Name: "test", Value: a.URL})

resp, err := http.DefaultClient.Do(req)
c.Assert(err, IsNil)

defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)

c.Assert(err, IsNil)
c.Assert(string(body), Equals, "a")
}

c.Assert(rb.RemoveServer(testutils.ParseURI(a.URL)), IsNil)
c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"b", "x", "b"})
c.Assert(rb.RemoveServer(testutils.ParseURI(b.URL)), IsNil)
c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"x", "x", "x"})
}

type testMeter struct {
rating float64
notReady bool
Expand Down
54 changes: 41 additions & 13 deletions roundrobin/rr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ func ErrorHandler(h utils.ErrorHandler) LBOption {
}
}

func EnableStickySession(stickySession *StickySession) LBOption {
return func(s *RoundRobin) error {
s.stickySession = stickySession
return nil
}
}

// ErrorHandler is a functional argument that sets error handler of the server
func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
return func(s *RoundRobin) error {
Expand All @@ -46,15 +53,17 @@ type RoundRobin struct {
index int
servers []*server
currentWeight int
stickySession *StickySession
requestRewriteListener RequestRewriteListener
}

func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
rr := &RoundRobin{
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
stickySession: nil,
}
for _, o := range opts {
if err := o(rr); err != nil {
Expand All @@ -78,21 +87,40 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer logEntry.Debug("vulcand/oxy/roundrobin/rr: competed ServeHttp on request")
}

url, err := r.NextServer()
if err != nil {
r.errHandler.ServeHTTP(w, req, err)
return
// make shallow copy of request before chaning anything to avoid side effects
newReq := *req
stuck := false
if r.stickySession != nil {
cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers())

if err != nil {
log.Infof("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err)
}

if present {
newReq.URL = cookieURL
stuck = true
}
}

if !stuck {
url, err := r.NextServer()
if err != nil {
r.errHandler.ServeHTTP(w, req, err)
return
}

if r.stickySession != nil {
r.stickySession.StickBackend(url, &w)
}
newReq.URL = url
}

if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
}

// make shallow copy of request before chaning anything to avoid side effects
newReq := *req
newReq.URL = url

//Emit event to a listener if one exists
if r.requestRewriteListener != nil {
r.requestRewriteListener(req, &newReq)
Expand Down
56 changes: 56 additions & 0 deletions roundrobin/stickysessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity
package roundrobin

import (
"net/http"
"net/url"
)

type StickySession struct {
cookieName string
}

func NewStickySession(cookieName string) *StickySession {
return &StickySession{cookieName}
}

// GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers.
func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) {
cookie, err := req.Cookie(s.cookieName)
switch err {
case nil:
case http.ErrNoCookie:
return nil, false, nil
default:
return nil, false, err
}

serverURL, err := url.Parse(cookie.Value)
if err != nil {
return nil, false, err
}

if s.isBackendAlive(serverURL, servers) {
return serverURL, true, nil
} else {
return nil, false, nil
}
}

func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) {
cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"}
http.SetCookie(*w, cookie)
}

func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) bool {
if len(haystack) == 0 {
return false
}

for _, serverURL := range haystack {
if sameURL(needle, serverURL) {
return true
}
}
return false
}
Loading