Skip to content

Commit

Permalink
Close underlying connection when context has terminated
Browse files Browse the repository at this point in the history
This allows the Bolt server to ditch the transaction closely after
the context terminates.

Before this, the only way to have the Bolt server release
transactions was to close the entire connection pool (which is
done when closing the driver). That's because the connection is
marked as dead, which means the pool will not act on it upon
return.
  • Loading branch information
fbiville authored Jan 27, 2023
1 parent c2ca69d commit 78c1c75
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 0 deletions.
3 changes: 3 additions & 0 deletions neo4j/internal/bolt/bolt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ func NewBolt3(serverName string, conn net.Conn, logger log.Logger, boltLog log.B
if b.err == nil {
b.err = err
}
if ctxErr := handleTerminatedContextError(err, b.conn); ctxErr != nil {
b.err = ctxErr
}
b.state = bolt3_dead
},
boltLogger: boltLog,
Expand Down
53 changes: 53 additions & 0 deletions neo4j/internal/bolt/bolt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package bolt
import (
"context"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"io"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -770,4 +772,55 @@ func TestBolt3(outer *testing.T) {
}

})

outer.Run("closes underlying socket when context has terminated", func(inner *testing.T) {
ctx := context.Background()
pastDeadline := time.Now().Add(-6 * time.Hour)
pastCtx, pastCtxCancel := context.WithDeadline(context.Background(), pastDeadline)
defer pastCtxCancel()
canceledCtx, cancelFunc := context.WithCancel(ctx)
cancelFunc() // cancel it now
type testCase struct {
description string
ctx context.Context
errorMatch string
}

testCases := []testCase{
{
description: "due to a past deadline",
ctx: pastCtx,
errorMatch: "Timeout while writing to connection",
},
{
description: "because of cancelation",
ctx: canceledCtx,
errorMatch: "Writing to connection has been canceled",
},
}
for _, test := range testCases {
inner.Run(test.description, func(t *testing.T) {
var latch sync.WaitGroup
latch.Add(1)
bolt, cleanup := connectToServer(t, func(srv *bolt3server) {
srv.accept(3)
defer func() {
// test server reaches EOF since the client closes the socket
// this happens before being able to dechunk the run message
AssertDeepEquals(t, recover(), io.EOF)
latch.Done()
}()
srv.waitForRun()
})
defer cleanup()
defer bolt.Close(ctx)

_, err := bolt.Run(test.ctx, idb.Command{Cypher: "UNWIND [1,2] AS k RETURN k"}, idb.TxConfig{Mode: idb.ReadMode})

latch.Wait()
AssertErrorMessageContains(t, err, test.errorMatch)
})
}
})

}
3 changes: 3 additions & 0 deletions neo4j/internal/bolt/bolt4.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ func (b *bolt4) setError(err error, fatal bool) {

// Increase severity even if it was a previous error
if fatal {
if ctxErr := handleTerminatedContextError(err, b.conn); ctxErr != nil {
b.err = ctxErr
}
b.state = bolt4_dead
}

Expand Down
53 changes: 53 additions & 0 deletions neo4j/internal/bolt/bolt4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"context"
"fmt"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"io"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -1264,4 +1266,55 @@ func TestBolt4(outer *testing.T) {
}

})

outer.Run("closes underlying socket when context has terminated", func(inner *testing.T) {
ctx := context.Background()
pastDeadline := time.Now().Add(-6 * time.Hour)
pastCtx, pastCtxCancel := context.WithDeadline(context.Background(), pastDeadline)
defer pastCtxCancel()
canceledCtx, cancelFunc := context.WithCancel(ctx)
cancelFunc() // cancel it now
type testCase struct {
description string
ctx context.Context
errorMatch string
}

testCases := []testCase{
{
description: "due to a past deadline",
ctx: pastCtx,
errorMatch: "Timeout while writing to connection",
},
{
description: "because of cancelation",
ctx: canceledCtx,
errorMatch: "Writing to connection has been canceled",
},
}
for _, test := range testCases {
inner.Run(test.description, func(t *testing.T) {
var latch sync.WaitGroup
latch.Add(1)
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
srv.accept(4)
defer func() {
// test server reaches EOF since the client closes the socket
// this happens before being able to dechunk the run message
AssertDeepEquals(t, recover(), io.EOF)
latch.Done()
}()
srv.waitForRun(nil)
})
defer cleanup()
defer bolt.Close(ctx)

_, err := bolt.Run(test.ctx, idb.Command{Cypher: "UNWIND [1,2] AS k RETURN k"}, idb.TxConfig{Mode: idb.ReadMode})

latch.Wait()
AssertErrorMessageContains(t, err, test.errorMatch)
})
}
})

}
3 changes: 3 additions & 0 deletions neo4j/internal/bolt/bolt5.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ func (b *bolt5) setError(err error, fatal bool) {

// Increase severity even if it was a previous error
if fatal {
if ctxErr := handleTerminatedContextError(err, b.conn); ctxErr != nil {
b.err = ctxErr
}
b.state = bolt5Dead
}

Expand Down
52 changes: 52 additions & 0 deletions neo4j/internal/bolt/bolt5_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"context"
"fmt"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"io"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -1183,4 +1185,54 @@ func TestBolt5(outer *testing.T) {
}

})

