Skip to content

Commit

Permalink
Merge pull request #54 from planetscale/filter-columns-before-sync
Browse files Browse the repository at this point in the history
Filter columns before sync
  • Loading branch information
notfelineit authored Aug 30, 2024
2 parents 485abac + e7d1cc7 commit b17d475
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 16 deletions.
40 changes: 35 additions & 5 deletions lib/connect_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"

"vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -107,7 +108,15 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane

currentPosition := lastKnownPosition
readDuration := 1 * time.Minute
preamble := fmt.Sprintf("[%v:%v shard : %v] ", ps.Database, tableName, currentPosition.Shard)
preamble := fmt.Sprintf("[%v:%v shard:%v tabletType:%s] ", ps.Database, tableName, currentPosition.Shard, tabletType)

existingColumns, err := p.filterExistingColumns(ctx, ps, tableName, columns)
if err != nil {
logger.Info(fmt.Sprintf("%sCouldn't fetch existing columns, falling back to requested columns: %s", preamble, err.Error()))
}

logger.Info(fmt.Sprintf("%sFiltering with columns %s", preamble, strings.Join(existingColumns, ",")))

for {
logger.Info(preamble + "peeking to see if there's any new rows")
latestCursorPosition, lcErr := p.getLatestCursorPosition(ctx, currentPosition.Shard, currentPosition.Keyspace, tableName, ps, tabletType)
Expand All @@ -123,7 +132,7 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane
logger.Info(fmt.Sprintf(preamble+"new rows found, syncing rows for %v", readDuration))
logger.Info(fmt.Sprintf(preamble+"syncing rows with cursor [%v]", currentPosition))

currentPosition, err = p.sync(ctx, logger, tableName, columns, currentPosition, latestCursorPosition, ps, tabletType, readDuration, onResult, onCursor, onUpdate)
currentPosition, err = p.sync(ctx, logger, tableName, existingColumns, currentPosition, latestCursorPosition, ps, tabletType, readDuration, onResult, onCursor, onUpdate)
if currentPosition.Position != "" {
currentSerializedCursor, sErr = TableCursorToSerializedCursor(currentPosition)
if sErr != nil {
Expand All @@ -135,7 +144,7 @@ func (p connectClient) Read(ctx context.Context, logger DatabaseLogger, ps Plane
if s, ok := status.FromError(err); ok {
// if the error is anything other than server timeout, keep going
if s.Code() != codes.DeadlineExceeded {
logger.Info(fmt.Sprintf("%v Got error [%v] with message [%q], Returning with cursor :[%v] after server timeout", preamble, s.Code(), err, currentPosition))
logger.Info(fmt.Sprintf("%vGot error [%v] with message [%q], Returning with cursor :[%v] after server timeout", preamble, s.Code(), err, currentPosition))
return currentSerializedCursor, nil
} else {
logger.Info(preamble + "Continuing with cursor after server timeout")
Expand All @@ -160,7 +169,7 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam
client psdbconnect.ConnectClient
)

preamble := fmt.Sprintf("[%v:%v shard : %v] ", ps.Database, tableName, tc.Shard)
preamble := fmt.Sprintf("[%v:%v shard:%v tabletType:%s] ", ps.Database, tableName, tc.Shard, tabletType)

if p.clientFn == nil {
conn, err := grpcclient.Dial(ctx, ps.Host,
Expand All @@ -187,7 +196,7 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam
tc.Position = ""
}

logger.Info(fmt.Sprintf("%s Syncing with cursor position : [%v], using last known PK : %v, stop cursor is : [%v]", preamble, tc.Position, tc.LastKnownPk != nil, stopPosition))
logger.Info(fmt.Sprintf("%sSyncing with cursor position : [%v], using last known PK : %v, stop cursor is : [%v]", preamble, tc.Position, tc.LastKnownPk != nil, stopPosition))

sReq := &psdbconnect.SyncRequest{
TableName: tableName,
Expand Down Expand Up @@ -274,6 +283,27 @@ func (p connectClient) sync(ctx context.Context, logger DatabaseLogger, tableNam
}
}

func (p connectClient) filterExistingColumns(ctx context.Context, ps PlanetScaleSource, tableName string, columns []string) ([]string, error) {
existingColumns := []string{}
results, err := (*p.Mysql).GetKeyspaceTableColumns(ctx, ps.Database, tableName)
if err != nil {
existingColumns = columns
} else {
columnSet := map[string]bool{}
for _, result := range results {
columnSet[result.Name] = true
}

for _, c := range columns {
if columnSet[c] {
existingColumns = append(existingColumns, c)
}
}

}
return existingColumns, err
}

func serializeQueryResult(result *query.QueryResult) *sqltypes.Result {
qr := sqltypes.Proto3ToResult(result)
var sqlResult *sqltypes.Result
Expand Down
137 changes: 136 additions & 1 deletion lib/connect_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"vitess.io/vitess/go/vt/proto/query"

"github.com/pkg/errors"
psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1"

"github.com/stretchr/testify/assert"
Expand All @@ -18,6 +19,11 @@ import (
func TestRead_CanPeekBeforeRead(t *testing.T) {
dbl := &dbLogger{}
ped := connectClient{}
getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}, {Name: "email", Type: "varchar(256)", IsPrimaryKey: false}}, nil
}
mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient
tc := &psdbconnect.TableCursor{
Shard: "-",
Position: "THIS_IS_A_SHARD_GTID",
Expand Down Expand Up @@ -62,6 +68,11 @@ func TestRead_CanPeekBeforeRead(t *testing.T) {
func TestRead_CanEarlyExitIfNoNewVGtidInPeek(t *testing.T) {
dbl := &dbLogger{}
ped := connectClient{}
getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}, {Name: "email", Type: "varchar(256)", IsPrimaryKey: false}}, nil
}
mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient
tc := &psdbconnect.TableCursor{
Shard: "-",
Position: "THIS_IS_A_SHARD_GTID",
Expand Down Expand Up @@ -102,6 +113,11 @@ func TestRead_CanEarlyExitIfNoNewVGtidInPeek(t *testing.T) {
func TestRead_CanPickPrimaryForShardedKeyspaces(t *testing.T) {
dbl := &dbLogger{}
ped := connectClient{}
getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}, {Name: "email", Type: "varchar(256)", IsPrimaryKey: false}}, nil
}
mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient
tc := &psdbconnect.TableCursor{
Shard: "40-80",
Position: "THIS_IS_A_SHARD_GTID",
Expand Down Expand Up @@ -144,6 +160,11 @@ func TestRead_CanPickPrimaryForShardedKeyspaces(t *testing.T) {
func TestRead_CanPickReplicaForShardedKeyspaces(t *testing.T) {
dbl := &dbLogger{}
ped := connectClient{}
getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}, {Name: "email", Type: "varchar(256)", IsPrimaryKey: false}}, nil
}
mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient
tc := &psdbconnect.TableCursor{
Shard: "40-80",
Position: "THIS_IS_A_SHARD_GTID",
Expand Down Expand Up @@ -223,6 +244,13 @@ func TestRead_CanReturnNewCursorIfNewFound(t *testing.T) {
onCursor := func(*psdbconnect.TableCursor) error {
return nil
}

getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}, {Name: "email", Type: "varchar(256)", IsPrimaryKey: false}}, nil
}
mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient

sc, err := ped.Read(context.Background(), dbl, ps, "customers", nil, tc, onRow, onCursor, nil)
assert.NoError(t, err)
esc, err := TableCursorToSerializedCursor(newTC)
Expand Down Expand Up @@ -327,6 +355,13 @@ func TestRead_CanStopAtWellKnownCursor(t *testing.T) {
onCursor := func(*psdbconnect.TableCursor) error {
return nil
}

getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return []MysqlColumn{{Name: "id", Type: "bigint", IsPrimaryKey: true}, {Name: "email", Type: "varchar(256)", IsPrimaryKey: false}}, nil
}
mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient

sc, err := ped.Read(context.Background(), dbl, ps, "customers", nil, responses[0].Cursor, onRow, onCursor, nil)

assert.NoError(t, err)
Expand All @@ -336,7 +371,107 @@ func TestRead_CanStopAtWellKnownCursor(t *testing.T) {
assert.Equal(t, esc, sc)
assert.Equal(t, 2, cc.syncFnInvokedCount)

assert.Equal(t, "[connect-test:customers shard : -] Finished reading all rows for table [customers]", dbl.messages[len(dbl.messages)-1].message)
assert.Equal(t, "[connect-test:customers shard:- tabletType:primary] Finished reading all rows for table [customers]", dbl.messages[len(dbl.messages)-1].message)
assert.Equal(t, 2*(nextVGtidPosition/3), insertedRowCounter)
assert.Equal(t, 2*(nextVGtidPosition/3), deletedRowCounter)
}

func TestRead_FiltersNonExistentColumns(t *testing.T) {
tests := []struct {
name string
tableColumns []MysqlColumn
requestedColumns []string
expectedColumns []string
err error
}{
{
name: "filters nonexistent columns",
tableColumns: []MysqlColumn{
{Name: "id", Type: "bigint", IsPrimaryKey: true},
{Name: "email", Type: "varchar(256)", IsPrimaryKey: false},
{Name: "name", Type: "varchar(256)", IsPrimaryKey: false},
},
requestedColumns: []string{"id", "email", "nonexistent_column"},
expectedColumns: []string{"id", "email"},
},
{
name: "uses requested columns on error",
tableColumns: nil,
requestedColumns: []string{"id", "email", "nonexistent_column"},
expectedColumns: []string{"id", "email", "nonexistent_column"},
err: errors.New("error fetching columns"),
},
}

ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dbl := &dbLogger{}
ped := connectClient{}

getKeyspaceTableColumnsFunc := func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
return tt.tableColumns, tt.err
}

newTC := &psdbconnect.TableCursor{
Shard: "-",
Position: "I_AM_FARTHER_IN_THE_BINLOG",
Keyspace: "connect-test",
}

tc := &psdbconnect.TableCursor{
Shard: "-",
Position: "THIS_IS_A_SHARD_GTID",
Keyspace: "connect-test",
}

syncClient := &connectSyncClientMock{
syncResponses: []*psdbconnect.SyncResponse{
{
Cursor: newTC,
},
{
Cursor: newTC,
},
},
}

var firstExpectedColumns []string
run := 1

syncFn := func(ctx context.Context, in *psdbconnect.SyncRequest, opts ...grpc.CallOption) (psdbconnect.Connect_SyncClient, error) {
if run == 1 {
assert.Equal(t, firstExpectedColumns, in.Columns)
} else {
assert.Equal(t, tt.expectedColumns, in.Columns)
}
run += 1
return syncClient, nil
}

mysqlClient := NewTestMysqlClient(getKeyspaceTableColumnsFunc)
ped.Mysql = &mysqlClient

cc := clientConnectionMock{
syncFn: syncFn,
}
ped.clientFn = func(ctx context.Context, ps PlanetScaleSource) (psdbconnect.ConnectClient, error) {
return &cc, nil
}
ps := PlanetScaleSource{}
onRow := func(*sqltypes.Result, Operation) error {
return nil
}
onCursor := func(*psdbconnect.TableCursor) error {
return nil
}
sc, err := ped.Read(ctx, dbl, ps, "customers", tt.requestedColumns, tc, onRow, onCursor, nil)
assert.NoError(t, err)
esc, err := TableCursorToSerializedCursor(newTC)
assert.NoError(t, err)
assert.Equal(t, esc, sc)
assert.Equal(t, 2, cc.syncFnInvokedCount)
})
}
}
5 changes: 3 additions & 2 deletions lib/mysql_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type MysqlClient interface {
BuildSchema(ctx context.Context, psc PlanetScaleSource, schemaBuilder SchemaBuilder) error
PingContext(context.Context, PlanetScaleSource) error
GetVitessShards(ctx context.Context, psc PlanetScaleSource) ([]string, error)
GetKeyspaceTableColumns(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error)
Close() error
}

Expand Down Expand Up @@ -54,7 +55,7 @@ func (p mysqlClient) BuildSchema(ctx context.Context, psc PlanetScaleSource, sch
for _, tableName := range tableNames {
schemaBuilder.OnTable(keyspaceName, tableName)

columns, err := p.getKeyspaceTableColumns(ctx, keyspaceName, tableName)
columns, err := p.GetKeyspaceTableColumns(ctx, keyspaceName, tableName)
if err != nil {
return errors.Wrap(err, "Unable to build schema for database")
}
Expand All @@ -70,7 +71,7 @@ func (p mysqlClient) Close() error {
return p.db.Close()
}

func (p mysqlClient) getKeyspaceTableColumns(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
func (p mysqlClient) GetKeyspaceTableColumns(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
var columns []MysqlColumn
columnNamesQR, err := p.db.QueryContext(
ctx,
Expand Down
32 changes: 24 additions & 8 deletions lib/test_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,26 @@ func (c *clientConnectionMock) Sync(ctx context.Context, in *psdbconnect.SyncReq
}

type (
BuildSchemaFunc func(ctx context.Context, psc PlanetScaleSource, schemaBuilder SchemaBuilder) error
PingContextFunc func(context.Context, PlanetScaleSource) error
GetVitessShardsFunc func(ctx context.Context, psc PlanetScaleSource) ([]string, error)
TestMysqlClient struct {
BuildSchemaFn BuildSchemaFunc
PingContextFn PingContextFunc
GetVitessShardsFn GetVitessShardsFunc
BuildSchemaFunc func(ctx context.Context, psc PlanetScaleSource, schemaBuilder SchemaBuilder) error
PingContextFunc func(context.Context, PlanetScaleSource) error
GetVitessShardsFunc func(ctx context.Context, psc PlanetScaleSource) ([]string, error)
GetKeyspaceTableColumnsFunc func(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error)
TestMysqlClient struct {
BuildSchemaFn BuildSchemaFunc
PingContextFn PingContextFunc
GetVitessShardsFn GetVitessShardsFunc
GetKeyspaceTableColumnsFn GetKeyspaceTableColumnsFunc
}
)

func (t TestMysqlClient) GetKeyspaceTableColumns(ctx context.Context, keyspaceName string, tableName string) ([]MysqlColumn, error) {
if t.GetKeyspaceTableColumnsFn != nil {
return t.GetKeyspaceTableColumnsFn(ctx, keyspaceName, tableName)
}

panic("GetKeyspaceTableColumnsFunc is not implemented")
}

func (t TestMysqlClient) BuildSchema(ctx context.Context, psc PlanetScaleSource, schemaBuilder SchemaBuilder) error {
if t.BuildSchemaFn != nil {
return t.BuildSchemaFn(ctx, psc, schemaBuilder)
Expand All @@ -80,13 +90,19 @@ func (t TestMysqlClient) GetVitessShards(ctx context.Context, psc PlanetScaleSou
if t.GetVitessShardsFn != nil {
return t.GetVitessShardsFn(ctx, psc)
}
panic("GetvitessShards is not implemented")
panic("GetVitessShards is not implemented")
}

func (t TestMysqlClient) Close() error {
return nil
}

func NewTestMysqlClient(gktc GetKeyspaceTableColumnsFunc) MysqlClient {
return &TestMysqlClient{
GetKeyspaceTableColumnsFn: gktc,
}
}

type (
ReadFunc func(ctx context.Context, logger DatabaseLogger, ps PlanetScaleSource, tableName string, columns []string, tc *psdbconnect.TableCursor, onResult OnResult, onCursor OnCursor, onUpdate OnUpdate) (*SerializedCursor, error)
CanConnectFunc func(ctx context.Context, ps PlanetScaleSource) error
Expand Down

0 comments on commit b17d475

Please sign in to comment.