From 957a6b1970eb30074db6d86ad66c2811d2905068 Mon Sep 17 00:00:00 2001
From: brido4125 <hcs4125@gmail.com>
Date: Tue, 8 Oct 2024 15:43:00 +0900
Subject: [PATCH] REFACTOR: methods in MemcachedConnectio invoked with
 updateReplConnection.

---
 .../spy/memcached/MemcachedConnection.java    | 68 ++---------------
 .../spy/memcached/MemcachedReplicaGroup.java  | 48 ++++++++++++
 .../memcached/MemcachedReplicaGroupTest.java  | 76 +++++++++++++++++++
 .../net/spy/memcached/MockMemcachedNode.java  |  6 +-
 4 files changed, 134 insertions(+), 64 deletions(-)
 create mode 100644 src/test/java/net/spy/memcached/MemcachedReplicaGroupTest.java

diff --git a/src/main/java/net/spy/memcached/MemcachedConnection.java b/src/main/java/net/spy/memcached/MemcachedConnection.java
index df4c0071c..fb4b23756 100644
--- a/src/main/java/net/spy/memcached/MemcachedConnection.java
+++ b/src/main/java/net/spy/memcached/MemcachedConnection.java
@@ -32,7 +32,6 @@
 import java.nio.channels.SocketChannel;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -367,38 +366,6 @@ private void updateConnections(List<InetSocketAddress> addrs) throws IOException
   }
 
   /* ENABLE_REPLICATION if */
