From 3cc9581ef3032904c1f719e924ea285901c27a62 Mon Sep 17 00:00:00 2001 From: disksing Date: Tue, 6 Dec 2022 16:58:14 +0800 Subject: [PATCH] backend: add context to pass parameters need by handler (#144) --- pkg/proxy/backend/authenticator.go | 6 ++++-- pkg/proxy/backend/backend_conn_mgr.go | 6 +++--- pkg/proxy/backend/backend_conn_mgr_test.go | 2 +- pkg/proxy/backend/handshake_handler.go | 21 +++++++++++++++++---- pkg/proxy/backend/mock_proxy_test.go | 9 +++++---- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index ef90d6f2..367c6ba0 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -15,6 +15,7 @@ package backend import ( + "context" "crypto/tls" "encoding/binary" "fmt" @@ -145,7 +146,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.capability = commonCaps.Uint32() resp := pnet.ParseHandshakeResponse(pkt) - if err = handshakeHandler.HandleHandshakeResp(resp, clientIO.SourceAddr().String()); err != nil { + ctx := context.WithValue(context.Background(), ContextKeyClientAddr, clientIO.SourceAddr().String()) + if err = handshakeHandler.HandleHandshakeResp(ctx, resp); err != nil { return err } auth.user = resp.User @@ -153,7 +155,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.collation = resp.Collation auth.attrs = resp.Attrs - backendIO, err := getBackend(auth, resp) + backendIO, err := getBackend(ctx, auth, resp) if err != nil { return err } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index a273902f..8107256d 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -54,7 +54,7 @@ type redirectResult struct { to string } -type backendIOGetter func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) +type backendIOGetter func(ctx context.Context, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) // BackendConnManager migrates a session from one BackendConnection to another. // @@ -105,8 +105,8 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler signalReceived: make(chan struct{}, 1), redirectResCh: make(chan *redirectResult, 1), } - mgr.getBackendIO = func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { - router, err := handshakeHandler.GetRouter(resp) + mgr.getBackendIO = func(ctx context.Context, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { + router, err := handshakeHandler.GetRouter(ctx, resp) if err != nil { return nil, err } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index f294b6e0..b3a29e0a 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -116,7 +116,7 @@ func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester { return tester } -func (ts *backendMgrTester) getBackendIO(auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) { +func (ts *backendMgrTester) getBackendIO(ctx context.Context, auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) { addr := ts.tc.backendListener.Addr().String() ts.mp.backendConn = NewBackendConnection(addr) if err := ts.mp.backendConn.Connect(); err != nil { diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 57a8a99f..32be5dfe 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -15,18 +15,31 @@ package backend import ( + "context" + "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/pkg/manager/namespace" "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" ) +type contextKey string + +func (k contextKey) String() string { + return "handler context key " + string(k) +} + +// Context keys. +var ( + ContextKeyClientAddr = contextKey("client_addr") +) + var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) type HandshakeHandler interface { - HandleHandshakeResp(resp *pnet.HandshakeResp, sourceAddr string) error + HandleHandshakeResp(ctx context.Context, resp *pnet.HandshakeResp) error GetCapability() pnet.Capability - GetRouter(resp *pnet.HandshakeResp) (router.Router, error) + GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) } type DefaultHandshakeHandler struct { @@ -39,7 +52,7 @@ func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager) *DefaultH } } -func (handler *DefaultHandshakeHandler) HandleHandshakeResp(*pnet.HandshakeResp, string) error { +func (handler *DefaultHandshakeHandler) HandleHandshakeResp(context.Context, *pnet.HandshakeResp) error { return nil } @@ -47,7 +60,7 @@ func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability { return SupportedServerCapabilities } -func (handler *DefaultHandshakeHandler) GetRouter(resp *pnet.HandshakeResp) (router.Router, error) { +func (handler *DefaultHandshakeHandler) GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) { ns, ok := handler.nsManager.GetNamespaceByUser(resp.User) if !ok { ns, ok = handler.nsManager.GetNamespace("default") diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index bcd8339d..706a2369 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -15,6 +15,7 @@ package backend import ( + "context" "crypto/tls" "testing" @@ -64,7 +65,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(*Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { + if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(context.Context, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { return backendIO, nil }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err @@ -107,14 +108,14 @@ type CustomHandshakeHandler struct { outAttrs map[string]string } -func (handler *CustomHandshakeHandler) GetRouter(resp *pnet.HandshakeResp) (router.Router, error) { +func (handler *CustomHandshakeHandler) GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) { return nil, nil } -func (handler *CustomHandshakeHandler) HandleHandshakeResp(resp *pnet.HandshakeResp, addr string) error { +func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx context.Context, resp *pnet.HandshakeResp) error { handler.inUsername = resp.User resp.User = handler.outUsername - handler.inAddr = addr + handler.inAddr = ctx.Value(ContextKeyClientAddr).(string) resp.Attrs = handler.outAttrs return nil }