diff --git a/Makefile b/Makefile index 0b3d0c9..1116846 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # Image URL to use all building/pushing image targets -IMAGE ?= amazon/aws-network-policy-agent +IMAGE ?= public.ecr.aws/q1l2n4k8/npa VERSION ?= $(shell git describe --tags --always --dirty || echo "unknown") IMAGE_NAME = $(IMAGE)$(IMAGE_ARCH_SUFFIX):$(VERSION) GOLANG_VERSION ?= $(shell cat .go-version) diff --git a/api/v1alpha1/policyendpoints_types.go b/api/v1alpha1/policyendpoints_types.go index eadeecd..000f4da 100644 --- a/api/v1alpha1/policyendpoints_types.go +++ b/api/v1alpha1/policyendpoints_types.go @@ -48,6 +48,9 @@ type Port struct { // EndpointInfo defines the network endpoint information for the policy ingress/egress type EndpointInfo struct { + // Action is the action to enforce on an IP/CIDR (Allow, Deny, Pass) + Action string `json:"action"` + // CIDR is the network address(s) of the endpoint CIDR NetworkAddress `json:"cidr"` @@ -72,6 +75,15 @@ type PodEndpoint struct { // PolicyEndpointSpec defines the desired state of PolicyEndpoint type PolicyEndpointSpec struct { + // IsGlobal specifies whether the parent policy is an admin policy + IsGlobal bool `json:"isGlobal"` + + // Namespaces of the pod selector, will be empty for cluster wide + Namespaces []string `json:"namespaces"` + + // Priority of the policy, lower value is higher priority + Priority int `json:"priority"` + // PodSelector is the podSelector from the policy resource PodSelector *metav1.LabelSelector `json:"podSelector,omitempty"` diff --git a/controllers/policyendpoints_controller.go b/controllers/policyendpoints_controller.go index 014ea59..d5973c2 100644 --- a/controllers/policyendpoints_controller.go +++ b/controllers/policyendpoints_controller.go @@ -19,14 +19,21 @@ package controllers import ( "context" "errors" + "fmt" "net" + "sort" + "strconv" + "strings" "sync" "time" + merge "github.com/aws/aws-network-policy-agent/pkg/utils/mergerules" + policyk8sawsv1 "github.com/aws/aws-network-policy-agent/api/v1alpha1" "github.com/aws/aws-network-policy-agent/pkg/ebpf" "github.com/aws/aws-network-policy-agent/pkg/utils" "github.com/aws/aws-network-policy-agent/pkg/utils/imds" + "github.com/prometheus/client_golang/prometheus" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" @@ -35,6 +42,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "github.com/go-logr/logr" + v1 "k8s.io/api/core/v1" networking "k8s.io/api/networking/v1" ) @@ -107,7 +115,9 @@ type PolicyEndpointsReconciler struct { policyEndpointeBPFContext sync.Map // Maps pod Identifier to list of PolicyEndpoint resources podIdentifierToPolicyEndpointMap sync.Map - // Mutex for operations on PodIdentifierToPolicyEndpointMap + // Maps pod Identifier to list of global PolicyEndpoint resources + podIdentifierToGlobalPolicyEndpointMap sync.Map + // Mutex for operations on PodIdentifierToPolicyEndpointMap and PodIdentifierToGlobalPolicyEndpointMap podIdentifierToPolicyEndpointMapMutex sync.Mutex // Maps PolicyEndpoint resource with a list of local pods policyEndpointSelectorMap sync.Map @@ -249,7 +259,7 @@ func (r *PolicyEndpointsReconciler) reconcilePolicyEndpoint(ctx context.Context, return err } - for podIdentifier, _ := range podIdentifiers { + for podIdentifier := range podIdentifiers { // Derive Ingress IPs from the PolicyEndpoint ingressRules, egressRules, isIngressIsolated, isEgressIsolated, err := r.deriveIngressAndEgressFirewallRules(ctx, podIdentifier, policyEndpoint.Namespace, policyEndpoint.Name, false) @@ -326,7 +336,10 @@ func (r *PolicyEndpointsReconciler) cleanupeBPFProbes(ctx context.Context, targe // Detach eBPF probes attached to the local pods (if required). We should detach eBPF probes if this // is the only PolicyEndpoint resource that applies to this pod. If not, just update the Ingress/Egress Map contents - if _, ok := r.podIdentifierToPolicyEndpointMap.Load(podIdentifier); ok { + _, foundPE := r.podIdentifierToPolicyEndpointMap.Load(podIdentifier) + _, foundGlobalPE := r.podIdentifierToGlobalPolicyEndpointMap.Load(podIdentifier) + ok := foundPE || foundGlobalPE + if ok { ingressRules, egressRules, isIngressIsolated, isEgressIsolated, err = r.deriveIngressAndEgressFirewallRules(ctx, podIdentifier, targetPod.Namespace, policyEndpoint, isDeleteFlow) if err != nil { @@ -377,23 +390,274 @@ func (r *PolicyEndpointsReconciler) cleanupeBPFProbes(ctx context.Context, targe return nil } +func mergeGlobalRulesHelper(rules []policyk8sawsv1.EndpointInfo, ipPorts map[string][]string, logger logr.Logger) map[string][]string { + allProtocol := "ALL" + for _, rule := range rules { + action := rule.Action + if rule.Ports == nil { + rule.Ports = append(rule.Ports, policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&allProtocol), + Port: nil, + EndPort: nil, + }) + } + if val, ok := ipPorts[string(rule.CIDR)]; ok { + for _, prt := range rule.Ports { + modified := false + for i, port := range val { + split := strings.Split(port, "-") + action2 := split[0] + protocol := split[1] + portInt, _ := strconv.Atoi(split[2]) + port := int32(portInt) + portFin := &port + if portInt == 0 { + portFin = nil + } + endportInt, _ := strconv.Atoi(split[3]) + endport := int32(endportInt) + endPortFin := &endport + if endportInt == 0 { + endPortFin = nil + } + tempPort := policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&protocol), + Port: portFin, + EndPort: endPortFin, + } + if protocol == "ALL" || string(*prt.Protocol) == "ALL" || checkPortOverlap(prt, tempPort) { + modified = true + mergedPort := merge.MergePorts(tempPort, prt, action2, action, logger) + val[i] = mergedPort[0] + if len(mergedPort) > 1 { + val = append(val, mergedPort[1]) + } + ipPorts[string(rule.CIDR)] = val + } else { + continue + } + } + if !modified { + zero := int32(0) + if prt.Port == nil { + prt.Port = &zero + } + if prt.EndPort == nil { + prt.EndPort = &zero + } + ipPorts[string(rule.CIDR)] = append(val, fmt.Sprintf("%s-%s-%d-%d", rule.Action, *prt.Protocol, *prt.Port, *prt.EndPort)) + } + } + } else { + arr := []string{} + if rule.Ports == nil { + arr = append(arr, fmt.Sprintf("%s-%s-%d-%d", rule.Action, "ALL", 0, 0)) + } else { + for _, port := range rule.Ports { + if port.EndPort == nil { + temp := int32(0) + port.EndPort = &temp + } + arr = append(arr, fmt.Sprintf("%s-%s-%d-%d", rule.Action, *port.Protocol, *port.Port, *port.EndPort)) + } + } + ipPorts[string(rule.CIDR)] = arr + } + } + return ipPorts +} + +func (r *PolicyEndpointsReconciler) mergeGlobalRules(currentPE *policyk8sawsv1.PolicyEndpoint, ipIngressPorts map[string][]string, ipEgressPorts map[string][]string) (map[string][]string, map[string][]string) { + ipIngressPorts = mergeGlobalRulesHelper(currentPE.Spec.Ingress, ipIngressPorts, r.log) + ipEgressPorts = mergeGlobalRulesHelper(currentPE.Spec.Egress, ipEgressPorts, r.log) + return ipIngressPorts, ipEgressPorts +} + +func mergeLocalRulesHelper(rules []policyk8sawsv1.EndpointInfo, ipPorts map[string][]string) map[string][]string { + tempAll := "ALL" + portZero := int32(0) + for _, rule := range rules { + if rule.Ports == nil { + rule.Ports = append(rule.Ports, policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&tempAll), + Port: &portZero, + EndPort: &portZero, + }) + } + for _, port := range rule.Ports { + endport := port.EndPort + if endport == nil { + endport = &portZero + } + if val, ok := ipPorts[string(rule.CIDR)]; ok { + val = append(val, fmt.Sprintf("%s-%s-%d-%d", "Allow", *port.Protocol, *port.Port, *endport)) + ipPorts[string(rule.CIDR)] = val + } else { + val := []string{} + val = append(val, fmt.Sprintf("%s-%s-%d-%d", "Allow", *port.Protocol, *port.Port, *endport)) + ipPorts[string(rule.CIDR)] = val + } + } + for _, except := range rule.Except { + if val, ok := ipPorts[string(except)]; ok { + val = append(val, fmt.Sprintf("%s-%s-%d-%d", "Deny", "ALL", 0, 0)) + ipPorts[string(except)] = val + } else { + val := []string{} + val = append(val, fmt.Sprintf("%s-%s-%d-%d", "Deny", "ALL", 0, 0)) + ipPorts[string(except)] = val + } + } + } + return ipPorts +} + +func (r *PolicyEndpointsReconciler) mergeLocalRules(currentPE *policyk8sawsv1.PolicyEndpoint, ipIngressPorts map[string][]string, ipEgressPorts map[string][]string) (map[string][]string, map[string][]string) { + ipIngressPorts = mergeLocalRulesHelper(currentPE.Spec.Ingress, ipIngressPorts) + ipEgressPorts = mergeLocalRulesHelper(currentPE.Spec.Egress, ipEgressPorts) + return ipIngressPorts, ipEgressPorts +} + +// If an IP or CIDR matches a NP but not an ANP, we will validate the packet using NP +func (r *PolicyEndpointsReconciler) mergeGlobalLocalRules(ingress map[string][]string, egress map[string][]string, ingressLocal map[string][]string, egressLocal map[string][]string) (map[string][]string, map[string][]string) { + tempIngress := make(map[string][]string) + tempEgress := make(map[string][]string) + for localIP, localVal := range ingressLocal { + merged := false + for ip, val := range ingress { + if ip == localIP { + merged = true + val = merge.MergeGlobalLocalPorts(val, localVal) + ingress[ip] = val + } else { + continue + } + } + if !merged { + tempIngress[localIP] = localVal + } + } + for localIP, localVal := range egressLocal { + merged := false + for ip, val := range egress { + if ip == localIP { + merged = true + val = merge.MergeGlobalLocalPorts(val, localVal) + ingress[ip] = val + } else { + continue + } + } + if !merged { + tempEgress[localIP] = localVal + } + } + for ip, val := range tempIngress { + ingress[ip] = val + } + + for ip, val := range tempEgress { + egress[ip] = val + } + return ingress, egress +} + +func checkPortOverlap(ports, ports2 policyk8sawsv1.Port) bool { + if *ports.Protocol != *ports2.Protocol { + return false + } + if ports.EndPort != nil && ports2.EndPort != nil { + if *ports.EndPort < *ports2.Port || *ports.Port > *ports2.EndPort { + return false + } + return true + } else if ports.EndPort != nil { + if *ports2.Port >= *ports.Port && *ports2.Port <= *ports.EndPort { + return true + } + return false + } else if ports2.EndPort != nil { + if *ports.Port >= *ports2.Port && *ports.Port <= *ports2.EndPort { + return true + } + return false + } + return ports.Port == ports2.Port +} + func (r *PolicyEndpointsReconciler) deriveIngressAndEgressFirewallRules(ctx context.Context, podIdentifier string, resourceNamespace string, resourceName string, isDeleteFlow bool) ([]ebpf.EbpfFirewallRules, []ebpf.EbpfFirewallRules, bool, bool, error) { var ingressRules, egressRules []ebpf.EbpfFirewallRules isIngressIsolated, isEgressIsolated := false, false currentPE := &policyk8sawsv1.PolicyEndpoint{} + globalRules, localRules := []policyk8sawsv1.PolicyEndpoint{}, []policyk8sawsv1.PolicyEndpoint{} + found := false + + if policyEndpointList, ok := r.podIdentifierToGlobalPolicyEndpointMap.Load(podIdentifier); ok { + found = true + for _, policyEndpointResource := range policyEndpointList.([]string) { + peNamespacedName := types.NamespacedName{ + Name: policyEndpointResource, + Namespace: "kube-system", + } + if err := r.k8sClient.Get(ctx, peNamespacedName, currentPE); err != nil { + if apierrors.IsNotFound(err) { + continue + } + return nil, nil, isIngressIsolated, isEgressIsolated, err + } + globalRules = append(globalRules, *currentPE) + } + + // Sort globalRules by priority for ease in merging + sort.SliceStable(globalRules, func(i, j int) bool { + return globalRules[i].Spec.Priority < globalRules[j].Spec.Priority + }) + } + + // Pod has global rules that apply to it, which means we must replace resourceNamespace since it might not be in kube-system + if len(globalRules) > 0 { + namespaces := globalRules[0].Spec.Namespaces + // Longest Prefix Match + sort.Sort(sort.Reverse(sort.StringSlice(namespaces))) + for _, ns := range namespaces { + if strings.HasSuffix(podIdentifier, ns) { + r.log.Info("Found correct namespace", "namespace", ns) + resourceNamespace = ns + break + } + } + } + if policyEndpointList, ok := r.podIdentifierToPolicyEndpointMap.Load(podIdentifier); ok { - r.log.Info("Total number of PolicyEndpoint resources for", "podIdentifier ", podIdentifier, " are ", len(policyEndpointList.([]string))) + found = true for _, policyEndpointResource := range policyEndpointList.([]string) { peNamespacedName := types.NamespacedName{ Name: policyEndpointResource, Namespace: resourceNamespace, } + if err := r.k8sClient.Get(ctx, peNamespacedName, currentPE); err != nil { + if apierrors.IsNotFound(err) { + continue + } + return nil, nil, isIngressIsolated, isEgressIsolated, err + } + localRules = append(localRules, *currentPE) + } + } + + ipIngressPorts := make(map[string][]string) + ipEgressPorts := make(map[string][]string) + ipLocalIngressPorts := make(map[string][]string) + ipLocalEgressPorts := make(map[string][]string) + if found { + r.log.Info("Total number of global PolicyEndpoint resources for", "podIdentifier ", podIdentifier, " are ", len(globalRules)) + for _, policyEndpointResource := range globalRules { if isDeleteFlow { deletedPEParentNPName := utils.GetParentNPNameFromPEName(resourceName) - currentPEParentNPName := utils.GetParentNPNameFromPEName(policyEndpointResource) + currentPEParentNPName := utils.GetParentNPNameFromPEName(policyEndpointResource.Name) if deletedPEParentNPName == currentPEParentNPName { r.log.Info("PE belongs to same NP. Ignore and move on since it's a delete flow", "deletedPE", resourceName, "currentPE", policyEndpointResource) @@ -401,37 +665,122 @@ func (r *PolicyEndpointsReconciler) deriveIngressAndEgressFirewallRules(ctx cont } } - if err := r.k8sClient.Get(ctx, peNamespacedName, currentPE); err != nil { - if apierrors.IsNotFound(err) { + r.log.Info("Deriving Firewall rules for global PolicyEndpoint:", "Name: ", policyEndpointResource.Name) + + ipIngressPorts, ipEgressPorts = r.mergeGlobalRules(&policyEndpointResource, ipIngressPorts, ipEgressPorts) + } + + // Once global rules are processed, "Pass" rules are no longer needed + ipIngressPorts = removePassRules(ipIngressPorts) + ipEgressPorts = removePassRules(ipEgressPorts) + r.log.Info("Merged global ingress rules", "ipIngressPorts", ipIngressPorts) + r.log.Info("Merged global egress rules", "ipEgressPorts", ipEgressPorts) + + r.log.Info("Total number of PolicyEndpoint resources for", "podIdentifier ", podIdentifier, " are ", len(localRules)) + + for _, policyEndpointResource := range localRules { + if isDeleteFlow { + deletedPEParentNPName := utils.GetParentNPNameFromPEName(resourceName) + currentPEParentNPName := utils.GetParentNPNameFromPEName(policyEndpointResource.Name) + if deletedPEParentNPName == currentPEParentNPName { + r.log.Info("PE belongs to same NP. Ignore and move on since it's a delete flow", + "deletedPE", resourceName, "currentPE", policyEndpointResource) continue } - return nil, nil, isIngressIsolated, isEgressIsolated, err } r.log.Info("Deriving Firewall rules for PolicyEndpoint:", "Name: ", currentPE.Name) - for _, endPointInfo := range currentPE.Spec.Ingress { - ingressRules = append(ingressRules, - ebpf.EbpfFirewallRules{ - IPCidr: endPointInfo.CIDR, - Except: endPointInfo.Except, - L4Info: endPointInfo.Ports, - }) - } + ipLocalIngressPorts, ipLocalEgressPorts = r.mergeLocalRules(&policyEndpointResource, ipLocalIngressPorts, ipLocalEgressPorts) - for _, endPointInfo := range currentPE.Spec.Egress { - egressRules = append(egressRules, - ebpf.EbpfFirewallRules{ - IPCidr: endPointInfo.CIDR, - Except: endPointInfo.Except, - L4Info: endPointInfo.Ports, - }) - } - r.log.Info("Total no.of - ", "ingressRules", len(ingressRules), "egressRules", len(egressRules)) - ingressIsolated, egressIsolated := r.deriveDefaultPodIsolation(ctx, currentPE, len(ingressRules), len(egressRules)) + ingressIsolated, egressIsolated := r.deriveDefaultPodIsolation(ctx, &policyEndpointResource, len(ipLocalIngressPorts), len(ipLocalEgressPorts)) isIngressIsolated = isIngressIsolated || ingressIsolated isEgressIsolated = isEgressIsolated || egressIsolated } } + r.log.Info("Merged local ingress rules", "ipLocalIngressPorts", ipLocalIngressPorts) + r.log.Info("Merged local egress rules", "ipLocalEgressPorts", ipLocalEgressPorts) + + ipIngressPorts, ipEgressPorts = r.mergeGlobalLocalRules(ipIngressPorts, ipEgressPorts, ipLocalIngressPorts, ipLocalEgressPorts) + r.log.Info("Merged total ingress rules", "ipIngressPorts", ipIngressPorts) + r.log.Info("Merged total egress rules", "ipEgressPortss", ipEgressPorts) + + ingressAll := ebpf.EbpfFirewallRules{ + IPCidr: "0.0.0.0/0", + Except: nil, + L4Info: nil, + } + + egressAll := ebpf.EbpfFirewallRules{ + IPCidr: "0.0.0.0/0", + Except: nil, + L4Info: nil, + } + + for ip, ports := range ipIngressPorts { + if !strings.Contains(ip, "/") { + ip += "/32" + } + exceptDeny := false + portList := []policyk8sawsv1.Port{} + denyList := []policyk8sawsv1.Port{} + ports = dedupPorts(ports) + for _, port := range ports { + split := strings.Split(port, "-") + action := split[0] + protocol := split[1] + portInt, _ := strconv.Atoi(split[2]) + port := int32(portInt) + endportInt, _ := strconv.Atoi(split[3]) + endport := int32(endportInt) + + if protocol == "ALL" { + if action == "Deny" { + exceptDeny = true + break + } + continue + } + if action == "Deny" { + denyList = append(denyList, policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&protocol), + Port: &port, + EndPort: &endport, + }) + } else if action == "Allow" { + portList = append(portList, policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&protocol), + Port: &port, + EndPort: &endport, + }) + } + + } + if exceptDeny { + ingressRules = append(ingressRules, + ebpf.EbpfFirewallRules{ + IPCidr: policyk8sawsv1.NetworkAddress(ip), + Except: []policyk8sawsv1.NetworkAddress{policyk8sawsv1.NetworkAddress(ip)}, + L4Info: []policyk8sawsv1.Port{}, + L4Deny: []policyk8sawsv1.Port{}, + }) + } else { + ingressRules = append(ingressRules, + ebpf.EbpfFirewallRules{ + IPCidr: policyk8sawsv1.NetworkAddress(ip), + Except: []policyk8sawsv1.NetworkAddress{}, + L4Info: portList, + L4Deny: denyList, + }) + } + } + + // Only append "all" traffic rule if global rules exist but local rules don't + if len(globalRules) != 0 && len(localRules) == 0 { + ingressRules = append(ingressRules, ingressAll) + egressRules = append(egressRules, egressAll) + } + + r.log.Info("Total no.of - ", "ingressRules", len(ingressRules), "egressRules", len(egressRules)) if len(ingressRules) > 0 { isIngressIsolated = false } @@ -441,6 +790,40 @@ func (r *PolicyEndpointsReconciler) deriveIngressAndEgressFirewallRules(ctx cont return ingressRules, egressRules, isIngressIsolated, isEgressIsolated, nil } +func removePassRules(ipPorts map[string][]string) map[string][]string { + for ip, ports := range ipPorts { + tempPorts := []string{} + for _, port := range ports { + portSplit := strings.Split(port, "-") + if portSplit[0] != "Pass" { + tempPorts = append(tempPorts, port) + } + } + if len(tempPorts) == 0 { + delete(ipPorts, ip) + } else { + ipPorts[ip] = tempPorts + } + } + return ipPorts +} + +func dedupPorts(ports []string) []string { + dedup := make(map[string]bool) + for _, port := range ports { + if _, ok := dedup[port]; ok { + continue + } + dedup[port] = true + } + + dedupedPorts := []string{} + for port := range dedup { + dedupedPorts = append(dedupedPorts, port) + } + return dedupedPorts +} + func (r *PolicyEndpointsReconciler) deriveDefaultPodIsolation(ctx context.Context, policyEndpoint *policyk8sawsv1.PolicyEndpoint, ingressRulesCount, egressRulesCount int) (bool, bool) { isIngressIsolated, isEgressIsolated := false, false @@ -514,7 +897,7 @@ func (r *PolicyEndpointsReconciler) deriveTargetPodsForParentNP(ctx context.Cont currentTargetPods, currentPodIdentifiers := r.deriveTargetPods(ctx, currentPE, parentPEList) r.log.Info("Adding to current targetPods", "Total pods: ", len(currentTargetPods)) targetPods = append(targetPods, currentTargetPods...) - for podIdentifier, _ := range currentPodIdentifiers { + for podIdentifier := range currentPodIdentifiers { podIdentifiers[podIdentifier] = true targetPodIdentifiers = append(targetPodIdentifiers, podIdentifier) } @@ -566,7 +949,7 @@ func (r *PolicyEndpointsReconciler) deriveTargetPods(ctx context.Context, podIdentifiers[podIdentifier] = true r.log.Info("Derived ", "Pod identifier: ", podIdentifier) } - r.updatePodIdentifierToPEMap(ctx, podIdentifier, parentPEList) + r.updatePodIdentifierToPEMap(ctx, podIdentifier, parentPEList, policyEndpoint.Spec.IsGlobal) } return targetPods, podIdentifiers } @@ -594,11 +977,37 @@ func (r *PolicyEndpointsReconciler) getPodListToBeCleanedUp(oldPodSet []types.Na } func (r *PolicyEndpointsReconciler) updatePodIdentifierToPEMap(ctx context.Context, podIdentifier string, - parentPEList []string) { + parentPEList []string, isGlobal bool) { r.podIdentifierToPolicyEndpointMapMutex.Lock() defer r.podIdentifierToPolicyEndpointMapMutex.Unlock() var policyEndpoints []string + if isGlobal { + r.log.Info("Current Global PE Count for Parent NP:", "Count: ", len(parentPEList)) + if currentPESet, ok := r.podIdentifierToGlobalPolicyEndpointMap.Load(podIdentifier); ok { + policyEndpoints = currentPESet.([]string) + for _, policyEndpointResourceName := range parentPEList { + r.log.Info("Global PE for parent NP", "name", policyEndpointResourceName) + addPEResource := true + for _, pe := range currentPESet.([]string) { + if pe == policyEndpointResourceName { + //Nothing to do if this PE is already tracked against this podIdentifier + addPEResource = false + break + } + } + if addPEResource { + r.log.Info("Adding PE", "name", policyEndpointResourceName, "for podIdentifier", podIdentifier) + policyEndpoints = append(policyEndpoints, policyEndpointResourceName) + } + } + } else { + policyEndpoints = append(policyEndpoints, parentPEList...) + } + r.podIdentifierToGlobalPolicyEndpointMap.Store(podIdentifier, policyEndpoints) + return + } + r.log.Info("Current PE Count for Parent NP:", "Count: ", len(parentPEList)) if currentPESet, ok := r.podIdentifierToPolicyEndpointMap.Load(podIdentifier); ok { policyEndpoints = currentPESet.([]string) @@ -621,7 +1030,6 @@ func (r *PolicyEndpointsReconciler) updatePodIdentifierToPEMap(ctx context.Conte policyEndpoints = append(policyEndpoints, parentPEList...) } r.podIdentifierToPolicyEndpointMap.Store(podIdentifier, policyEndpoints) - return } func (r *PolicyEndpointsReconciler) deriveStalePodIdentifiers(ctx context.Context, resourceName string, @@ -662,6 +1070,17 @@ func (r *PolicyEndpointsReconciler) deletePolicyEndpointFromPodIdentifierMap(ctx } r.podIdentifierToPolicyEndpointMap.Store(podIdentifier, currentPEList) } + + var currentGlobalPEList []string + if policyEndpointList, ok := r.podIdentifierToGlobalPolicyEndpointMap.Load(podIdentifier); ok { + for _, policyEndpointName := range policyEndpointList.([]string) { + if policyEndpointName == policyEndpoint { + continue + } + currentGlobalPEList = append(currentGlobalPEList, policyEndpointName) + } + r.podIdentifierToGlobalPolicyEndpointMap.Store(podIdentifier, currentGlobalPEList) + } } func (r *PolicyEndpointsReconciler) addCatchAllEntry(ctx context.Context, firewallRules *[]ebpf.EbpfFirewallRules) { @@ -674,8 +1093,6 @@ func (r *PolicyEndpointsReconciler) addCatchAllEntry(ctx context.Context, firewa IPCidr: catchAllRule.CIDR, L4Info: catchAllRule.Ports, }) - - return } // SetupWithManager sets up the controller with the Manager. @@ -743,5 +1160,11 @@ func (r *PolicyEndpointsReconciler) ArePoliciesAvailableInLocalCache(podIdentifi return true } } + if policyEndpointList, ok := r.podIdentifierToGlobalPolicyEndpointMap.Load(podIdentifier); ok { + if len(policyEndpointList.([]string)) > 0 { + r.log.Info("Active policies available against", "podIdentifier", podIdentifier) + return true + } + } return false } diff --git a/pkg/clihelper/show.go b/pkg/clihelper/show.go index 3216565..37e59e9 100644 --- a/pkg/clihelper/show.go +++ b/pkg/clihelper/show.go @@ -127,6 +127,7 @@ func MapWalk(mapID int) error { fmt.Println("Protocol - ", utils.GetProtocol(int(iterValue[i].Protocol))) fmt.Println("StartPort - ", iterValue[i].StartPort) fmt.Println("Endport - ", iterValue[i].EndPort) + fmt.Println("Allow - ", iterValue[i].Allow) fmt.Println("-------------------") } fmt.Println("*******************************") @@ -238,6 +239,7 @@ func MapWalkv6(mapID int) error { fmt.Println("Protocol - ", utils.GetProtocol(int(iterValue[i].Protocol))) fmt.Println("StartPort - ", iterValue[i].StartPort) fmt.Println("Endport - ", iterValue[i].EndPort) + fmt.Println("Allow - ", iterValue[i].Allow) fmt.Println("-------------------") } fmt.Println("*******************************") diff --git a/pkg/ebpf/bpf_client.go b/pkg/ebpf/bpf_client.go index de8a557..2006528 100644 --- a/pkg/ebpf/bpf_client.go +++ b/pkg/ebpf/bpf_client.go @@ -106,6 +106,7 @@ type EbpfFirewallRules struct { IPCidr v1alpha1.NetworkAddress Except []v1alpha1.NetworkAddress L4Info []v1alpha1.Port + L4Deny []v1alpha1.Port } func NewBpfClient(policyEndpointeBPFContext *sync.Map, nodeIP string, enablePolicyEventLogs, enableCloudWatchLogs bool, @@ -744,7 +745,6 @@ func (l *bpfClient) updateEbpfMap(mapToUpdate goebpfmaps.BpfMap, firewallRules [ func sortFirewallRulesByPrefixLength(rules []EbpfFirewallRules, prefixLenStr string) { sort.Slice(rules, func(i, j int) bool { - prefixSplit := strings.Split(prefixLenStr, "/") prefixLen, _ := strconv.Atoi(prefixSplit[1]) prefixLenIp1 := prefixLen @@ -815,7 +815,7 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew //Traffic from the local node should always be allowed. Add NodeIP by default to map entries. _, mapKey, _ := net.ParseCIDR(l.nodeIP + l.hostMask) key := utils.ComputeTrieKey(*mapKey, l.enableIPv6) - value := utils.ComputeTrieValue([]v1alpha1.Port{}, l.logger, true, false) + value := utils.ComputeTrieValue([]v1alpha1.Port{}, []v1alpha1.Port{}, l.logger, true, false) firewallMap[string(key)] = value //Sort the rules @@ -827,10 +827,9 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew //Add the Catch All IP entry _, mapKey, _ := net.ParseCIDR("0.0.0.0/0") key := utils.ComputeTrieKey(*mapKey, l.enableIPv6) - value := utils.ComputeTrieValue(catchAllIPPorts, l.logger, allowAll, false) + value := utils.ComputeTrieValue(catchAllIPPorts, []v1alpha1.Port{}, l.logger, allowAll, false) firewallMap[string(key)] = value } - for _, firewallRule := range firewallRules { var cidrL4Info []v1alpha1.Port @@ -842,8 +841,8 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew continue } - if !utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) { - if len(firewallRule.L4Info) == 0 { + if !utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) && len(firewallRule.Except) == 0 { + if len(firewallRule.L4Info) == 0 && len(firewallRule.L4Deny) == 0 { l.logger.Info("No L4 specified. Add Catch all entry: ", "CIDR: ", firewallRule.IPCidr) l.addCatchAllL4Entry(&firewallRule) l.logger.Info("Total L4 entries ", "count: ", len(firewallRule.L4Info)) @@ -888,7 +887,7 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew firewallRule.L4Info = mergedL4Info } - firewallValue := utils.ComputeTrieValue(firewallRule.L4Info, l.logger, allowAll, false) + firewallValue := utils.ComputeTrieValue(firewallRule.L4Info, firewallRule.L4Deny, l.logger, allowAll, false) firewallMap[string(firewallKey)] = firewallValue } if firewallRule.Except != nil { @@ -900,7 +899,7 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info) firewallRule.L4Info = mergedL4Info } - value := utils.ComputeTrieValue(firewallRule.L4Info, l.logger, false, true) + value := utils.ComputeTrieValue(firewallRule.L4Info, []v1alpha1.Port{}, l.logger, false, true) firewallMap[string(key)] = value } } diff --git a/pkg/ebpf/c/tc.v4egress.bpf.c b/pkg/ebpf/c/tc.v4egress.bpf.c index 0fce1a2..f87971b 100644 --- a/pkg/ebpf/c/tc.v4egress.bpf.c +++ b/pkg/ebpf/c/tc.v4egress.bpf.c @@ -39,6 +39,7 @@ struct lpm_trie_val { __u32 protocol; __u32 start_port; __u32 end_port; + __u32 allow; }; struct conntrack_key { diff --git a/pkg/ebpf/c/tc.v4ingress.bpf.c b/pkg/ebpf/c/tc.v4ingress.bpf.c index a6d8312..e722ba0 100644 --- a/pkg/ebpf/c/tc.v4ingress.bpf.c +++ b/pkg/ebpf/c/tc.v4ingress.bpf.c @@ -39,6 +39,7 @@ struct lpm_trie_val { __u32 protocol; __u32 start_port; __u32 end_port; + __u32 allow; }; struct conntrack_key { diff --git a/pkg/ebpf/c/tc.v6egress.bpf.c b/pkg/ebpf/c/tc.v6egress.bpf.c index c8550fd..a82a8e8 100644 --- a/pkg/ebpf/c/tc.v6egress.bpf.c +++ b/pkg/ebpf/c/tc.v6egress.bpf.c @@ -39,6 +39,7 @@ struct lpm_trie_val { __u32 protocol; __u32 start_port; __u32 end_port; + __u32 allow; }; diff --git a/pkg/ebpf/c/tc.v6ingress.bpf.c b/pkg/ebpf/c/tc.v6ingress.bpf.c index 013ea54..9c4cb47 100644 --- a/pkg/ebpf/c/tc.v6ingress.bpf.c +++ b/pkg/ebpf/c/tc.v6ingress.bpf.c @@ -42,6 +42,7 @@ struct lpm_trie_val { __u32 protocol; __u32 start_port; __u32 end_port; + __u32 allow; }; struct conntrack_key { diff --git a/pkg/ebpf/conntrack/conntrack_client.go b/pkg/ebpf/conntrack/conntrack_client.go index d6398c4..c487956 100644 --- a/pkg/ebpf/conntrack/conntrack_client.go +++ b/pkg/ebpf/conntrack/conntrack_client.go @@ -167,7 +167,7 @@ func (c *conntrackClient) CleanupConntrackMap() { } } // Check if the local cache and kernel cache is in sync - for localConntrackEntry, _ := range c.localConntrackV4Cache { + for localConntrackEntry := range c.localConntrackV4Cache { newKey := utils.ConntrackKey{} newKey.Source_ip = utils.ConvIPv4ToInt(utils.ConvIntToIPv4(localConntrackEntry.Source_ip)) newKey.Source_port = localConntrackEntry.Source_port @@ -189,7 +189,6 @@ func (c *conntrackClient) CleanupConntrackMap() { c.logger.Info("Done cleanup of conntrack map") c.hydratelocalConntrack = true } - return } func (c *conntrackClient) Cleanupv6ConntrackMap() { @@ -324,7 +323,7 @@ func (c *conntrackClient) Cleanupv6ConntrackMap() { } // Check if the local cache and kernel cache is in sync - for localConntrackEntry, _ := range c.localConntrackV6Cache { + for localConntrackEntry := range c.localConntrackV6Cache { _, ok := kernelConntrackV6Cache[localConntrackEntry] if !ok { // Delete the entry in local cache since kernel entry is still missing so expired case @@ -341,7 +340,6 @@ func (c *conntrackClient) Cleanupv6ConntrackMap() { c.logger.Info("Done cleanup of conntrack map") c.hydratelocalConntrack = true } - return } func (c *conntrackClient) printByteArray(byteArray []byte) { diff --git a/pkg/utils/mergerules/mergerules.go b/pkg/utils/mergerules/mergerules.go new file mode 100644 index 0000000..ec5e367 --- /dev/null +++ b/pkg/utils/mergerules/mergerules.go @@ -0,0 +1,269 @@ +package mergerules + +import ( + "fmt" + "math" + "strconv" + "strings" + + policyk8sawsv1 "github.com/aws/aws-network-policy-agent/api/v1alpha1" + "github.com/go-logr/logr" + v1 "k8s.io/api/core/v1" +) + +// NOTE: First parameter will always be the higher priority. ports/action is higher prioirty than ports2/action2. + +func MergePorts(ports, ports2 policyk8sawsv1.Port, action, action2 string, logger logr.Logger) []string { + logger.Info("Merging ports") + if string(*ports.Protocol) == "ALL" && string(*ports2.Protocol) == "ALL" { + return mergeGlobalIPs(ports, ports2, action, action2) + } + if string(*ports.Protocol) == "ALL" { + return mergeGlobalIPPorts(ports, ports2, action, action2) + } + if string(*ports2.Protocol) == "ALL" { + return mergeGlobalPortsIP(ports, ports2, action, action2) + } + return mergeGlobalPorts(ports, ports2, action, action2, logger) +} + +func mergeGlobalIPs(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + if action == action2 || action == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, 0, 0)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, 0, 0)} +} + +func mergeGlobalIPPorts(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + if ports2.EndPort == nil { + zero := int32(0) + ports2.EndPort = &zero + } + if action == action2 { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, 0, 0)} + } + if action == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, 0, 0), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, 0, 0)} +} + +func mergeGlobalPortsIP(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + zero := int32(0) + if ports.EndPort == nil { + ports.EndPort = &zero + } + if action == action2 || action == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, 0, 0)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, 0, 0)} +} + +func mergeGlobalPorts(ports, ports2 policyk8sawsv1.Port, action, action2 string, logger logr.Logger) []string { + portRange := isRange(ports) + portRange2 := isRange(ports2) + if portRange && portRange2 { + if action == action2 { + startPort := math.Min(float64(*ports.Port), float64(*ports2.Port)) + endPort := math.Max(float64(*ports.EndPort), float64(*ports2.EndPort)) + return []string{fmt.Sprintf("%s-%s-%f-%f", action, *ports.Protocol, startPort, endPort)} + } else if action == "Allow" { + if *ports2.Port <= *ports.Port && *ports2.EndPort >= *ports.EndPort { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } else { + if *ports.Port > *ports2.Port && *ports.EndPort < *ports2.EndPort { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports.EndPort-1), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports.EndPort+1, *ports2.EndPort)} + } else if *ports.Port <= *ports2.Port && *ports.EndPort >= *ports2.EndPort { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } else if *ports.Port <= *ports2.Port && *ports.EndPort < *ports2.EndPort { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports.EndPort+1, *ports2.EndPort)} + } + // Case: *ports.Port >= *ports2.Port && *ports.EndPort > *ports2.EndPort + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports.Port-1)} + } + } else if portRange { + if action == action2 { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } else if action == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, 0)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } else if portRange2 { + if action == action2 { + // return the portrange portrange2 + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, 0)} + } else if action == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, 0)} + } else if action == "Pass" { + if ports.Port == ports2.Port { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port+1, *ports2.EndPort)} + } else if ports.Port == ports2.EndPort { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort-1)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports.Port-1), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports.Port+1, *ports2.EndPort)} + } else { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports.Protocol, *ports.Port, 0), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + } + // Case where neither are port ranges + if action == action2 { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, 0)} + } else if action == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, 0)} + } + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, 0)} +} + +func MergeGlobalLocalPorts(global []string, local []string) []string { + for _, loc := range local { + localSplit := strings.Split(loc, "-") + localPortInt, _ := strconv.Atoi(localSplit[2]) + localPrt := int32(localPortInt) + localEndportInt, _ := strconv.Atoi(localSplit[3]) + localEndport := int32(localEndportInt) + localPort := policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&localSplit[1]), + Port: &localPrt, + EndPort: &localEndport, + } + for _, glo := range global { + globalSplit := strings.Split(glo, "-") + globalPortInt, _ := strconv.Atoi(globalSplit[2]) + globalPrt := int32(globalPortInt) + globalEndportInt, _ := strconv.Atoi(globalSplit[3]) + globalEndport := int32(globalEndportInt) + globalPort := policyk8sawsv1.Port{ + Protocol: (*v1.Protocol)(&globalSplit[1]), + Port: &globalPrt, + EndPort: &globalEndport, + } + if localSplit[1] == "ALL" && globalSplit[1] == "ALL" { + return mergeGlobalIPLocalIPs(globalPort, localPort, globalSplit[0], localSplit[0]) + } else if localSplit[1] == "ALL" { + return mergeGlobalPortsLocalIPs(globalPort, localPort, globalSplit[0], localSplit[0]) + } else if globalSplit[1] == "ALL" { + return mergeGlobalIPLocalPorts(globalPort, localPort, globalSplit[0], localSplit[0]) + } else { + return mergeGlobalPortsLocalPorts(globalPort, localPort, globalSplit[0], localSplit[0]) + } + } + } + return nil +} + +func mergeGlobalIPLocalIPs(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + if action == action2 { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + if action == "Allow" { + if action2 == "Deny" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + } else if action == "Deny" { + if action2 == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + } + return nil +} + +func mergeGlobalIPLocalPorts(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + if action == action2 { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + if action == "Allow" { + if action2 == "Deny" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + } else if action == "Deny" { + if action2 == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + } + return nil +} + +func mergeGlobalPortsLocalIPs(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + zero := int32(0) + if ports2.EndPort == nil { + ports2.EndPort = &zero + } + if ports.EndPort == nil { + ports.EndPort = &zero + } + if action == action2 { + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + if action == "Allow" { + if action2 == "Deny" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + } else if action == "Deny" { + if action2 == "Allow" { + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + } + return nil +} + +func mergeGlobalPortsLocalPorts(ports, ports2 policyk8sawsv1.Port, action, action2 string) []string { + portRange := isRange(ports) + portRange2 := isRange(ports2) + if portRange && portRange2 { + //TODO + if action == action2 { + + } + } else if portRange { + if action == action2 { + // return the portrange portrange + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } else if action == "Allow" { + // append deny port + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } else { + if action2 == "Allow" { + //return deny port range + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, *ports.EndPort)} + } + } + } else if portRange2 { + if action == action2 { + // return the portrange portrange2 + return []string{fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, 0)} + } else if action == "Deny" { + if action2 == "Allow" { + // append deny + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, 0), + fmt.Sprintf("%s-%s-%d-%d", action2, *ports2.Protocol, *ports2.Port, *ports2.EndPort)} + } + } + } else { + if action == action2 { + //do nothing + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, 0)} + } else if action == "Deny" { + // Do nothing + return []string{fmt.Sprintf("%s-%s-%d-%d", action, *ports.Protocol, *ports.Port, 0)} + } + } + return nil +} + +func isRange(ports policyk8sawsv1.Port) bool { + return ports.EndPort != nil +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 9fb2cea..eb63026 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -23,7 +23,7 @@ var ( ANY_IP_PROTOCOL = 254 TRIE_KEY_LENGTH = 8 TRIE_V6_KEY_LENGTH = 20 - TRIE_VALUE_LENGTH = 288 + TRIE_VALUE_LENGTH = 384 BPF_PROGRAMS_PIN_PATH_DIRECTORY = "/sys/fs/bpf/globals/aws/programs/" BPF_MAPS_PIN_PATH_DIRECTORY = "/sys/fs/bpf/globals/aws/maps/" TC_INGRESS_PROG = "handle_ingress" @@ -142,7 +142,7 @@ func ComputeTrieKey(n net.IPNet, isIPv6Enabled bool) []byte { return key } -func ComputeTrieValue(l4Info []v1alpha1.Port, log logr.Logger, allowAll, denyAll bool) []byte { +func ComputeTrieValue(l4Info []v1alpha1.Port, denyL4Info []v1alpha1.Port, log logr.Logger, allowAll, denyAll bool) []byte { var startPort, endPort, protocol int value := make([]byte, TRIE_VALUE_LENGTH) @@ -160,7 +160,39 @@ func ComputeTrieValue(l4Info []v1alpha1.Port, log logr.Logger, allowAll, denyAll startOffset += 4 binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(endPort)) startOffset += 4 - log.Info("L4 values: ", "protocol: ", protocol, "startPort: ", startPort, "endPort: ", endPort) + val := 1 + if denyAll { + val = 0 + } + binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(val)) + startOffset += 4 + log.Info("L4 all values: ", "protocol: ", protocol, "startPort: ", startPort, "endPort: ", endPort) + } + + for _, denyL4 := range denyL4Info { + if startOffset >= TRIE_VALUE_LENGTH { + return value + } + endPort = 0 + startPort = 0 + + protocol = deriveProtocolValue(denyL4, allowAll, denyAll) + if denyL4.Port != nil { + startPort = int(*denyL4.Port) + } + + if denyL4.EndPort != nil { + endPort = int(*denyL4.EndPort) + } + log.Info("L4 deny values: ", "protocol: ", protocol, "startPort: ", startPort, "endPort: ", endPort) + binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(protocol)) + startOffset += 4 + binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(startPort)) + startOffset += 4 + binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(endPort)) + startOffset += 4 + binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(0)) + startOffset += 4 } for _, l4Entry := range l4Info { @@ -186,6 +218,8 @@ func ComputeTrieValue(l4Info []v1alpha1.Port, log logr.Logger, allowAll, denyAll startOffset += 4 binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(endPort)) startOffset += 4 + binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(1)) + startOffset += 4 } return value @@ -371,6 +405,7 @@ type BPFTrieVal struct { Protocol uint32 StartPort uint32 EndPort uint32 + Allow uint32 } func ConvTrieV6ToByte(key BPFTrieKeyV6) []byte { diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 24b2879..353ae89 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -330,7 +330,7 @@ func TestComputeTrieValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ComputeTrieValue(tt.args.Ports, test_utilsLogger, tt.args.allowAll, tt.args.denyAll) + got := ComputeTrieValue(tt.args.Ports, nil, test_utilsLogger, tt.args.allowAll, tt.args.denyAll) assert.Equal(t, tt.want, got) }) }