-  private Set<String> findChangedGroups(List<InetSocketAddress> addrs,
-                                        Collection<MemcachedNode> nodes) {
-    Map<String, InetSocketAddress> addrMap = new HashMap<>();
-    for (InetSocketAddress each : addrs) {
-      addrMap.put(each.toString(), each);
-    }
-
-    Set<String> changedGroupSet = new HashSet<>();
-    for (MemcachedNode node : nodes) {
-      String nodeAddr = ((InetSocketAddress) node.getSocketAddress()).toString();
-      if (addrMap.remove(nodeAddr) == null) { // removed node
-        changedGroupSet.add(node.getReplicaGroup().getGroupName());
-      }
-    }
-    for (String addr : addrMap.keySet()) { // newly added node
-      ArcusReplNodeAddress a = (ArcusReplNodeAddress) addrMap.get(addr);
-      changedGroupSet.add(a.getGroupName());
-    }
-    return changedGroupSet;
-  }
-
-  private List<InetSocketAddress> findAddrsOfChangedGroups(List<InetSocketAddress> addrs,
-                                                           Set<String> changedGroups) {
-    List<InetSocketAddress> changedGroupAddrs = new ArrayList<>();
-    for (InetSocketAddress addr : addrs) {
-      if (changedGroups.contains(((ArcusReplNodeAddress) addr).getGroupName())) {
-        changedGroupAddrs.add(addr);
-      }
-    }
-    return changedGroupAddrs;
-  }
-
   private void updateReplConnections(List<InetSocketAddress> addrs) throws IOException {
     List<MemcachedNode> attachNodes = new ArrayList<>();
     List<MemcachedNode> removeNodes = new ArrayList<>();
@@ -416,10 +383,11 @@ private void updateReplConnections(List<InetSocketAddress> addrs) throws IOExcep
      * we find out the changed groups with the comparison of previous and current znode list,
      * and update the state of groups based on them.
      */
-    Set<String> changedGroups = findChangedGroups(addrs, locator.getAll());
+    Set<String> changedGroups = MemcachedReplicaGroup.findChangedGroups(addrs, locator.getAll());
 
     Map<String, List<ArcusReplNodeAddress>> newAllGroups =
-            ArcusReplNodeAddress.makeGroupAddrsList(findAddrsOfChangedGroups(addrs, changedGroups));
+            ArcusReplNodeAddress.makeGroupAddrsList(
+                    MemcachedReplicaGroup.findAddrsOfChangedGroups(addrs, changedGroups));
 
     // remove invalidated groups in changedGroups
     for (Map.Entry<String, List<ArcusReplNodeAddress>> entry : newAllGroups.entrySet()) {
@@ -467,8 +435,10 @@ private void updateReplConnections(List<InetSocketAddress> addrs) throws IOExcep
       assert oldMasterAddr != null : "invalid old rgroup";
       assert newMasterAddr != null : "invalid new rgroup";
 
-      Set<ArcusReplNodeAddress> oldSlaveAddrs = getAddrsFromNodes(oldSlaveNodes);
-      Set<ArcusReplNodeAddress> newSlaveAddrs = getSlaveAddrsFromGroupAddrs(newGroupAddrs);
+      Set<ArcusReplNodeAddress> oldSlaveAddrs
+              = MemcachedReplicaGroup.getAddrsFromNodes(oldSlaveNodes);
+      Set<ArcusReplNodeAddress> newSlaveAddrs
+              = MemcachedReplicaGroup.getSlaveAddrsFromGroupAddrs(newGroupAddrs);
 
       if (oldMasterAddr.isSameAddress(newMasterAddr)) {
         // add newly added slave node
@@ -560,30 +530,6 @@ private void updateReplConnections(List<InetSocketAddress> addrs) throws IOExcep
     // Remove the unavailable nodes.
     handleNodesToRemove(removeNodes);
   }
-
-  private Set<ArcusReplNodeAddress> getAddrsFromNodes(List<MemcachedNode> nodes) {
-    Set<ArcusReplNodeAddress> addrs = Collections.emptySet();
-    if (!nodes.isEmpty()) {
-      addrs = new HashSet<>((int) (nodes.size() / .75f) + 1);
-      for (MemcachedNode node : nodes) {
-        addrs.add((ArcusReplNodeAddress) node.getSocketAddress());
-      }
-    }
-    return addrs;
-  }
-
-  private Set<ArcusReplNodeAddress> getSlaveAddrsFromGroupAddrs(
-          List<ArcusReplNodeAddress> groupAddrs) {
-    Set<ArcusReplNodeAddress> slaveAddrs = Collections.emptySet();
-    int groupSize = groupAddrs.size();
-    if (groupSize > 1) {
-      slaveAddrs = new HashSet<>((int) ((groupSize - 1) / .75f) + 1);
-      for (int i = 1; i < groupSize; i++) {
-        slaveAddrs.add(groupAddrs.get(i));
-      }
-    }
-    return slaveAddrs;
-  }
   /* ENABLE_REPLICATION end */
 
   /* ENABLE_REPLICATION if */
diff --git a/src/main/java/net/spy/memcached/MemcachedReplicaGroup.java b/src/main/java/net/spy/memcached/MemcachedReplicaGroup.java
index 5c27dbb8a..c7214a56d 100644
--- a/src/main/java/net/spy/memcached/MemcachedReplicaGroup.java
+++ b/src/main/java/net/spy/memcached/MemcachedReplicaGroup.java
@@ -18,9 +18,15 @@
 /* ENABLE_REPLICATION if */
 package net.spy.memcached;
 
+import java.net.InetSocketAddress;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import net.spy.memcached.compat.SpyObject;
 
@@ -190,5 +196,47 @@ private MemcachedNode getNextActiveSlaveNodeNoRotate() {
   public static String getGroupNameFromNode(final MemcachedNode node) {
     return ((ArcusReplNodeAddress) node.getSocketAddress()).getGroupName();
   }
+
+  static Set<String> findChangedGroups(List<InetSocketAddress> newAddrs,
+                                       Collection<MemcachedNode> oldNodes) {
+    Set<String> changedGroupSet = new HashSet<>();
+    Map<String, InetSocketAddress> addrMap = newAddrs.stream()
+            .collect(Collectors.toMap(InetSocketAddress::toString, addr -> addr));
+
+    for (MemcachedNode node : oldNodes) {
+      if (addrMap.remove(node.getSocketAddress().toString()) == null) {
+        changedGroupSet.add(node.getReplicaGroup().getGroupName());
+      }
+    }
+
+    addrMap.values().stream()
+            .map(addr -> ((ArcusReplNodeAddress) addr).getGroupName())
+            .forEach(changedGroupSet::add);
+
+    return changedGroupSet;
+  }
+
+  static List<InetSocketAddress> findAddrsOfChangedGroups(List<InetSocketAddress> newAddrs,
+                                                          Set<String> changedGroups) {
+    List<InetSocketAddress> changedGroupAddrs = new ArrayList<>();
+    newAddrs.stream()
+            .filter(addr -> changedGroups.contains(((ArcusReplNodeAddress) addr).getGroupName()))
+            .forEach(changedGroupAddrs::add);
+    return changedGroupAddrs;
+  }
+
+  static Set<ArcusReplNodeAddress> getAddrsFromNodes(List<MemcachedNode> nodes) {
+    return nodes.stream()
+            .map(node -> (ArcusReplNodeAddress) node.getSocketAddress())
+            .collect(Collectors.toSet());
+  }
+
+  static Set<ArcusReplNodeAddress> getSlaveAddrsFromGroupAddrs(
+          List<ArcusReplNodeAddress> groupAddrs) {
+    if (groupAddrs.size() <= 1) {
+      return Collections.emptySet();
+    }
+    return new HashSet<>(groupAddrs.subList(1, groupAddrs.size()));
+  }
 }
 /* ENABLE_REPLICATION end */
diff --git a/src/test/java/net/spy/memcached/MemcachedReplicaGroupTest.java b/src/test/java/net/spy/memcached/MemcachedReplicaGroupTest.java
new file mode 100644
index 000000000..977150c52
--- /dev/null
+++ b/src/test/java/net/spy/memcached/MemcachedReplicaGroupTest.java
@@ -0,0 +1,76 @@
+package net.spy.memcached;
+
+
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+
+class MemcachedReplicaGroupTest {
+
+  @Test
+  void findChangedGroupsTest() {
+    List<ArcusReplNodeAddress> g0 = createReplList("g0", "192.168.0.1");
+    List<ArcusReplNodeAddress> g1 = createReplList("g1", "192.168.0.2");
+    List<MemcachedNode> old = new ArrayList<>();
+    setReplGroup(g0, old);
+    setReplGroup(g1, old);
+
+    List<InetSocketAddress> update = new ArrayList<>(g0);
+
+    Set<String> changedGroups = MemcachedReplicaGroup.findChangedGroups(update, old);
+    Assertions.assertEquals(1, changedGroups.size());
+    Assertions.assertTrue(changedGroups.contains("g1"));
+  }
+
+  @Test
+  void findAddrsOfChangedGroupsTest() {
+    List<ArcusReplNodeAddress> g0 = createReplList("g0", "192.168.0.1");
+    List<ArcusReplNodeAddress> g1 = createReplList("g1", "192.168.0.2");
+    List<MemcachedNode> old = new ArrayList<>();
+    setReplGroup(g0, old);
+    setReplGroup(g1, old);
+
+    List<InetSocketAddress> update = new ArrayList<>();
+    update.addAll(g0.subList(0, 2));
+    update.addAll(g1.subList(0, 2));
+
+    Set<String> changedGroups = MemcachedReplicaGroup.findChangedGroups(update, old);
+    List<InetSocketAddress> result
+            = MemcachedReplicaGroup.findAddrsOfChangedGroups(update, changedGroups);
+
+    Assertions.assertEquals(4, result.size());
+    Assertions.assertTrue(result.contains(g0.get(0)));
+    Assertions.assertTrue(result.contains(g0.get(1)));
+    Assertions.assertTrue(result.contains(g1.get(0)));
+    Assertions.assertTrue(result.contains(g1.get(1)));
+  }
+
+  private void setReplGroup(List<ArcusReplNodeAddress> group, List<MemcachedNode> old) {
+    List<MockMemcachedNode> collect = group.stream()
+            .map(MockMemcachedNode::new)
+            .collect(Collectors.toList());
+    MemcachedReplicaGroupImpl impl = null;
+    for (MockMemcachedNode node : collect) {
+      if (impl == null) {
+        impl = new MemcachedReplicaGroupImpl(node);
+      } else {
+        node.setReplicaGroup(impl);
+      }
+    }
+    old.addAll(collect);
+  }
+
+  private List<ArcusReplNodeAddress> createReplList(String group, String ip) {
+    List<ArcusReplNodeAddress> replList = new ArrayList<>();
+    replList.add(ArcusReplNodeAddress.create(group, true, ip + ":" + 11211));
+    replList.add(ArcusReplNodeAddress.create(group, false, ip + ":" + (11211 + 1)));
+    replList.add(ArcusReplNodeAddress.create(group, false, ip + ":" + (11211 + 2)));
+    return replList;
+  }
+}
diff --git a/src/test/java/net/spy/memcached/MockMemcachedNode.java b/src/test/java/net/spy/memcached/MockMemcachedNode.java
index 5f354764a..2474288b0 100644
--- a/src/test/java/net/spy/memcached/MockMemcachedNode.java
+++ b/src/test/java/net/spy/memcached/MockMemcachedNode.java
@@ -29,6 +29,7 @@
 
 public class MockMemcachedNode implements MemcachedNode {
   private final InetSocketAddress socketAddress;
+  private MemcachedReplicaGroup memcachedReplicaGroup;
 
   public SocketAddress getSocketAddress() {
     return socketAddress;
@@ -260,13 +261,12 @@ public String getOpQueueStatus() {
 
   @Override
   public void setReplicaGroup(MemcachedReplicaGroup g) {
-    // noop
+    this.memcachedReplicaGroup = g;
   }
 
   @Override
   public MemcachedReplicaGroup getReplicaGroup() {
-    // noop
-    return null;
+    return memcachedReplicaGroup;
   }
 
   @Override