Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added heartbeat to cloud download #173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Config struct {
PollInterval time.Duration
ClientTimeout time.Duration // max time the http request can last
PingTimeout time.Duration // max time allowed for ping
HeartbeatInterval time.Duration
CanUseMultipleCatalogs bool
DriverName string
DriverVersion string
Expand Down Expand Up @@ -70,6 +71,7 @@ func (c *Config) DeepCopy() *Config {
PollInterval: c.PollInterval,
ClientTimeout: c.ClientTimeout,
PingTimeout: c.PingTimeout,
HeartbeatInterval: c.HeartbeatInterval,
CanUseMultipleCatalogs: c.CanUseMultipleCatalogs,
DriverName: c.DriverName,
DriverVersion: c.DriverVersion,
Expand Down Expand Up @@ -189,6 +191,7 @@ func WithDefaults() *Config {
PollInterval: 1 * time.Second,
ClientTimeout: 900 * time.Second,
PingTimeout: 60 * time.Second,
HeartbeatInterval: 30 * time.Second,
CanUseMultipleCatalogs: true,
DriverName: "godatabrickssqlconnector", // important. Do not change
ThriftProtocol: "binary",
Expand Down
20 changes: 19 additions & 1 deletion internal/fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fetcher

import (
"context"
"errors"
"sync"

"github.com/databricks/databricks-sql-go/driverctx"
Expand All @@ -17,6 +18,13 @@ type Fetcher[OutputType any] interface {
Start() (<-chan OutputType, context.CancelFunc, error)
}

// An item that will be stopped/started in sync with
// a fetcher
type Overwatch interface {
Start()
Stop()
}

type concurrentFetcher[I FetchableItems[O], O any] struct {
cancelChan chan bool
inputChan <-chan FetchableItems[O]
Expand All @@ -28,6 +36,7 @@ type concurrentFetcher[I FetchableItems[O], O any] struct {
ctx context.Context
cancelFunc context.CancelFunc
*dbsqllog.DBSQLLogger
overWatch Overwatch
}

func (rf *concurrentFetcher[I, O]) Err() error {
Expand All @@ -40,6 +49,9 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error)
f.start.Do(func() {
// wait group for the worker routines
var wg sync.WaitGroup
if f.overWatch != nil {
f.overWatch.Start()
}

for i := 0; i < f.nWorkers; i++ {

Expand All @@ -64,6 +76,9 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error)
wg.Wait()
f.logger().Trace().Msg("concurrent fetcher closing output channel")
close(f.outChan)
if f.overWatch != nil {
f.overWatch.Stop()
}
}()

// We return a cancel function so that the client can
Expand Down Expand Up @@ -98,7 +113,7 @@ func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger {
return f.DBSQLLogger
}

func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O], overWatch Overwatch) (Fetcher[O], error) {
if nWorkers < 1 {
nWorkers = 1
}
Expand All @@ -123,6 +138,7 @@ func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWork
cancelChan: stopChannel,
ctx: ctx,
nWorkers: nWorkers,
overWatch: overWatch,
}

return fetcher, nil
Expand All @@ -133,10 +149,12 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in
for {
select {
case <-f.cancelChan:
f.setErr(errors.New("fetcher canceled"))
f.logger().Debug().Msgf("concurrent fetcher worker %d received cancel signal", workerIndex)
return

case <-f.ctx.Done():
f.setErr(f.ctx.Err())
f.logger().Debug().Msgf("concurrent fetcher worker %d context done", workerIndex)
return

Expand Down
126 changes: 116 additions & 10 deletions internal/fetcher/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

// Create a mock struct for FetchableItems
Expand All @@ -32,6 +32,13 @@ func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) {

var _ FetchableItems[[]*mockOutput] = (*mockFetchableItem)(nil)

type testOverWatch struct {
started, stopped bool
}

func (ow *testOverWatch) Start() { ow.started = true }
func (ow *testOverWatch) Stop() { ow.stopped = true }

func TestConcurrentFetcher(t *testing.T) {
t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) {
ctx := context.Background()
Expand All @@ -43,8 +50,10 @@ func TestConcurrentFetcher(t *testing.T) {
}
close(inputChan)

ow := &testOverWatch{}

// Create a fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan)
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan, ow)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}
Expand All @@ -60,6 +69,9 @@ func TestConcurrentFetcher(t *testing.T) {
results = append(results, result...)
}

assert.True(t, ow.started)
assert.True(t, ow.stopped)

// Check if the fetcher returned the expected results
expectedLen := 50
if len(results) != expectedLen {
Expand All @@ -83,19 +95,20 @@ func TestConcurrentFetcher(t *testing.T) {

t.Run("Cancel the concurrent fetcher", func(t *testing.T) {
// Create a context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// Create an input channel
inputChan := make(chan FetchableItems[[]*mockOutput], 3)
for i := 0; i < 3; i++ {
item := mockFetchableItem{item: i, wait: 1 * time.Second}
inputChan := make(chan FetchableItems[[]*mockOutput], 5)
for i := 0; i < 5; i++ {
item := mockFetchableItem{item: i, wait: 2 * time.Second}
inputChan <- &item
}
close(inputChan)
ow := &testOverWatch{}

// Create a new fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan)
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan, ow)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}
Expand All @@ -111,13 +124,106 @@ func TestConcurrentFetcher(t *testing.T) {
cancelFunc()
}()

var count int
for range outChan {
// Just drain the channel
count += 1
}

assert.Less(t, count, 5)

err = fetcher.Err()
assert.EqualError(t, err, "fetcher canceled")

assert.True(t, ow.started)
assert.True(t, ow.stopped)
})

t.Run("timeout the concurrent fetcher", func(t *testing.T) {
// Create a context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

// Create an input channel
inputChan := make(chan FetchableItems[[]*mockOutput], 10)
for i := 0; i < 10; i++ {
item := mockFetchableItem{item: i, wait: 1 * time.Second}
inputChan <- &item
}
close(inputChan)

ow := &testOverWatch{}

// Create a new fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan, ow)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}

// Check if an error occurred
if err := fetcher.Err(); err != nil && !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected error: %v", err)
// Start the fetcher
outChan, _, err := fetcher.Start()
if err != nil {
t.Fatal(err)
}

var count int
for range outChan {
// Just drain the channel
count += 1
}

assert.Less(t, count, 10)

err = fetcher.Err()
assert.EqualError(t, err, "context deadline exceeded")

assert.True(t, ow.started)
assert.True(t, ow.stopped)
})

t.Run("context cancel the concurrent fetcher", func(t *testing.T) {
// Create a context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)

// Create an input channel
inputChan := make(chan FetchableItems[[]*mockOutput], 5)
for i := 0; i < 5; i++ {
item := mockFetchableItem{item: i, wait: 2 * time.Second}
inputChan <- &item
}
close(inputChan)

ow := &testOverWatch{}

// Create a new fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan, ow)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}

// Start the fetcher
outChan, _, err := fetcher.Start()
if err != nil {
t.Fatal(err)
}

// Ensure that the fetcher is cancelled successfully
go func() {
cancel()
}()

var count int
for range outChan {
// Just drain the channel
count += 1
}

assert.Less(t, count, 5)

err = fetcher.Err()
assert.EqualError(t, err, "context canceled")

assert.True(t, ow.started)
assert.True(t, ow.stopped)
})
}
18 changes: 16 additions & 2 deletions internal/rows/arrowbased/arrowRecordIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,35 @@ package arrowbased

import (
"context"
"database/sql/driver"
"io"

"github.com/apache/arrow/go/v12/arrow"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
"github.com/databricks/databricks-sql-go/logger"
"github.com/databricks/databricks-sql-go/rows"
)

func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowBatchIterator {
func NewArrowRecordIterator(
ctx context.Context,
rpi rowscanner.ResultPageIterator,
bi BatchIterator,
arrowSchemaBytes []byte,
cfg config.Config,
pinger driver.Pinger,
logger *logger.DBSQLLogger) rows.ArrowBatchIterator {

ari := arrowRecordIterator{
cfg: cfg,
batchIterator: bi,
resultPageIterator: rpi,
ctx: ctx,
arrowSchemaBytes: arrowSchemaBytes,
pinger: pinger,
logger: logger,
}

return &ari
Expand All @@ -34,6 +46,8 @@ type arrowRecordIterator struct {
currentBatch SparkArrowBatch
isFinished bool
arrowSchemaBytes []byte
pinger driver.Pinger
logger *logger.DBSQLLogger
}

var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)
Expand Down Expand Up @@ -175,7 +189,7 @@ func (ri *arrowRecordIterator) newBatchLoader(fr *cli_service.TFetchResultsResp)
var bl BatchLoader
var err error
if len(rowSet.ResultLinks) > 0 {
bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg, ri.pinger, ri.logger)
} else {
bl, err = NewLocalBatchLoader(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
}
Expand Down
6 changes: 4 additions & 2 deletions internal/rows/arrowbased/arrowRecordIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func TestArrowRecordIterator(t *testing.T) {
5000,
nil,
false,
true,
client,
"connectionId",
"correlationId",
Expand All @@ -55,7 +56,7 @@ func TestArrowRecordIterator(t *testing.T) {
assert.Nil(t, err)

cfg := *config.WithDefaults()
rs := NewArrowRecordIterator(context.Background(), rpi, bi, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, cfg)
rs := NewArrowRecordIterator(context.Background(), rpi, bi, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, cfg, nil, nil)
defer rs.Close()

hasNext := rs.HasNext()
Expand Down Expand Up @@ -127,13 +128,14 @@ func TestArrowRecordIterator(t *testing.T) {
5000,
nil,
false,
true,
client,
"connectionId",
"correlationId",
logger)

cfg := *config.WithDefaults()
rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg)
rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg, nil, nil)
defer rs.Close()

hasNext := rs.HasNext()
Expand Down
Loading
Loading