Skip to content

Handle overlaps in except and allow CIDRs #344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions pkg/ebpf/c/tc.v4egress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,24 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
bpf_ringbuf_output(&policy_events, &evt, sizeof(evt), 0);
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {

// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within the inclusive [start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.

if ((trie_val->protocol == ANY_IP_PROTOCOL &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are really complicate if check... Any chance we can take them out into a func? I think there is only one small different on ip-> for v4 and v6.

((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/ebpf/c/tc.v4ingress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,22 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within the inclusive [start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.
if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/ebpf/c/tc.v6egress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,22 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within the inclusive [start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.
if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/ebpf/c/tc.v6ingress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,22 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within the inclusive [start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.
if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
193 changes: 97 additions & 96 deletions pkg/fwruleprocessor/fw_rule_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func log() logger.Logger {
return logger.Get()
}

var CATCH_ALL_PROTOCOL corev1.Protocol = "ANY_IP_PROTOCOL"
var (
CATCH_ALL_PROTOCOL corev1.Protocol = "ANY_IP_PROTOCOL"
DENY_ALL_PROTOCOL corev1.Protocol = "RESERVED_IP_PROTOCOL_NUMBER"
)

type EbpfFirewallRules struct {
IPCidr v1alpha1.NetworkAddress
Expand All @@ -42,13 +45,31 @@ func NewFirewallRuleProcessor(nodeIP string, hostMask string, enableIPv6 bool) *
return fwrp
}

// computeMapEntriesFromEndpointRules generates a map of IP prefix keys to encoded L4 rules that will
// be used to update ebpf maps
//
// How it works:
// 1. A default allow-all entry is added for the node IP to ensure local node traffic is always permitted.
// 2. The list of firewall rules is sorted by prefix length in ascending order. This is crucial for
// handling overlapping CIDRs because longest-prefix matches win in LPM TRIE.
// 3. Each rule is normalized:
// - Ensures all entries contain a /mask (using hostMask if omitted).
// - Filters out IPv4 rules in IPv6 clusters and vice versa.
// - For rules without any L4 port info, a catch-all rule is inserted to match all traffic.
// 4. For any rule whose CIDR is more specific (e.g., /24) and falls within a broader one (e.g., /16),
// we check existing rules in the map to see if it matches a prior CIDR. If it does and is not part of
// that CIDR's "except" list, we inherit the broader rule's ports into the current one.
// This ensures that the specific CIDR behaves consistently with the broader scope's intent.
// 5. We then handle all `except` CIDRs at the end.
// - If not already in the map, each `except` CIDR is added explicitly with a deny-all L4 entry.
// - This ensures specific excluded IP ranges override broader allow rules correctly in the LPM match tree.
// 6. Finally, all CIDRs are encoded into trie keys and their corresponding merged/derived L4 info is encoded
// into the values, forming the output map.

func (f *FirewallRuleProcessor) ComputeMapEntriesFromEndpointRules(firewallRules []EbpfFirewallRules) (map[string][]byte, error) {

firewallMap := make(map[string][]byte)
ipCIDRs := make(map[string][]v1alpha1.Port)
nonHostCIDRs := make(map[string][]v1alpha1.Port)
isCatchAllIPEntryPresent, allowAll := false, false
var catchAllIPPorts []v1alpha1.Port
cidrsMap := make(map[string]EbpfFirewallRules)

//Traffic from the local node should always be allowed. Add NodeIP by default to map entries.
_, mapKey, _ := net.ParseCIDR(f.nodeIP + f.hostMask)
Expand All @@ -59,16 +80,6 @@ func (f *FirewallRuleProcessor) ComputeMapEntriesFromEndpointRules(firewallRules
//Sort the rules
sortFirewallRulesByPrefixLength(firewallRules, f.hostMask)

//Check and aggregate L4 Port Info for Catch All Entries.
catchAllIPPorts, isCatchAllIPEntryPresent, allowAll = f.checkAndDeriveCatchAllIPPorts(firewallRules)
if isCatchAllIPEntryPresent {
//Add the Catch All IP entry
_, mapKey, _ := net.ParseCIDR("0.0.0.0/0")
key := utils.ComputeTrieKey(*mapKey, f.enableIPv6)
value := utils.ComputeTrieValue(catchAllIPPorts, allowAll, false)
firewallMap[string(key)] = value
}

for _, firewallRule := range firewallRules {
var cidrL4Info []v1alpha1.Port

Expand All @@ -90,94 +101,62 @@ func (f *FirewallRuleProcessor) ComputeMapEntriesFromEndpointRules(firewallRules
continue
}

if !utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) {
if len(firewallRule.L4Info) == 0 {
addCatchAllL4Entry(&firewallRule)
}
if utils.IsNonHostCIDR(string(firewallRule.IPCidr)) {
existingL4Info, ok := nonHostCIDRs[string(firewallRule.IPCidr)]
if ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingL4Info...)
} else {
// Check if the /m entry is part of any /n CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /m will always win out.
cidrL4Info = checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
}
nonHostCIDRs[string(firewallRule.IPCidr)] = firewallRule.L4Info
} else {
if existingL4Info, ok := ipCIDRs[string(firewallRule.IPCidr)]; ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingL4Info...)
}
// Check if the /32 entry is part of any non host CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /32 will always win out.
cidrL4Info = checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
ipCIDRs[string(firewallRule.IPCidr)] = firewallRule.L4Info
}
//Include port and protocol combination paired with catch all entries
firewallRule.L4Info = append(firewallRule.L4Info, catchAllIPPorts...)

log().Infof("Updating Map with IP Key: %s", string(firewallRule.IPCidr))
_, firewallMapKey, _ := net.ParseCIDR(string(firewallRule.IPCidr))
// Key format: Prefix length (4 bytes) followed by 4/16byte IP address
firewallKey := utils.ComputeTrieKey(*firewallMapKey, f.enableIPv6)

if len(firewallRule.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info)
firewallRule.L4Info = mergedL4Info
// If no L4 specified add catch all entry
if len(firewallRule.L4Info) == 0 {
log().Debugf("No L4 specified. Add Catch all entry CIDR: %s", string(firewallRule.IPCidr))
addCatchAllL4Entry(&firewallRule)
log().Debugf("Total L4 entries count: %d", len(firewallRule.L4Info))
}

if existingFirewallRuleInfo, ok := cidrsMap[string(firewallRule.IPCidr)]; ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingFirewallRuleInfo.L4Info...)
firewallRule.Except = append(firewallRule.Except, existingFirewallRuleInfo.Except...)
} else {
// Check if the /m entry is part of any /n CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /m will always win out.
cidrL4Info = checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), cidrsMap)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
firewallValue := utils.ComputeTrieValue(firewallRule.L4Info, allowAll, false)
firewallMap[string(firewallKey)] = firewallValue
}
cidrsMap[string(firewallRule.IPCidr)] = firewallRule
}

// Go through except CIDRs and append DENY all rule to the L4 info
for _, firewallRule := range firewallRules {
if firewallRule.Except != nil {
for _, exceptCIDR := range firewallRule.Except {
_, mapKey, _ := net.ParseCIDR(string(exceptCIDR))
key := utils.ComputeTrieKey(*mapKey, f.enableIPv6)
log().Infof("Parsed Except CIDR IP Key: %s", mapKey.String())
if len(firewallRule.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info)
firewallRule.L4Info = mergedL4Info
for _, exceptCidr := range firewallRule.Except {
if _, ok := cidrsMap[string(exceptCidr)]; !ok {
exceptFirewall := EbpfFirewallRules{
IPCidr: exceptCidr,
Except: []v1alpha1.NetworkAddress{},
L4Info: []v1alpha1.Port{},
}
addDenyAllL4Entry(&exceptFirewall)
cidrsMap[string(exceptCidr)] = exceptFirewall
}
value := utils.ComputeTrieValue(firewallRule.L4Info, false, true)
firewallMap[string(key)] = value
log().Debugf("Parsed Except CIDR:", string(exceptCidr))
}
}
}

return firewallMap, nil
}
for key, value := range cidrsMap {
log().Infof("Updating Map with IP Key: %s", string(key))
_, firewallMapKey, _ := net.ParseCIDR(string(key))
// Key format: Prefix length (4 bytes) followed by 4/16byte IP address
firewallKey := utils.ComputeTrieKey(*firewallMapKey, f.enableIPv6)

if len(value.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(value.L4Info)
value.L4Info = mergedL4Info

func (f *FirewallRuleProcessor) checkAndDeriveCatchAllIPPorts(firewallRules []EbpfFirewallRules) ([]v1alpha1.Port, bool, bool) {
var catchAllL4Info []v1alpha1.Port
isCatchAllIPEntryPresent := false
allowAllPortAndProtocols := false
for _, firewallRule := range firewallRules {
if !strings.Contains(string(firewallRule.IPCidr), "/") {
firewallRule.IPCidr += v1alpha1.NetworkAddress(f.hostMask)
}
if !f.enableIPv6 && strings.Contains(string(firewallRule.IPCidr), "::") {
log().Debug("IPv6 catch all entry in IPv4 mode - skip ")
continue
}
if utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) {
catchAllL4Info = append(catchAllL4Info, firewallRule.L4Info...)
isCatchAllIPEntryPresent = true
if len(firewallRule.L4Info) == 0 {
//All ports and protocols
allowAllPortAndProtocols = true
}
}
firewallValue := utils.ComputeTrieValue(value.L4Info, false, false)
firewallMap[string(firewallKey)] = firewallValue
}
log().Debugf("Total L4 entry count for catch all entry: count: %d", len(catchAllL4Info))
return catchAllL4Info, isCatchAllIPEntryPresent, allowAllPortAndProtocols

return firewallMap, nil
}

// sorting Firewall Rules in Ascending Order of Prefix length
Expand Down Expand Up @@ -212,16 +191,38 @@ func addCatchAllL4Entry(firewallRule *EbpfFirewallRules) {
firewallRule.L4Info = append(firewallRule.L4Info, catchAllL4Entry)
}

func addDenyAllL4Entry(firewallRule *EbpfFirewallRules) {
denyAllL4Entry := v1alpha1.Port{
Protocol: &DENY_ALL_PROTOCOL,
}
firewallRule.L4Info = append(firewallRule.L4Info, denyAllL4Entry)
}

func checkAndDeriveL4InfoFromAnyMatchingCIDRs(firewallRule string,
nonHostCIDRs map[string][]v1alpha1.Port) []v1alpha1.Port {
cidrsMap map[string]EbpfFirewallRules) []v1alpha1.Port {
var matchingCIDRL4Info []v1alpha1.Port

_, ipToCheck, _ := net.ParseCIDR(firewallRule)
for nonHostCIDR, l4Info := range nonHostCIDRs {
_, cidrEntry, _ := net.ParseCIDR(nonHostCIDR)
for cidr, cidrFirewallInfo := range cidrsMap {
if !utils.IsNonHostCIDR(cidr) {
continue
}
_, cidrEntry, _ := net.ParseCIDR(cidr)
if cidrEntry.Contains(ipToCheck.IP) {
log().Debugf("Found a CIDR match for IP: %s in CIDR %s ", firewallRule, nonHostCIDR)
matchingCIDRL4Info = append(matchingCIDRL4Info, l4Info...)
log().Debugf("Found CIDR match or IP: %s in CIDR: %s", firewallRule, cidr)
// If CIDR contains IP, check if it is part of any except block under CIDR. If yes, do not include cidrL4Info
foundInExcept := false
for _, except := range cidrFirewallInfo.Except {
_, exceptEntry, _ := net.ParseCIDR(string(except))
if exceptEntry.Contains(ipToCheck.IP) {
foundInExcept = true
log().Debugf("Found IP: %s in except block %s of CIDR %s. Skipping CIDR match", firewallRule, string(except), cidr)
break
}
}
if !foundInExcept {
matchingCIDRL4Info = append(matchingCIDRL4Info, cidrFirewallInfo.L4Info...)
}
}
}
return matchingCIDRL4Info
Expand Down
Loading
Loading