Skip to content

Commit

Permalink
Introduce database abstraction layer (#132)
Browse files Browse the repository at this point in the history
* Introduce database abstraction layer
  • Loading branch information
kgalieva authored Feb 27, 2020
1 parent 3ce07a7 commit 76555c1
Show file tree
Hide file tree
Showing 32 changed files with 149 additions and 52 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ Kindly take note of following options:
* `TESTOPTS=...`: As detailed under https://docs.ruby-lang.org/en/2.1.0/Rake/TestTask.html

Example:
`bundle exec rake test DEBUG=1 TESTOPTS="-v --name=TrivialIntegrationTests#test_logged_query_omits_columns"`
`bundle exec rake test DEBUG=1 TESTOPTS="-v --name=TrivialIntegrationTests#test_logged_query_omits_columns"`
2 changes: 1 addition & 1 deletion batch_writer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package ghostferry

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/sirupsen/logrus"
)
Expand Down
5 changes: 3 additions & 2 deletions binlog_streamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package ghostferry
import (
"context"
"crypto/tls"
"database/sql"
sqlorig "database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"time"

"github.com/siddontang/go-mysql/mysql"
Expand Down Expand Up @@ -350,7 +351,7 @@ func idsOnServer(db *sql.DB) ([]uint32, error) {
server_ids := make([]uint32, 0)
for rows.Next() {
var server_id uint32
var host, port, master_id, slave_uuid sql.NullString
var host, port, master_id, slave_uuid sqlorig.NullString

err = rows.Scan(&server_id, &host, &port, &master_id, &slave_uuid)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion binlog_writer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package ghostferry

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/sirupsen/logrus"
)
Expand Down
5 changes: 3 additions & 2 deletions compression_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package ghostferry

import (
"crypto/md5"
"database/sql"
sqlorig "database/sql"
"encoding/hex"
"errors"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"strconv"
"strings"

Expand Down Expand Up @@ -207,7 +208,7 @@ func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig
return compressionVerifier, nil
}

func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sql.Rows, error) {
func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sqlorig.Rows, error) {
quotedPaginationKey := quoteField(paginationKeyColumn)
sql, args, err := rowSelector(columns, paginationKeyColumn).
From(QuotedTableNameFromString(schema, table)).
Expand Down
2 changes: 1 addition & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package ghostferry
import (
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"io/ioutil"
"time"

Expand Down
9 changes: 5 additions & 4 deletions cursor.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package ghostferry

import (
"database/sql"
sqlorig "database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"strings"

"github.com/Masterminds/squirrel"
Expand All @@ -12,7 +13,7 @@ import (

// both `sql.Tx` and `sql.DB` allow a SQL query to be `Prepare`d
type SqlPreparer interface {
Prepare(string) (*sql.Stmt, error)
Prepare(string) (*sqlorig.Stmt, error)
}

type SqlDBWithFakeRollback struct {
Expand Down Expand Up @@ -255,7 +256,7 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64
return
}

func ScanGenericRow(rows *sql.Rows, columnCount int) (RowData, error) {
func ScanGenericRow(rows *sqlorig.Rows, columnCount int) (RowData, error) {
values := make(RowData, columnCount)
valuePtrs := make(RowData, columnCount)

Expand All @@ -267,7 +268,7 @@ func ScanGenericRow(rows *sql.Rows, columnCount int) (RowData, error) {
return values, err
}

func ScanByteRow(rows *sql.Rows, columnCount int) ([][]byte, error) {
func ScanByteRow(rows *sqlorig.Rows, columnCount int) ([][]byte, error) {
values := make([][]byte, columnCount)
valuePtrs := make(RowData, columnCount)

Expand Down
2 changes: 1 addition & 1 deletion data_iterator.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package ghostferry

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math"
"sync"

Expand Down
2 changes: 1 addition & 1 deletion ferry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package ghostferry

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math"
"net/http"
"os"
Expand Down
2 changes: 1 addition & 1 deletion inline_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package ghostferry
import (
"bytes"
"context"
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"strconv"
"strings"
"sync"
Expand Down
2 changes: 1 addition & 1 deletion iterative_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package ghostferry

import (
"bytes"
"database/sql"
"errors"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math"
"strconv"
"strings"
Expand Down
2 changes: 1 addition & 1 deletion sharding/test/trivial_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package test

import (
"database/sql"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math/rand"
"testing"

Expand Down
2 changes: 1 addition & 1 deletion sharding/testhelpers/unit_test_suite.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package testhelpers

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"net/http"
"net/http/httptest"

Expand Down
88 changes: 88 additions & 0 deletions sqlwrapper/ghostferry_db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package sqlwrapper

import (
"context"
sqlorig "database/sql"
)

type DB struct {
*sqlorig.DB
}

type Tx struct {
*sqlorig.Tx
}

func Open(driverName, dataSourceName string) (*DB, error) {
sqlDB, err := sqlorig.Open(driverName, dataSourceName)
return &DB{sqlDB}, err
}

func (db DB) PrepareContext(ctx context.Context, query string) (*sqlorig.Stmt, error) {
return db.DB.PrepareContext(ctx, query)
}

func (db DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sqlorig.Result, error) {
return db.DB.ExecContext(ctx, query, args...)
}

func (db DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlorig.Rows, error) {
return db.DB.QueryContext(ctx, query, args...)
}

func (db DB) Exec(query string, args ...interface{}) (sqlorig.Result, error) {
return db.DB.Exec(query, args...)
}

func (db DB) Prepare(query string) (*sqlorig.Stmt, error) {
return db.DB.Prepare(query)
}

func (db DB) Query(query string, args ...interface{}) (*sqlorig.Rows, error) {
return db.DB.Query(query, args...)
}

func (db DB) QueryRow(query string, args ...interface{}) *sqlorig.Row {
return db.DB.QueryRow(query, args...)
}

func (db DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlorig.Row {
return db.DB.QueryRowContext(ctx, query, args...)
}

func (db DB) Begin() (*Tx, error) {
tx, err := db.DB.Begin()
return &Tx{tx}, err
}

func (tx Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sqlorig.Result, error) {
return tx.Tx.ExecContext(ctx, query, args...)
}

func (tx Tx) Exec(query string, args ...interface{}) (sqlorig.Result, error) {
return tx.Tx.Exec(query, args...)
}

func (tx Tx) Prepare(query string) (*sqlorig.Stmt, error) {
return tx.Tx.Prepare(query)
}

func (tx Tx) PrepareContext(ctx context.Context, query string) (*sqlorig.Stmt, error) {
return tx.Tx.PrepareContext(ctx, query)
}

func (tx Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlorig.Rows, error) {
return tx.Tx.QueryContext(ctx, query, args)
}

func (tx Tx) Query(query string, args ...interface{}) (*sqlorig.Rows, error) {
return tx.Tx.Query(query, args...)
}

func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlorig.Row {
return tx.Tx.QueryRowContext(ctx, query, args...)
}

func (tx Tx) QueryRow(query string, args ...interface{}) *sqlorig.Row {
return tx.Tx.QueryRow(query, args...)
}
7 changes: 4 additions & 3 deletions table_schema_cache.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package ghostferry

import (
"database/sql"
sqlorig "database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"strings"

sq "github.com/Masterminds/squirrel"
Expand Down Expand Up @@ -172,7 +173,7 @@ func LoadTables(db *sql.DB, tableFilter TableFilter, columnCompressionConfig Col
for _, table := range tableNames {
tableLog := dbLog.WithField("table", table)
tableLog.Debug("fetching table schema")
tableSchema, err := schema.NewTableFromSqlDB(db, dbname, table)
tableSchema, err := schema.NewTableFromSqlDB(db.DB, dbname, table)
if err != nil {
tableLog.WithError(err).Error("cannot fetch table schema from source db")
return tableSchemaCache, err
Expand Down Expand Up @@ -357,7 +358,7 @@ func maxPaginationKey(db *sql.DB, table *TableSchema) (uint64, bool, error) {
err = db.QueryRow(query, args...).Scan(&maxPaginationKey)

switch {
case err == sql.ErrNoRows:
case err == sqlorig.ErrNoRows:
return 0, false, nil
case err != nil:
return 0, false, err
Expand Down
2 changes: 1 addition & 1 deletion test/go/binlog_streamer_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package test

import (
"database/sql"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"strings"
"sync"
"testing"
Expand Down
2 changes: 1 addition & 1 deletion test/go/iterative_verifier_collation_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package test

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"testing"

"github.com/Shopify/ghostferry/testhelpers"
Expand Down
2 changes: 1 addition & 1 deletion test/go/iterative_verifier_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package test

import (
"database/sql"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"testing"

"github.com/Shopify/ghostferry"
Expand Down
2 changes: 1 addition & 1 deletion test/go/iterative_verifier_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package test

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"sort"
"testing"
"time"
Expand Down
2 changes: 1 addition & 1 deletion test/go/lag_throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package test

import (
"context"
"database/sql"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"sync"
"testing"
"time"
Expand Down
2 changes: 1 addition & 1 deletion test/go/race_conditions_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package test

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"testing"
"time"

Expand Down
2 changes: 1 addition & 1 deletion test/go/replication_config_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package test

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math"
"testing"

Expand Down
7 changes: 4 additions & 3 deletions test/go/trivial_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package test

import (
"database/sql"
sqlorig "database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math/rand"
"testing"

Expand Down Expand Up @@ -106,8 +107,8 @@ func TestCopyDataWhileRenamingDatabaseAndTable(t *testing.T) {
targetQuery := fmt.Sprintf("CHECKSUM TABLE `%s`.`%s` EXTENDED", targetDatabaseName, targetTableName)

var tablename string
var sourceChecksum sql.NullInt64
var targetChecksum sql.NullInt64
var sourceChecksum sqlorig.NullInt64
var targetChecksum sqlorig.NullInt64

sourceRow := f.SourceDB.QueryRow(sourceQuery)
err := sourceRow.Scan(&tablename, &sourceChecksum)
Expand Down
2 changes: 1 addition & 1 deletion test/go/types_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package test

import (
"database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"math/rand"
"testing"

Expand Down
Loading

0 comments on commit 76555c1

Please sign in to comment.