diff --git a/pkg/kwok/controllers/controller.go b/pkg/kwok/controllers/controller.go index 2f7066fe9..6ef517422 100644 --- a/pkg/kwok/controllers/controller.go +++ b/pkg/kwok/controllers/controller.go @@ -134,6 +134,7 @@ func NewController(conf Config) (*Controller, error) { LockPodParallelism: 16, DeletePodParallelism: 16, NodeHasFunc: nodes.Has, // just handle pods that are on nodes we have + NodeInfoGetFunc: nodes.Get, // get node info from the node controller Logger: conf.Logger, FuncMap: funcMap, }) diff --git a/pkg/kwok/controllers/node_controller.go b/pkg/kwok/controllers/node_controller.go index 033fe3410..be6f39343 100644 --- a/pkg/kwok/controllers/node_controller.go +++ b/pkg/kwok/controllers/node_controller.go @@ -47,7 +47,7 @@ type NodeController struct { manageNodesWithLabelSelector string nodeSelectorFunc func(node *corev1.Node) bool lockPodsOnNodeFunc func(ctx context.Context, nodeName string) error - nodesSets *stringSets + nodesSets *nodeSets nodeHeartbeatTemplate string nodeStatusTemplate string renderer *renderer @@ -102,7 +102,7 @@ func NewNodeController(conf NodeControllerConfig) (*NodeController, error) { manageNodesWithLabelSelector: conf.ManageNodesWithLabelSelector, lockPodsOnNodeFunc: conf.LockPodsOnNodeFunc, nodeIP: conf.NodeIP, - nodesSets: newStringSets(), + nodesSets: newNodeSets(), logger: log, nodeHeartbeatTemplate: conf.NodeHeartbeatTemplate, nodeStatusTemplate: conf.NodeStatusTemplate + "\n" + conf.NodeHeartbeatTemplate, @@ -254,7 +254,7 @@ func (c *NodeController) WatchNodes(ctx context.Context, ch chan<- string, opt m case watch.Added, watch.Modified: node := event.Object.(*corev1.Node) if c.needHeartbeat(node) { - c.nodesSets.Put(node.Name) + c.nodesSets.Put(node.Name, node.DeepCopy()) if c.needLockNode(node) { ch <- node.Name } @@ -283,7 +283,7 @@ func (c *NodeController) ListNodes(ctx context.Context, ch chan<- string, opt me return listPager.EachListItem(ctx, opt, func(obj runtime.Object) error { node := obj.(*corev1.Node) if c.needHeartbeat(node) { - c.nodesSets.Put(node.Name) + c.nodesSets.Put(node.Name, node.DeepCopy()) if c.needLockNode(node) { ch <- node.Name } @@ -393,6 +393,10 @@ func (c *NodeController) Has(nodeName string) bool { return c.nodesSets.Has(nodeName) } +func (c *NodeController) Get(nodeName string) *nodeInfo { + return c.nodesSets.Get(nodeName) +} + func (c *NodeController) Size() int { return c.nodesSets.Size() } diff --git a/pkg/kwok/controllers/pod_controller.go b/pkg/kwok/controllers/pod_controller.go index e657d62bd..f07e8ecc8 100644 --- a/pkg/kwok/controllers/pod_controller.go +++ b/pkg/kwok/controllers/pod_controller.go @@ -54,6 +54,7 @@ type PodController struct { nodeIP string cidrIPNet *net.IPNet nodeHasFunc func(nodeName string) bool + nodeInfoGetFunc func(nodeName string) *nodeInfo ipPool *ipPool podStatusTemplate string logger logger.Logger @@ -72,6 +73,7 @@ type PodControllerConfig struct { NodeIP string CIDR string NodeHasFunc func(nodeName string) bool + NodeInfoGetFunc func(nodeName string) *nodeInfo PodStatusTemplate string Logger logger.Logger LockPodParallelism int @@ -109,6 +111,7 @@ func NewPodController(conf PodControllerConfig) (*PodController, error) { cidrIPNet: cidrIPNet, ipPool: newIPPool(cidrIPNet), nodeHasFunc: conf.NodeHasFunc, + nodeInfoGetFunc: conf.NodeInfoGetFunc, logger: log, podStatusTemplate: conf.PodStatusTemplate, lockPodChan: make(chan *corev1.Pod), @@ -120,7 +123,11 @@ func NewPodController(conf PodControllerConfig) (*PodController, error) { "NodeIP": func() string { return n.nodeIP }, - "PodIP": func() string { + "PodIP": func(nodeName string) string { + nodeInfo := n.nodeInfoGetFunc(nodeName) + if nodeInfo != nil && nodeInfo.IPPool != nil { + return nodeInfo.IPPool.Get() + } return n.ipPool.Get() }, } @@ -305,12 +312,10 @@ func (c *PodController) WatchPods(ctx context.Context, lockChan, deleteChan chan } case watch.Deleted: pod := event.Object.(*corev1.Pod) - if c.nodeHasFunc(pod.Spec.NodeName) { - // Recycling PodIP - if pod.Status.PodIP != "" && c.cidrIPNet.Contains(net.ParseIP(pod.Status.PodIP)) { - c.ipPool.Put(pod.Status.PodIP) - } - } + if c.nodeHasFunc(pod.Spec.NodeName) { + // Recycling PodIP + c.reclaimPodIP(pod) + } } case <-ctx.Done(): watcher.Stop() @@ -347,9 +352,7 @@ func (c *PodController) LockPodsOnNode(ctx context.Context, nodeName string) err func (c *PodController) configurePod(pod *corev1.Pod) ([]byte, error) { // Mark the pod IP that existed before the kubelet was started - if c.cidrIPNet.Contains(net.ParseIP(pod.Status.PodIP)) { - c.ipPool.Use(pod.Status.PodIP) - } + c.markPodIP(pod) patch, err := c.computePatchData(pod, c.podStatusTemplate) if err != nil { @@ -400,3 +403,36 @@ func (c *PodController) computePatchData(pod *corev1.Pod, temp string) ([]byte, return patch, nil } +func (c *PodController) markPodIP(pod *corev1.Pod) { + if c.cidrIPNet.Contains(net.ParseIP(pod.Status.PodIP)) { + c.ipPool.Use(pod.Status.PodIP) + } + + nodeInfo := c.nodeInfoGetFunc(pod.Spec.NodeName) + if nodeInfo == nil || nodeInfo.CidrIPNet == nil || nodeInfo.IPPool == nil { + return + } + + if nodeInfo.CidrIPNet.Contains(net.ParseIP(pod.Status.PodIP)) { + nodeInfo.IPPool.Use(pod.Status.PodIP) + } +} + +func (c *PodController) reclaimPodIP(pod *corev1.Pod) { + nodeInfo := c.nodeInfoGetFunc(pod.Spec.NodeName) + if nodeInfo == nil || pod.Status.PodIP == "" { + return + } + + if c.cidrIPNet.Contains(net.ParseIP(pod.Status.PodIP)) { + c.ipPool.Put(pod.Status.PodIP) + } + + if nodeInfo.CidrIPNet == nil || nodeInfo.IPPool == nil { + return + } + + if nodeInfo.CidrIPNet.Contains(net.ParseIP(pod.Status.PodIP)) { + nodeInfo.IPPool.Put(pod.Status.PodIP) + } +} diff --git a/pkg/kwok/controllers/pod_controller_test.go b/pkg/kwok/controllers/pod_controller_test.go index 107e66ba6..6d6d6de64 100644 --- a/pkg/kwok/controllers/pod_controller_test.go +++ b/pkg/kwok/controllers/pod_controller_test.go @@ -70,6 +70,14 @@ func TestPodController(t *testing.T) { nodeHasFunc := func(nodeName string) bool { return strings.HasPrefix(nodeName, "node") } + + nodeInfoGetFunc := func(nodeName string) *nodeInfo { + if nodeHasFunc(nodeName) { + return &nodeInfo{} + } + return nil + } + annotationSelector, _ := labels.Parse("fake=custom") pods, err := NewPodController(PodControllerConfig{ ClientSet: clientset, @@ -78,6 +86,7 @@ func TestPodController(t *testing.T) { DisregardStatusWithAnnotationSelector: annotationSelector.String(), PodStatusTemplate: templates.DefaultPodStatusTemplate, NodeHasFunc: nodeHasFunc, + NodeInfoGetFunc: nodeInfoGetFunc, FuncMap: funcMap, LockPodParallelism: 2, DeletePodParallelism: 2, @@ -174,3 +183,124 @@ func TestPodController(t *testing.T) { } } } + +func TestPodControllerIPPool(t *testing.T) { + clientset := fake.NewSimpleClientset() + + nodeHasFunc := func(nodeName string) bool { + return strings.HasPrefix(nodeName, "node") + } + + node1PodCIDR := "10.0.1.1/24" + node1PodNet, _ := parseCIDR(node1PodCIDR) + node1Info := &nodeInfo{ + CidrIPNet: node1PodNet, + IPPool: newIPPool(node1PodNet), + } + nodeInfoGetFunc := func(nodeName string) *nodeInfo { + if nodeName == "node0" { + return &nodeInfo{} + } + return node1Info + } + + podCIDR := "10.0.0.1/24" + pods, err := NewPodController(PodControllerConfig{ + ClientSet: clientset, + NodeIP: "10.0.0.1", + CIDR: podCIDR, + PodStatusTemplate: templates.DefaultPodStatusTemplate, + NodeHasFunc: nodeHasFunc, + NodeInfoGetFunc: nodeInfoGetFunc, + FuncMap: funcMap, + LockPodParallelism: 2, + DeletePodParallelism: 2, + Logger: testingLogger{t}, + }) + if err != nil { + t.Fatal(fmt.Errorf("new pods controller error: %w", err)) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + t.Cleanup(func() { + cancel() + time.Sleep(time.Second) + }) + + err = pods.Start(ctx) + if err != nil { + t.Fatal(fmt.Errorf("start pods controller error: %w", err)) + } + + var genPod = func(podName, nodeName string) *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: "default", + CreationTimestamp: metav1.Now(), + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + NodeName: nodeName, + }, + } + } + + clientset.CoreV1().Pods("default").Create(ctx, genPod("pod0", "node0"), metav1.CreateOptions{}) + + // sleep 2 seconds to wait for pod0 to be assigned an IP + time.Sleep(2 * time.Second) + + pod0, err := clientset.CoreV1().Pods("default").Get(ctx, "pod0", metav1.GetOptions{}) + if err != nil { + t.Fatal(fmt.Errorf("get pod0 error: %w", err)) + } + + // check if pod0 ip is in default ip cidr + pod0IP := pod0.Status.PodIP + if pod0IP == "" { + t.Fatal(fmt.Errorf("want pod %s to be assign an IP, but got nothing", pod0.Name)) + } + if !pods.ipPool.InUsed(pod0IP) { + t.Fatal(fmt.Errorf("want pod %s ip in %s, but got %s", pod0.Name, podCIDR, pod0IP)) + } + + clientset.CoreV1().Pods("default").Create(ctx, genPod("pod1", "node1"), metav1.CreateOptions{}) + + // sleep 2 seconds to wait for pod0 to be assigned an IP + time.Sleep(2 * time.Second) + + pod1, err := clientset.CoreV1().Pods("default").Get(ctx, "pod1", metav1.GetOptions{}) + if err != nil { + t.Fatal(fmt.Errorf("get pod1 error: %w", err)) + } + + // check if pod1 ip is in node pod cidr + pod1IP := pod1.Status.PodIP + if pod1IP == "" { + t.Fatal(fmt.Errorf("want pod %s to be assign an IP, but got nothing", pod1.Name)) + } + if !node1Info.IPPool.InUsed(pod1IP) { + t.Fatal(fmt.Errorf("want pod %s ip in %s, but got %s", pod1.Name, node1PodCIDR, pod1IP)) + } + + clientset.CoreV1().Pods("default").Delete(ctx, "pod0", metav1.DeleteOptions{}) + // sleep 2 seconds to wait for pod0 to be deleted + time.Sleep(2 * time.Second) + if pods.ipPool.InUsed(pod0IP) { + t.Fatal(fmt.Errorf("want pod0 ip to be reclaimed, but got %s in use", pod0IP)) + } + + clientset.CoreV1().Pods("default").Delete(ctx, "pod1", metav1.DeleteOptions{}) + // sleep 2 seconds to wait for pod1 to be deleted + time.Sleep(2 * time.Second) + if node1Info.IPPool.InUsed(pod1IP) { + t.Fatal(fmt.Errorf("want pod1 ip to be reclaimed, but got %s in use", pod1IP)) + } +} diff --git a/pkg/kwok/controllers/templates/pod.status.tpl b/pkg/kwok/controllers/templates/pod.status.tpl index 253abfe09..3b56e1c4f 100644 --- a/pkg/kwok/controllers/templates/pod.status.tpl +++ b/pkg/kwok/controllers/templates/pod.status.tpl @@ -1,4 +1,5 @@ {{ $startTime := .metadata.creationTimestamp }} +{{ $nodeName := .spec.nodeName }} conditions: - lastTransitionTime: {{ $startTime }} @@ -46,7 +47,7 @@ initContainerStatuses: {{ with .status }} hostIP: {{ with .hostIP }} {{ . }} {{ else }} {{ NodeIP }} {{ end }} -podIP: {{ with .podIP }} {{ . }} {{ else }} {{ PodIP }} {{ end }} +podIP: {{ with .podIP }} {{ . }} {{ else }} {{ PodIP $nodeName }} {{ end }} {{ end }} phase: Running diff --git a/pkg/kwok/controllers/utils.go b/pkg/kwok/controllers/utils.go index 8fe531983..caa131121 100644 --- a/pkg/kwok/controllers/utils.go +++ b/pkg/kwok/controllers/utils.go @@ -22,6 +22,7 @@ import ( "sync" "time" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/labels" ) @@ -116,6 +117,13 @@ func (i *ipPool) Use(ip string) { i.used[ip] = struct{}{} } +func (i *ipPool) InUsed(ip string) bool { + i.mut.Lock() + defer i.mut.Unlock() + _, ok := i.used[ip] + return ok +} + type parallelTasks struct { wg sync.WaitGroup bucket chan struct{} @@ -160,43 +168,71 @@ func (p *parallelTasks) Wait() { p.wg.Wait() } -type stringSets struct { +// nodeInfo holds information about a node +type nodeInfo struct { + CidrIPNet *net.IPNet + IPPool *ipPool +} + +type nodeSets struct { mut sync.RWMutex - sets map[string]struct{} + sets map[string]*nodeInfo } -func newStringSets() *stringSets { - return &stringSets{ - sets: make(map[string]struct{}), +func newNodeSets() *nodeSets { + return &nodeSets{ + sets: make(map[string]*nodeInfo), } } -func (s *stringSets) Size() int { +func (s *nodeSets) Size() int { s.mut.RLock() defer s.mut.RUnlock() return len(s.sets) } -func (s *stringSets) Put(key string) { +func (s *nodeSets) Get(key string) *nodeInfo { + s.mut.RLock() + defer s.mut.RUnlock() + return s.sets[key] +} + +func (s *nodeSets) Put(key string, node *corev1.Node) { s.mut.Lock() defer s.mut.Unlock() - s.sets[key] = struct{}{} + + existing, ok := s.sets[key] + s.sets[key] = &nodeInfo{} + if ok && existing.IPPool != nil { + return + } + + if node.Spec.PodCIDR != "" { + cidrIPNet, err := parseCIDR(node.Spec.PodCIDR) + if err != nil { + return + } + s.sets[key] = &nodeInfo{ + CidrIPNet: cidrIPNet, + IPPool: newIPPool(cidrIPNet), + } + } } -func (s *stringSets) Delete(key string) { +func (s *nodeSets) Delete(key string) { s.mut.Lock() defer s.mut.Unlock() delete(s.sets, key) } -func (s *stringSets) Has(key string) bool { +func (s *nodeSets) Has(key string) bool { s.mut.RLock() defer s.mut.RUnlock() _, ok := s.sets[key] return ok } -func (s *stringSets) Foreach(f func(string)) { +func (s *nodeSets) Foreach(f func(string)) { s.mut.RLock() defer s.mut.RUnlock() for k := range s.sets {