Skip to content

Commit

Permalink
Proxy line support for mysql (#6594)
Browse files Browse the repository at this point in the history
  • Loading branch information
r0mant authored Apr 30, 2021
1 parent 5dca072 commit d0cfc8a
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 30 deletions.
9 changes: 6 additions & 3 deletions lib/multiplexer/multiplexer.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2017 Gravitational, Inc.
Copyright 2017-2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -319,9 +319,12 @@ func detect(conn net.Conn, enableProxyProtocol bool) (*Conn, error) {
return nil, trace.BadParameter("unknown protocol")
}

// Protocol defines detected protocol type.
type Protocol int

const (
// ProtoUnknown is for unknown protocol
ProtoUnknown = iota
ProtoUnknown Protocol = iota
// ProtoTLS is TLS protocol
ProtoTLS
// ProtoSSH is SSH protocol
Expand Down Expand Up @@ -379,7 +382,7 @@ func isHTTP(in []byte) bool {
return false
}

func detectProto(in []byte) (int, error) {
func detectProto(in []byte) (Protocol, error) {
switch {
// reader peeks only 3 bytes, slice the longer proxy prefix
case bytes.HasPrefix(in, proxyPrefix[:3]):
Expand Down
140 changes: 140 additions & 0 deletions lib/multiplexer/testproxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package multiplexer

import (
"io"
"net"

"github.com/gravitational/teleport/lib/utils"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
)

// TestProxy is tcp passthrough proxy that sends a proxy-line when connecting
// to the target server.
type TestProxy struct {
listener net.Listener
target string
closeCh chan (struct{})
log logrus.FieldLogger
}

// NewTestProxy creates a new test proxy that sends a proxy-line when
// proxying connections to the provided target address.
func NewTestProxy(target string) (*TestProxy, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, trace.Wrap(err)
}
return &TestProxy{
listener: listener,
target: target,
closeCh: make(chan struct{}),
log: logrus.WithField(trace.Component, "test:proxy"),
}, nil
}

// Address returns the proxy listen address.
func (p *TestProxy) Address() string {
return p.listener.Addr().String()
}

// Serve starts accepting client connections and proxying them to the target.
func (p *TestProxy) Serve() error {
for {
clientConn, err := p.listener.Accept()
if err != nil {
return trace.Wrap(err)
}
go func() {
if err := p.handleConnection(clientConn); err != nil {
p.log.WithError(err).Error("Failed to handle connection.")
}
}()
}
}

// handleConnection dials the target address, sends a proxy line to it and
// then starts proxying all traffic b/w client and target.
func (p *TestProxy) handleConnection(clientConn net.Conn) error {
serverConn, err := net.Dial("tcp", p.target)
if err != nil {
clientConn.Close()
return trace.Wrap(err)
}
defer serverConn.Close()
errCh := make(chan error, 2)
go func() { // Client -> server.
defer clientConn.Close()
defer serverConn.Close()
// Write proxy-line first and then start proxying from client.
err := p.sendProxyLine(clientConn, serverConn)
if err == nil {
_, err = io.Copy(serverConn, clientConn)
}
errCh <- trace.Wrap(err)
}()
go func() { // Server -> client.
defer clientConn.Close()
defer serverConn.Close()
_, err := io.Copy(clientConn, serverConn)
errCh <- trace.Wrap(err)
}()
var errs []error
for i := 0; i < 2; i++ {
select {
case err := <-errCh:
if err != nil && !utils.IsOKNetworkError(err) {
errs = append(errs, err)
}
case <-p.closeCh:
p.log.Debug("Closing.")
return trace.NewAggregate(errs...)
}
}
return trace.NewAggregate(errs...)
}

// sendProxyLine sends proxy-line to the server.
func (p *TestProxy) sendProxyLine(clientConn, serverConn net.Conn) error {
clientAddr, err := utils.ParseAddr(clientConn.RemoteAddr().String())
if err != nil {
return trace.Wrap(err)
}
serverAddr, err := utils.ParseAddr(serverConn.RemoteAddr().String())
if err != nil {
return trace.Wrap(err)
}
proxyLine := &ProxyLine{
Protocol: TCP4,
Source: net.TCPAddr{IP: net.ParseIP(clientAddr.Host()), Port: clientAddr.Port(0)},
Destination: net.TCPAddr{IP: net.ParseIP(serverAddr.Host()), Port: serverAddr.Port(0)},
}
_, err = serverConn.Write([]byte(proxyLine.String()))
if err != nil {
return trace.Wrap(err)
}
return nil
}

// Close closes the proxy listener.
func (p *TestProxy) Close() error {
close(p.closeCh)
return p.listener.Close()
}
37 changes: 34 additions & 3 deletions lib/multiplexer/wrappers.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2017 Gravitational, Inc.
Copyright 2017-2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,11 +30,19 @@ import (
// protocol detection
type Conn struct {
net.Conn
protocol int
protocol Protocol
proxyLine *ProxyLine
reader *bufio.Reader
}

// NewConn returns a net.Conn wrapper that supports peeking into the connection.
func NewConn(conn net.Conn) *Conn {
return &Conn{
Conn: conn,
reader: bufio.NewReader(conn),
}
}

// Read reads from connection
func (c *Conn) Read(p []byte) (int, error) {
return c.reader.Read(p)
Expand All @@ -57,10 +65,33 @@ func (c *Conn) RemoteAddr() net.Addr {
}

// Protocol returns the detected connection protocol
func (c *Conn) Protocol() int {
func (c *Conn) Protocol() Protocol {
return c.protocol
}

// Detect detects the connection protocol by peeking into the first few bytes.
func (c *Conn) Detect() (Protocol, error) {
bytes, err := c.reader.Peek(8)
if err != nil {
return ProtoUnknown, trace.Wrap(err)
}
proto, err := detectProto(bytes)
if err != nil && !trace.IsBadParameter(err) {
return ProtoUnknown, trace.Wrap(err)
}
return proto, nil
}

// ReadProxyLine reads proxy-line from the connection.
func (c *Conn) ReadProxyLine() (*ProxyLine, error) {
proxyLine, err := ReadProxyLine(c.reader)
if err != nil {
return nil, trace.Wrap(err)
}
c.proxyLine = proxyLine
return proxyLine, nil
}

func newListener(parent context.Context, addr net.Addr) *Listener {
context, cancel := context.WithCancel(parent)
return &Listener{
Expand Down
24 changes: 18 additions & 6 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ func TestMain(m *testing.M) {
func TestAccessPostgres(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres"))
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingConnections()

tests := []struct {
Expand Down Expand Up @@ -162,7 +161,6 @@ func TestAccessPostgres(t *testing.T) {
func TestAccessMySQL(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedMySQL("mysql"))
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingConnections()

tests := []struct {
Expand Down Expand Up @@ -256,7 +254,6 @@ func TestAccessDisabled(t *testing.T) {

ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres"))
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingConnections()

userName := "alice"
Expand Down Expand Up @@ -327,10 +324,15 @@ func (c *testContext) startHandlingConnections() {
// postgresClient connects to test Postgres through database access as a
// specified Teleport user and database account.
func (c *testContext) postgresClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*pgconn.PgConn, error) {
return c.postgresClientWithAddr(ctx, c.mux.DB().Addr().String(), teleportUser, dbService, dbUser, dbName)
}

// postgresClientWithAddr like postgresClient but allows to override connection address.
func (c *testContext) postgresClientWithAddr(ctx context.Context, address, teleportUser, dbService, dbUser, dbName string) (*pgconn.PgConn, error) {
return postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: c.mux.DB().Addr().String(),
Address: address,
Cluster: c.clusterName,
Username: teleportUser,
RouteToDatabase: tlsca.RouteToDatabase{
Expand All @@ -345,10 +347,15 @@ func (c *testContext) postgresClient(ctx context.Context, teleportUser, dbServic
// mysqlClient connects to test MySQL through database access as a specified
// Teleport user and database account.
func (c *testContext) mysqlClient(teleportUser, dbService, dbUser string) (*client.Conn, error) {
return c.mysqlClientWithAddr(c.mysqlListener.Addr().String(), teleportUser, dbService, dbUser)
}

// mysqlClientWithAddr like mysqlClient but allows to override connection address.
func (c *testContext) mysqlClientWithAddr(address, teleportUser, dbService, dbUser string) (*client.Conn, error) {
return mysql.MakeTestClient(common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: c.mysqlListener.Addr().String(),
Address: address,
Cluster: c.clusterName,
Username: teleportUser,
RouteToDatabase: tlsca.RouteToDatabase{
Expand Down Expand Up @@ -393,11 +400,16 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
postgres: make(map[string]testPostgres),
mysql: make(map[string]testMySQL),
}
t.Cleanup(func() { testCtx.Close() })

// Create multiplexer.
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
testCtx.mux, err = multiplexer.New(multiplexer.Config{ID: "test", Listener: listener})
testCtx.mux, err = multiplexer.New(multiplexer.Config{
ID: "test",
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)

// Create MySQL proxy listener.
Expand Down
2 changes: 0 additions & 2 deletions lib/srv/db/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
func TestAuditPostgres(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres"))
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingConnections()

testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{"postgres"}, []string{"postgres"})
Expand Down Expand Up @@ -71,7 +70,6 @@ func TestAuditPostgres(t *testing.T) {
func TestAuditMySQL(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedMySQL("mysql"))
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingConnections()

testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{"root"}, []string{types.Wildcard})
Expand Down
1 change: 0 additions & 1 deletion lib/srv/db/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ func TestAuthTokens(t *testing.T) {
withCloudSQLPostgres("postgres-cloudsql-incorrect-token", "qwe123"),
withRDSMySQL("mysql-rds-correct-token", "root", rdsAuthToken),
withRDSMySQL("mysql-rds-incorrect-token", "root", "qwe123"))
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingConnections()

testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard})
Expand Down
Loading

0 comments on commit d0cfc8a

Please sign in to comment.