outer.Run("closes underlying socket when context has terminated", func(inner *testing.T) {
ctx := context.Background()
pastDeadline := time.Now().Add(-6 * time.Hour)
pastCtx, pastCtxCancel := context.WithDeadline(context.Background(), pastDeadline)
defer pastCtxCancel()
canceledCtx, cancelFunc := context.WithCancel(ctx)
cancelFunc() // cancel it now
type testCase struct {
description string
ctx context.Context
errorMatch string
}

testCases := []testCase{
{
description: "due to a past deadline",
ctx: pastCtx,
errorMatch: "Timeout while writing to connection",
},
{
description: "because of cancelation",
ctx: canceledCtx,
errorMatch: "Writing to connection has been canceled",
},
}
for _, test := range testCases {
inner.Run(test.description, func(t *testing.T) {
var latch sync.WaitGroup
latch.Add(1)
bolt, cleanup := connectToServer(t, func(srv *bolt5server) {
srv.accept(5)
defer func() {
// test server reaches EOF since the client closes the socket
// this happens before being able to dechunk the run message
AssertDeepEquals(t, recover(), io.EOF)
latch.Done()
}()
srv.waitForRun(nil)
})
defer cleanup()
defer bolt.Close(ctx)

_, err := bolt.Run(test.ctx, idb.Command{Cypher: "UNWIND [1,2] AS k RETURN k"}, idb.TxConfig{Mode: idb.ReadMode})

latch.Wait()
AssertErrorMessageContains(t, err, test.errorMatch)
})
}
})
}
50 changes: 50 additions & 0 deletions neo4j/internal/bolt/connections.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package bolt

import (
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
"net"
)

func handleTerminatedContextError(err error, connection net.Conn) error {
if !contextTerminatedErr(err) {
return nil
}
closeErr := connection.Close()
if closeErr == nil {
return nil
}
return errorutil.CombineErrors(err, closeErr)
}

func contextTerminatedErr(err error) bool {
switch err.(type) {
case *ConnectionWriteTimeout:
return true
case *ConnectionReadTimeout:
return true
case *ConnectionWriteCanceled:
return true
case *ConnectionReadCanceled:
return true
}
return false
}
80 changes: 80 additions & 0 deletions neo4j/test-integration/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test_integration

import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/test-integration/dbserver"
"testing"
)

func TestContext(outer *testing.T) {
if testing.Short() {
outer.Skip()
}

ctx := context.Background()
server := dbserver.GetDbServer(ctx)

outer.Run("server does not hold on transaction when driver cancels context", func(t *testing.T) {
driver := server.Driver()
defer driver.Close(ctx)
session := driver.NewSession(ctx, neo4j.SessionConfig{FetchSize: 1})
defer session.Close(ctx)
tx, err := session.BeginTransaction(ctx)
assertNil(t, err)
defer tx.Close(ctx)
results, err := tx.Run(ctx, "UNWIND [1,2,3] AS x RETURN x", nil)
assertNil(t, err)
canceledCtx, cancel := context.WithCancel(ctx)
cancel()

_, err = results.Consume(canceledCtx)

assertStringContains(t, err.Error(), "context canceled")
workloads := listTransactionWorkloads(ctx, driver, server)
// TODO: replace length assertion with query assertion when https://trello.com/c/G14xMoBG is fixed
assertEquals(t, len(workloads), 0)
})
}

func listTransactionWorkloads(ctx context.Context, driver neo4j.DriverWithContext, server dbserver.DbServer) []string {
session := driver.NewSession(ctx, neo4j.SessionConfig{})
defer session.Close(ctx)
transactionQuery := server.GetTransactionWorkloadsQuery()
results, err := session.Run(ctx, transactionQuery, nil)
if err != nil {
panic(err)
}
records, err := results.Collect(ctx)
if err != nil {
panic(err)
}
workloads := make([]string, 0, len(records)-1)
for _, record := range records {
rawQuery, _ := record.Get("query")
query := rawQuery.(string)
if query != transactionQuery {
workloads = append(workloads, query)
}
}
return workloads
}
8 changes: 8 additions & 0 deletions neo4j/test-integration/dbserver/dbserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,11 @@ func (s DbServer) DropDatabaseQuery(db string) string {
func (s DbServer) isV42OrLater(v Version) bool {
return (v.major == 4 && v.minor >= 2) || v.major > 4
}

func (s DbServer) GetTransactionWorkloadsQuery() string {
version := s.Version
if version.LessThan(VersionOf("4.4.0")) {
return "CALL dbms.listTransactions() YIELD status, currentQuery WHERE status = 'Running' RETURN currentQuery AS query"
}
return "SHOW TRANSACTIONS YIELD status, currentQuery WHERE status = 'Running' RETURN currentQuery AS query"
}

0 comments on commit 78c1c75

Please sign in to comment.