Skip to content

Commit

Permalink
refactor: minor tunneling cleanup (#4034)
Browse files Browse the repository at this point in the history
chore: tunneling cleanup
  • Loading branch information
achettyiitr authored Oct 30, 2023
1 parent 2cc9470 commit 877eb70
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 52 deletions.
1 change: 1 addition & 0 deletions .github/tools/matrixchecker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var IgnorePackages = []string{
"warehouse/integrations/testdata",
"warehouse/integrations/config",
"warehouse/integrations/types",
"warehouse/integrations/tunnelling",
}

func main() {
Expand Down
7 changes: 4 additions & 3 deletions warehouse/integrations/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"time"

"github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling"

"github.com/rudderlabs/rudder-go-kit/stats"

sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
Expand All @@ -22,7 +24,6 @@ import (
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/warehouse/client"
"github.com/rudderlabs/rudder-server/warehouse/tunnelling"
warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

Expand Down Expand Up @@ -223,7 +224,7 @@ func (pg *Postgres) connect() (*sqlmiddleware.DB, error) {

if cred.tunnelInfo != nil {

db, err = tunnelling.SQLConnectThroughTunnel(dsn.String(), cred.tunnelInfo.Config)
db, err = tunnelling.Connect(dsn.String(), cred.tunnelInfo.Config)
if err != nil {
return nil, fmt.Errorf("opening connection to postgres through tunnelling: %w", err)
}
Expand All @@ -248,7 +249,7 @@ func (pg *Postgres) getConnectionCredentials() credentials {
sslMode: sslMode,
sslDir: warehouseutils.GetSSLKeyDirPath(pg.Warehouse.Destination.ID),
timeout: pg.connectTimeout,
tunnelInfo: warehouseutils.ExtractTunnelInfoFromDestinationConfig(
tunnelInfo: tunnelling.ExtractTunnelInfoFromDestinationConfig(
pg.Warehouse.Destination.Config,
),
}
Expand Down
10 changes: 5 additions & 5 deletions warehouse/integrations/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"testing"
"time"

"github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling"

"github.com/golang/mock/gomock"

"github.com/rudderlabs/rudder-go-kit/config"
Expand All @@ -24,14 +26,12 @@ import (

"github.com/rudderlabs/rudder-server/testhelper/workspaceConfig"

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/warehouse/client"
"github.com/rudderlabs/rudder-server/warehouse/tunnelling"

"github.com/rudderlabs/compose-test/testcompose"
kithelper "github.com/rudderlabs/rudder-go-kit/testhelper"
backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/runner"
"github.com/rudderlabs/rudder-server/testhelper/health"
"github.com/rudderlabs/rudder-server/warehouse/client"

"github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper"

Expand Down Expand Up @@ -325,7 +325,7 @@ func TestIntegration(t *testing.T) {
},
}

db, err := tunnelling.SQLConnectThroughTunnel(dsn, tunnelInfo.Config)
db, err := tunnelling.Connect(dsn, tunnelInfo.Config)
require.NoError(t, err)
require.NoError(t, db.Ping())

Expand Down
7 changes: 4 additions & 3 deletions warehouse/integrations/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"strings"
"time"

"github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling"

"github.com/samber/lo"

"github.com/rudderlabs/rudder-server/warehouse/integrations/types"
Expand All @@ -31,7 +33,6 @@ import (
sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
"github.com/rudderlabs/rudder-server/warehouse/logfield"
"github.com/rudderlabs/rudder-server/warehouse/tunnelling"
warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

Expand Down Expand Up @@ -1002,7 +1003,7 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) {
)

if cred.TunnelInfo != nil {
if db, err = tunnelling.SQLConnectThroughTunnel(dsn.String(), cred.TunnelInfo.Config); err != nil {
if db, err = tunnelling.Connect(dsn.String(), cred.TunnelInfo.Config); err != nil {
return nil, fmt.Errorf("connecting to redshift through tunnel: %w", err)
}
} else {
Expand Down Expand Up @@ -1235,7 +1236,7 @@ func (rs *Redshift) getConnectionCredentials() RedshiftCredentials {
Username: warehouseutils.GetConfigValue(RSUserName, rs.Warehouse),
Password: warehouseutils.GetConfigValue(RSPassword, rs.Warehouse),
timeout: rs.connectTimeout,
TunnelInfo: warehouseutils.ExtractTunnelInfoFromDestinationConfig(rs.Warehouse.Destination.Config),
TunnelInfo: tunnelling.ExtractTunnelInfoFromDestinationConfig(rs.Warehouse.Destination.Config),
}

return creds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"strconv"

whutils "github.com/rudderlabs/rudder-server/warehouse/utils"

stunnel "github.com/rudderlabs/sql-tunnels/driver/ssh"
)

Expand All @@ -22,30 +24,56 @@ const (
)

type (
Type string
Config map[string]interface{}
Config map[string]interface{}
TunnelInfo struct {
Config Config
}
)

type TunnelInfo struct {
Config Config
// ExtractTunnelInfoFromDestinationConfig extracts TunnelInfo from destination config if tunnel is enabled for the destination.
func ExtractTunnelInfoFromDestinationConfig(config Config) *TunnelInfo {
if tunnelEnabled := whutils.ReadAsBool("useSSH", config); !tunnelEnabled {
return nil
}

return &TunnelInfo{
Config: config,
}
}

func ReadSSHTunnelConfig(config Config) (conf *stunnel.Config, err error) {
// Connect establishes a database connection over an SSH tunnel.
func Connect(dsn string, config Config) (*sql.DB, error) {
tunnelConfig, err := extractTunnelConfig(config)
if err != nil {
return nil, fmt.Errorf("reading ssh tunnel config: %w", err)
}

encodedDSN, err := tunnelConfig.EncodeWithDSN(dsn)
if err != nil {
return nil, fmt.Errorf("encoding with dsn: %w", err)
}

db, err := sql.Open("sql+ssh", encodedDSN)
if err != nil {
return nil, fmt.Errorf("opening warehouse connection sql+ssh driver: %w", err)
}
return db, nil
}

func extractTunnelConfig(config Config) (*stunnel.Config, error) {
var user, host, port, privateKey *string
var err error

if user, err = ReadString(sshUser, config); err != nil {
if user, err = readString(sshUser, config); err != nil {
return nil, err
}

if host, err = ReadString(sshHost, config); err != nil {
if host, err = readString(sshHost, config); err != nil {
return nil, err
}

if port, err = ReadString(sshPort, config); err != nil {
if port, err = readString(sshPort, config); err != nil {
return nil, err
}

if privateKey, err = ReadString(sshPrivateKey, config); err != nil {
if privateKey, err = readString(sshPrivateKey, config); err != nil {
return nil, err
}

Expand All @@ -62,7 +90,7 @@ func ReadSSHTunnelConfig(config Config) (conf *stunnel.Config, err error) {
}, nil
}

func ReadString(key string, config Config) (*string, error) {
func readString(key string, config Config) (*string, error) {
val, ok := config[key]
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingKey, key)
Expand All @@ -72,22 +100,5 @@ func ReadString(key string, config Config) (*string, error) {
if !ok {
return nil, fmt.Errorf("%w: %s expected string", ErrUnexpectedType, key)
}

return &resp, nil
}

func SQLConnectThroughTunnel(dsn string, tunnelConfig Config) (*sql.DB, error) {
conf, err := ReadSSHTunnelConfig(tunnelConfig)
if err != nil {
return nil, fmt.Errorf("reading ssh tunnel config: %w", err)
}
encodedDSN, err := conf.EncodeWithDSN(dsn)
if err != nil {
return nil, fmt.Errorf("encoding with dsn: %w", err)
}
db, err := sql.Open("sql+ssh", encodedDSN)
if err != nil {
return nil, fmt.Errorf("opening warehouse connection sql+ssh driver: %w", err)
}
return db, nil
}
140 changes: 140 additions & 0 deletions warehouse/integrations/tunnelling/connect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package tunnelling

import (
"context"
"errors"
"fmt"
"os"
"testing"

"github.com/rudderlabs/compose-test/compose"
"github.com/rudderlabs/compose-test/testcompose"

"github.com/stretchr/testify/require"
)

func TestConnect(t *testing.T) {
privateKey, err := os.ReadFile("testdata/test_key")
require.Nil(t, err)

ctx := context.Background()

c := testcompose.New(t, compose.FilePaths{"./testdata/docker-compose.yml"})
c.Start(context.Background())

host := "0.0.0.0"
user := c.Env("openssh-server", "USER_NAME")
port := c.Port("openssh-server", 2222)
postgresPort := c.Port("postgres", 5432)

testCases := []struct {
name string
dsn string
config Config
wantError error
}{
{
name: "empty config",
dsn: "dsn",
config: Config{},
wantError: ErrMissingKey,
},
{
name: "invalid config",
dsn: "dsn",
config: Config{
sshUser: "user",
sshHost: "host",
sshPort: 22,
sshPrivateKey: "privateKey",
},
wantError: errors.New("invalid type"),
},
{
name: "missing sshUser",
dsn: "dsn",
config: Config{
sshHost: "host",
sshPort: "port",
sshPrivateKey: "privateKey",
},
wantError: ErrMissingKey,
},
{
name: "missing sshHost",
dsn: "dsn",
config: Config{
sshUser: "user",
sshPort: "port",
sshPrivateKey: "privateKey",
},
wantError: ErrMissingKey,
},
{
name: "missing sshPort",
dsn: "dsn",
config: Config{
sshUser: "user",
sshHost: "host",
sshPrivateKey: "privateKey",
},
wantError: ErrMissingKey,
},
{
name: "missing sshPrivateKey",
dsn: "dsn",
config: Config{
sshUser: "user",
sshHost: "host",
sshPort: "port",
},
wantError: ErrMissingKey,
},
{
name: "invalid sshPort",
dsn: "dsn",
config: Config{
sshUser: "user",
sshHost: "host",
sshPort: "port",
sshPrivateKey: "privateKey",
},
wantError: errors.New("invalid port"),
},
{
name: "invalid dsn",
dsn: "postgres://user:password@host:5439/db?query1=val1&query2=val2",
config: Config{
sshUser: "user",
sshHost: "0.0.0.0",
sshPort: "22",
sshPrivateKey: "privateKey",
},
wantError: errors.New("invalid dsn"),
},
{
name: "valid dsn",
dsn: fmt.Sprintf("postgres://postgres:postgres@db_postgres:%d/postgres?sslmode=disable", postgresPort),
config: Config{
sshUser: user,
sshHost: host,
sshPort: port,
sshPrivateKey: privateKey,
},
wantError: errors.New("invalid dsn"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
db, err := Connect(tc.dsn, tc.config)
t.Log(err)
if tc.wantError != nil {
require.Error(t, err, tc.wantError)
return
}
require.NoError(t, err)
require.NoError(t, db.PingContext(ctx))
})
}
}
Loading

0 comments on commit 877eb70

Please sign in to comment.