diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index 46458ca1..25b916dc 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -88,27 +88,44 @@ func (r *Router) getTable(database string) (*db.RoutingTable, error) { return dbRouter.table, nil } - var routers []string - if dbRouter != nil { - routers = dbRouter.table.Routers + var ( + table *db.RoutingTable + err error + ) + + // Try last known set of routers if there are any + if dbRouter != nil && len(dbRouter.table.Routers) > 0 { + routers := dbRouter.table.Routers + r.log.Infof(r.logId, "Reading routing table for '%s' from previously known routers: %v", database, routers) + table, err = readTable(context.Background(), r.pool, database, routers, r.routerContext) + } + + // Try initial router if no routers or failed + if table == nil || err != nil { + r.log.Infof(r.logId, "Reading routing table from initial router: %s", r.rootRouter) + table, err = readTable(context.Background(), r.pool, database, []string{r.rootRouter}, r.routerContext) } - if len(routers) == 0 { - routers = []string{r.rootRouter} + + // Use hook to retrieve possibly different set of routers and retry + if err != nil && r.getRouters != nil { + routers := r.getRouters() + r.log.Infof(r.logId, "Reading routing table for '%s' from custom routers: %v", routers) + table, err = readTable(context.Background(), r.pool, database, routers, r.routerContext) } - r.log.Infof(r.logId, "Reading routing table for '%s' from any of %v", database, routers) - table, err := readTable(context.Background(), r.pool, database, routers, r.routerContext) if err != nil { - // Use hook to retrieve possibly different set of routers and retry - if r.getRouters != nil { - routers = r.getRouters() - table, err = readTable(context.Background(), r.pool, database, routers, r.routerContext) - } - if err != nil { - r.log.Error(r.logId, err) - return nil, err - } + r.log.Error(r.logId, err) + return nil, err + } + + if table == nil { + // Safe guard for logical error somewhere else + err = errors.New("No error and no table") + r.log.Error(r.logId, err) + return nil, err } + + // Store the routing table r.dbRouters[database] = &databaseRouter{ table: table, dueUnix: now.Add(time.Duration(table.TimeToLive) * time.Second).Unix(), diff --git a/neo4j/internal/router/router_test.go b/neo4j/internal/router/router_test.go index b3d34a1f..4746d7c0 100644 --- a/neo4j/internal/router/router_test.go +++ b/neo4j/internal/router/router_test.go @@ -135,6 +135,70 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) { assertNum(t, numfetch, 3, "Should have have fetched") } +func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { + borrows := [][]string{} + + conn := &connFake{table: &db.RoutingTable{TimeToLive: 1, Routers: []string{"otherRouter"}}} + var err error + pool := &poolFake{ + borrow: func(names []string, cancel context.CancelFunc) (poolpackage.Connection, error) { + //numfetch++ + borrows = append(borrows, names) + return conn, err //&connFake{table: table}, nil + }, + } + nzero := time.Now() + n := nzero + router := New("rootRouter", func() []string { return []string{} }, nil, pool, logger) + router.now = func() time.Time { + return n + } + dbName := "dbname" + + // First access should trigger initial table read from root router + router.Readers(dbName) + if borrows[0][0] != "rootRouter" { + t.Errorf("Should have connected to root upon first router request") + } + // Next access should go to otherRouter + n = n.Add(2 * time.Second) + router.Readers(dbName) + if borrows[1][0] != "otherRouter" { + t.Errorf("Should have queried other router") + } + // Let the next access first fail when requesting otherRouter and then succeed requesting + // rootRouter + requestedOther := false + requestedRoot := false + pool.borrow = func(names []string, cancel context.CancelFunc) (poolpackage.Connection, error) { + if !requestedOther { + if names[0] != "otherRouter" { + t.Errorf("Expected request for otherRouter") + return nil, errors.New("Wrong") + } + requestedOther = true + return nil, errors.New("some err") + } + if names[0] != "rootRouter" { + t.Errorf("Expected request for rootRouter") + return nil, errors.New("oh") + } + requestedRoot = true + return &connFake{table: &db.RoutingTable{TimeToLive: 1, Readers: []string{"aReader"}}}, nil + } + n = n.Add(2 * time.Second) + readers, err := router.Readers(dbName) + if err != nil { + t.Error(err) + } + if readers[0] != "aReader" { + t.Errorf("Didn't get the expected reader") + } + if !requestedOther || !requestedRoot { + t.Errorf("Should have requested both other and root routers") + } +} + // Verify that when the routing table can not be retrieved from the root router, a callback // should be invoked to get backup routers. func TestUseGetRoutersHookWhenInitialRouterFails(t *testing.T) {