Skip to content

Commit

Permalink
Enhance Neo4j integration test configuration (#544)
Browse files Browse the repository at this point in the history
- Add `getVersionFromDB` to `DbServer` for dynamic version detection
- Remove default version in `VersionOf` if empty
- Default `GetDbServer` to localhost when `TEST_NEO4J_HOST` unset
  • Loading branch information
lucapirolo authored Nov 10, 2023
1 parent 4d8aa8f commit e3e27ee
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
70 changes: 60 additions & 10 deletions neo4j/test-integration/dbserver/dbserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ package dbserver
import (
"context"
"fmt"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
"os"
"strconv"
"sync"

"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"

"github.com/neo4j/neo4j-go-driver/v5/neo4j"
)

Expand All @@ -53,7 +54,7 @@ func GetDbServer(ctx context.Context) DbServer {

if server == nil {
vars := map[string]string{
"TEST_NEO4J_HOST": "",
"TEST_NEO4J_HOST": "localhost",
"TEST_NEO4J_USER": "neo4j",
"TEST_NEO4J_PASS": "password",
"TEST_NEO4J_SCHEME": "neo4j",
Expand All @@ -62,15 +63,14 @@ func GetDbServer(ctx context.Context) DbServer {
"TEST_NEO4J_IS_CLUSTER": "0",
"TEST_NEO4J_VERSION": "",
}
for k1, v1 := range vars {
v2, e2 := os.LookupEnv(k1)
if !e2 && v1 == "" {
panic(fmt.Sprintf("Required environment variable %s is missing", k1))
}
if e2 {
vars[k1] = v2

for k := range vars {
if envVal, exists := os.LookupEnv(k); exists {
vars[k] = envVal
fmt.Printf("Using %s=%s from environment\n", k, envVal)
}
}

key := "TEST_NEO4J_PORT"
port, err := strconv.ParseUint(vars[key], 10, 16)
if err != nil {
Expand All @@ -81,6 +81,7 @@ func GetDbServer(ctx context.Context) DbServer {
if err != nil {
panic(fmt.Sprintf("Unable to parse %s:%s to bool", key, vars[key]))
}

server = &DbServer{
Username: vars["TEST_NEO4J_USER"],
Password: vars["TEST_NEO4J_PASS"],
Expand All @@ -89,13 +90,62 @@ func GetDbServer(ctx context.Context) DbServer {
Port: int(port),
IsCluster: isCluster,
IsEnterprise: vars["TEST_NEO4J_EDITION"] == "enterprise",
Version: VersionOf(vars["TEST_NEO4J_VERSION"]),
}

envVersion := VersionOf(vars["TEST_NEO4J_VERSION"])
setServerVersion(ctx, server, envVersion)

server.deleteData(ctx)
}
return *server
}

// setServerVersion assigns a specific Neo4j version to the DbServer instance.
// It prefers the environment variable TEST_NEO4J_VERSION when available.
// Otherwise, it queries the database to determine the version.
// Panics if neither method can set the version, signaling the need for TEST_NEO4J_VERSION.
func setServerVersion(ctx context.Context, server *DbServer, envVersion Version) {
if envVersion != noVersion {
server.Version = envVersion
} else {
version, err := server.getVersionFromDB(ctx)
if err != nil {
panic("Unable to determine version from database, please set the TEST_NEO4J_VERSION environment variable")
}
server.Version = version
}
}

// getVersionFromDB fetches the Neo4j database version.
// this is used when the TEST_NEO4J_VERSION environment variable is not set.
func (s *DbServer) getVersionFromDB(ctx context.Context) (Version, error) {
driver := s.Driver()
session := driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
defer session.Close(ctx)

result, err := session.Run(ctx, "CALL dbms.components() YIELD versions UNWIND versions AS version RETURN version;", nil)
if err != nil {
return defaultVersion, err
}

record, err := result.Single(ctx)
if err != nil {
return defaultVersion, err
}

versionValue, found := record.Get("version")
if !found {
return defaultVersion, fmt.Errorf("version not found in record")
}

versionString, ok := versionValue.(string)
if !ok {
return defaultVersion, fmt.Errorf("version is not a string")
}

return VersionOf(versionString), nil
}

func (s DbServer) deleteData(ctx context.Context) {
driver := s.Driver()
session := driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
Expand Down
22 changes: 9 additions & 13 deletions neo4j/test-integration/dbserver/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,15 @@ var (
)

func VersionOf(server string) Version {
if server == "" {
return defaultVersion
} else {
if versionMatcher == nil {
versionMatcher = regexp.MustCompile(versionPattern)
}
matches := versionMatcher.FindStringSubmatch(server)
if matches != nil {
major, _ := strconv.Atoi(matches[2])
minor := parseMinor(matches[3])
patch, _ := strconv.Atoi(matches[4])
return Version{major, minor, patch}
}
if versionMatcher == nil {
versionMatcher = regexp.MustCompile(versionPattern)
}
matches := versionMatcher.FindStringSubmatch(server)
if matches != nil {
major, _ := strconv.Atoi(matches[2])
minor := parseMinor(matches[3])
patch, _ := strconv.Atoi(matches[4])
return Version{major, minor, patch}
}

return noVersion
Expand Down

0 comments on commit e3e27ee

Please sign in to comment.