Skip to content

Commit

Permalink
Implement Result.PeekRecord + TestKit messages
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdedude committed Jan 13, 2022
1 parent 5906663 commit f172f82
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 28 deletions.
55 changes: 43 additions & 12 deletions neo4j/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Result interface {
// NextRecord returns true if there is a record to be processed, record parameter is set
// to point to current record.
NextRecord(record **Record) bool
// PeekRecord returns true if there is a record after the current one to be processed without advancing the record
// stream, record parameter is set to point to that record if present.
PeekRecord(record **Record) bool
// Err returns the latest error that caused this Next to return false.
Err() error
// Record returns the current record.
Expand All @@ -46,13 +49,16 @@ type Result interface {
}

type result struct {
conn db.Connection
streamHandle db.StreamHandle
cypher string
params map[string]interface{}
record *Record
summary *db.Summary
err error
conn db.Connection
streamHandle db.StreamHandle
cypher string
params map[string]interface{}
record *Record
summary *db.Summary
err error
peekedRecord *Record
peekedSummary *db.Summary
peeked bool
}

func newResult(conn db.Connection, str db.StreamHandle, cypher string, params map[string]interface{}) *result {
Expand All @@ -64,23 +70,48 @@ func newResult(conn db.Connection, str db.StreamHandle, cypher string, params ma
}
}

func (r *result) advance() {
if r.peeked {
r.record = r.peekedRecord
r.summary = r.peekedSummary
r.peeked = false
} else {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
}
}

func (r *result) peek() {
if !r.peeked {
r.peekedRecord, r.peekedSummary, r.err = r.conn.Next(r.streamHandle)
r.peeked = true
}
}

func (r *result) Keys() ([]string, error) {
return r.conn.Keys(r.streamHandle)
}

func (r *result) Next() bool {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
return r.record != nil
}

func (r *result) NextRecord(out **Record) bool {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if out != nil {
*out = r.record
}
return r.record != nil
}

func (r *result) PeekRecord(out **Record) bool {
r.peek()
if out != nil {
*out = r.peekedRecord
}
return r.peekedRecord != nil
}

func (r *result) Record() *Record {
return r.record
}
Expand All @@ -92,7 +123,7 @@ func (r *result) Err() error {
func (r *result) Collect() ([]*Record, error) {
recs := make([]*Record, 0, 1024)
for r.summary == nil && r.err == nil {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if r.record != nil {
recs = append(recs, r.record)
}
Expand All @@ -109,7 +140,7 @@ func (r *result) buffer() {

func (r *result) Single() (*Record, error) {
// Try retrieving the single record
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if r.err != nil {
return nil, wrapError(r.err)
}
Expand All @@ -122,7 +153,7 @@ func (r *result) Single() (*Record, error) {
single := r.record

// Probe connection for more records
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if r.record != nil {
// There were more records, consume the stream since the user didn't
// expect more records and should therefore not use them.
Expand Down
58 changes: 42 additions & 16 deletions testkit-backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ func (b *backend) writeError(err error) {

if isDriverError {
id := b.setError(err)
b.writeResponse("DriverError", map[string]interface{}{"id": id, "msg": err.Error(), "code": code})
b.writeResponse("DriverError", map[string]interface{}{
"id": id,
"errorType": strings.Split(err.Error(), ":")[0],
"msg": err.Error(),
"code": code})
return
}

Expand Down Expand Up @@ -341,6 +345,35 @@ func (s serverAddress) Port() string {
return s.port
}

func (b *backend) writeRecord(result neo4j.Result, record *neo4j.Record, expectRecord *bool) {
if expectRecord != nil {
if *expectRecord && record == nil {
b.writeResponse("BackendError", map[string]interface{}{
"msg": "Found no record where one was expected.",
})
} else if !*expectRecord && record != nil {
b.writeResponse("BackendError", map[string]interface{}{
"msg": "Found a record where none was expected.",
})
}
}
if record != nil {
values := record.Values
cypherValues := make([]interface{}, len(values))
for i, v := range values {
cypherValues[i] = nativeToCypher(v)
}
b.writeResponse("Record", map[string]interface{}{"values": cypherValues})
} else {
err := result.Err()
if err != nil {
b.writeError(err)
return
}
b.writeResponse("NullRecord", nil)
}
}

func (b *backend) handleRequest(req map[string]interface{}) {
name := req["name"].(string)
data := req["data"].(map[string]interface{})
Expand Down Expand Up @@ -557,21 +590,12 @@ func (b *backend) handleRequest(req map[string]interface{}) {
case "ResultNext":
result := b.results[data["resultId"].(string)]
more := result.Next()
if more {
values := result.Record().Values
cypherValues := make([]interface{}, len(values))
for i, v := range values {
cypherValues[i] = nativeToCypher(v)
}
b.writeResponse("Record", map[string]interface{}{"values": cypherValues})
} else {
err := result.Err()
if err != nil {
b.writeError(err)
return
}
b.writeResponse("NullRecord", nil)
}
b.writeRecord(result, result.Record(), &more)
case "ResultPeek":
result := b.results[data["resultId"].(string)]
var record *db.Record = nil
more := result.PeekRecord(&record)
b.writeRecord(result, record, &more)
case "ResultConsume":
result := b.results[data["resultId"].(string)]
summary, err := result.Consume()
Expand All @@ -592,6 +616,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
b.writeResponse("FeatureList", map[string]interface{}{
"features": []string{
"ConfHint:connection.recv_timeout_seconds",
"Feature:API:Result.Peek",
"Feature:Auth:Custom",
"Feature:Auth:Bearer",
"Feature:Auth:Kerberos",
Expand Down Expand Up @@ -669,5 +694,6 @@ func testSkips() map[string]string {
"neo4j.test_authentication.TestAuthenticationBasic.test_error_on_incorrect_credentials_tx": "Driver retries tx on failed authentication.",
"stub.*.test_0_timeout": "Driver omits 0 as tx timeout value",
"stub.*.test_negative_timeout": "Driver omits negative tx timeout values",
"stub.iteration.test_result_peek.TestResultPeek.test_result_peek_with_failure_tx_run": "Driver does not reset failed connection but raises error on Session.Close()",
}
}

0 comments on commit f172f82

Please sign in to comment.