Skip to content
22 changes: 18 additions & 4 deletions pkg/ebpf/c/tc.v4egress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,24 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
bpf_ringbuf_output(&policy_events, &evt, sizeof(evt), 0);
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {

// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within (start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.

if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/ebpf/c/tc.v4ingress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,22 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within (start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.
if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/ebpf/c/tc.v6egress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,22 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within (start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.
if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/ebpf/c/tc.v6ingress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,22 @@ static inline int evaluateByLookUp(struct keystruct trie_key, struct conntrack_k
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
// 1. ANY_IP_PROTOCOL:
// - If the rule specifies ANY_IP_PROTOCOL (i.e., applies to all L4 protocols),
// - Then match if:
// - start_port is ANY_PORT → rule applies to all ports
// - OR l4_dst_port is exactly the start_port
// - OR l4_dst_port falls within (start_port, end_port] range
//
// 2. Specific Protocol Match:
// - If trie_val->protocol matches the packet's IP protocol (e.g., TCP or UDP),
// - Then apply the same port match logic as above.
if ((trie_val->protocol == ANY_IP_PROTOCOL &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) ||
(trie_val->protocol == ip->nexthdr &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, should this be inclusive as well?

//Inject in to conntrack map
struct conntrack_value new_flow_val = {};
if (pst->state == DEFAULT_ALLOW) {
Expand Down
196 changes: 98 additions & 98 deletions pkg/fwruleprocessor/fw_rule_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func log() logger.Logger {
return logger.Get()
}

var CATCH_ALL_PROTOCOL corev1.Protocol = "ANY_IP_PROTOCOL"
var (
CATCH_ALL_PROTOCOL corev1.Protocol = "ANY_IP_PROTOCOL"
DENY_ALL_PROTOCOL corev1.Protocol = "RESERVED_IP_PROTOCOL_NUMBER"
)

type EbpfFirewallRules struct {
IPCidr v1alpha1.NetworkAddress
Expand All @@ -42,13 +45,33 @@ func NewFirewallRuleProcessor(nodeIP string, hostMask string, enableIPv6 bool) *
return fwrp
}

// computeMapEntriesFromEndpointRules generates a map of IP prefix keys to encoded L4 rules that will
// be used to update ebpf maps
//
// How it works:
// 1. A default allow-all entry is added for the node IP to ensure local node traffic is always permitted.
// 2. The list of firewall rules is sorted by prefix length in ascending order. This is crucial for
// handling overlapping CIDRs because longest-prefix matches win in LPM TRIE.
// 3. Each rule is normalized:
// - Ensures all entries contain a /mask (using hostMask if omitted).
// - Filters out IPv4 rules in IPv6 clusters and vice versa.
// - For rules without any L4 port info, a catch-all rule is inserted to match all traffic.
// 4. For any rule whose CIDR is more specific (e.g., /24) and falls within a broader one (e.g., /16),
// we check existing rules in the map to see if it matches a prior CIDR. If it does and is not part of
// that CIDR's "except" list, we inherit the broader rule's ports into the current one.
// This ensures that the specific CIDR behaves consistently with the broader scope's intent.
// 5. We then handle all `except` CIDRs at the end.
// - If not already in the map, each `except` CIDR is added explicitly with a deny-all L4 entry.
// - This ensures specific excluded IP ranges override broader allow rules correctly in the LPM match tree.
// 6. Finally, all CIDRs are encoded into trie keys and their corresponding merged/derived L4 info is encoded
// into the values, forming the output map.

func (f *FirewallRuleProcessor) ComputeMapEntriesFromEndpointRules(firewallRules []EbpfFirewallRules) (map[string][]byte, error) {

firewallMap := make(map[string][]byte)
ipCIDRs := make(map[string][]v1alpha1.Port)
nonHostCIDRs := make(map[string][]v1alpha1.Port)
isCatchAllIPEntryPresent, allowAll := false, false
var catchAllIPPorts []v1alpha1.Port
cidrsMap := make(map[string]EbpfFirewallRules)
exceptCidrs := make(map[string]struct{})
nonHostCIDRs := make(map[string]EbpfFirewallRules)

//Traffic from the local node should always be allowed. Add NodeIP by default to map entries.
_, mapKey, _ := net.ParseCIDR(f.nodeIP + f.hostMask)
Expand All @@ -59,17 +82,12 @@ func (f *FirewallRuleProcessor) ComputeMapEntriesFromEndpointRules(firewallRules
//Sort the rules
sortFirewallRulesByPrefixLength(firewallRules, f.hostMask)

//Check and aggregate L4 Port Info for Catch All Entries.
catchAllIPPorts, isCatchAllIPEntryPresent, allowAll = f.checkAndDeriveCatchAllIPPorts(firewallRules)
if isCatchAllIPEntryPresent {
//Add the Catch All IP entry
_, mapKey, _ := net.ParseCIDR("0.0.0.0/0")
key := utils.ComputeTrieKey(*mapKey, f.enableIPv6)
value := utils.ComputeTrieValue(catchAllIPPorts, allowAll, false)
firewallMap[string(key)] = value
}

for _, firewallRule := range firewallRules {
// Keep track of except CIDRs to handle later
for _, exceptCidr := range firewallRule.Except {
exceptCidrs[string(exceptCidr)] = struct{}{}
}

var cidrL4Info []v1alpha1.Port

if !strings.Contains(string(firewallRule.IPCidr), "/") {
Expand All @@ -90,94 +108,57 @@ func (f *FirewallRuleProcessor) ComputeMapEntriesFromEndpointRules(firewallRules
continue
}

if !utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) {
if len(firewallRule.L4Info) == 0 {
addCatchAllL4Entry(&firewallRule)
}
if utils.IsNonHostCIDR(string(firewallRule.IPCidr)) {
existingL4Info, ok := nonHostCIDRs[string(firewallRule.IPCidr)]
if ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingL4Info...)
} else {
// Check if the /m entry is part of any /n CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /m will always win out.
cidrL4Info = checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
}
nonHostCIDRs[string(firewallRule.IPCidr)] = firewallRule.L4Info
} else {
if existingL4Info, ok := ipCIDRs[string(firewallRule.IPCidr)]; ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingL4Info...)
}
// Check if the /32 entry is part of any non host CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /32 will always win out.
cidrL4Info = checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
ipCIDRs[string(firewallRule.IPCidr)] = firewallRule.L4Info
}
//Include port and protocol combination paired with catch all entries
firewallRule.L4Info = append(firewallRule.L4Info, catchAllIPPorts...)

log().Debugf("Updating Map with IP Key: %s", string(firewallRule.IPCidr))
_, firewallMapKey, _ := net.ParseCIDR(string(firewallRule.IPCidr))
// Key format: Prefix length (4 bytes) followed by 4/16byte IP address
firewallKey := utils.ComputeTrieKey(*firewallMapKey, f.enableIPv6)

if len(firewallRule.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info)
firewallRule.L4Info = mergedL4Info
// If no L4 specified add catch all entry
if len(firewallRule.L4Info) == 0 {
log().Debugf("No L4 specified. Add Catch all entry CIDR: %s", string(firewallRule.IPCidr))
addCatchAllL4Entry(&firewallRule)
}

if existingFirewallRuleInfo, ok := cidrsMap[string(firewallRule.IPCidr)]; ok {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post merge, lets review if we will ever enter into this if block. Since the new check added to the function will be executed only for the new IPCidr key but not for duplicate in the main for loop..

firewallRule.L4Info = append(firewallRule.L4Info, existingFirewallRuleInfo.L4Info...)
firewallRule.Except = append(firewallRule.Except, existingFirewallRuleInfo.Except...)
} else {
// Check if the /m entry is part of any /n CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /m will always win out.
cidrL4Info = checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
firewallValue := utils.ComputeTrieValue(firewallRule.L4Info, allowAll, false)
firewallMap[string(firewallKey)] = firewallValue
}
if firewallRule.Except != nil {
for _, exceptCIDR := range firewallRule.Except {
_, mapKey, _ := net.ParseCIDR(string(exceptCIDR))
key := utils.ComputeTrieKey(*mapKey, f.enableIPv6)
log().Debugf("Parsed Except CIDR IP Key: %s", mapKey.String())
if len(firewallRule.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info)
firewallRule.L4Info = mergedL4Info
}
value := utils.ComputeTrieValue(firewallRule.L4Info, false, true)
firewallMap[string(key)] = value
cidrsMap[string(firewallRule.IPCidr)] = firewallRule
if utils.IsNonHostCIDR(string(firewallRule.IPCidr)) {
nonHostCIDRs[string(firewallRule.IPCidr)] = firewallRule
}
}

// Go through except CIDRs and append DENY all rule to the L4 info
for exceptCidr := range exceptCidrs {
if _, ok := cidrsMap[exceptCidr]; !ok {
exceptFirewall := EbpfFirewallRules{
IPCidr: v1alpha1.NetworkAddress(exceptCidr),
Except: []v1alpha1.NetworkAddress{},
L4Info: []v1alpha1.Port{},
}
addDenyAllL4Entry(&exceptFirewall)
cidrsMap[exceptCidr] = exceptFirewall
}
log().Debugf("Parsed Except CIDR: %s", exceptCidr)
}

return firewallMap, nil
}
for key, value := range cidrsMap {
log().Infof("Updating Map with IP Key: %s", string(key))
_, firewallMapKey, _ := net.ParseCIDR(string(key))
// Key format: Prefix length (4 bytes) followed by 4/16byte IP address
firewallKey := utils.ComputeTrieKey(*firewallMapKey, f.enableIPv6)

func (f *FirewallRuleProcessor) checkAndDeriveCatchAllIPPorts(firewallRules []EbpfFirewallRules) ([]v1alpha1.Port, bool, bool) {
var catchAllL4Info []v1alpha1.Port
isCatchAllIPEntryPresent := false
allowAllPortAndProtocols := false
for _, firewallRule := range firewallRules {
if !strings.Contains(string(firewallRule.IPCidr), "/") {
firewallRule.IPCidr += v1alpha1.NetworkAddress(f.hostMask)
}
if !f.enableIPv6 && strings.Contains(string(firewallRule.IPCidr), "::") {
log().Debug("IPv6 catch all entry in IPv4 mode - skip ")
continue
}
if utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) {
catchAllL4Info = append(catchAllL4Info, firewallRule.L4Info...)
isCatchAllIPEntryPresent = true
if len(firewallRule.L4Info) == 0 {
//All ports and protocols
allowAllPortAndProtocols = true
}
if len(value.L4Info) != 0 {
value.L4Info = mergeDuplicateL4Info(value.L4Info)
}
firewallMap[string(firewallKey)] = utils.ComputeTrieValue(value.L4Info, false, false)
}
log().Debugf("Total L4 entry count for catch all entry: count: %d", len(catchAllL4Info))
return catchAllL4Info, isCatchAllIPEntryPresent, allowAllPortAndProtocols

return firewallMap, nil
}

// sorting Firewall Rules in Ascending Order of Prefix length
Expand Down Expand Up @@ -212,16 +193,35 @@ func addCatchAllL4Entry(firewallRule *EbpfFirewallRules) {
firewallRule.L4Info = append(firewallRule.L4Info, catchAllL4Entry)
}

func addDenyAllL4Entry(firewallRule *EbpfFirewallRules) {
denyAllL4Entry := v1alpha1.Port{
Protocol: &DENY_ALL_PROTOCOL,
}
firewallRule.L4Info = append(firewallRule.L4Info, denyAllL4Entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: combine both lines to one line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep it 2 lines here for readability if that works

}

func checkAndDeriveL4InfoFromAnyMatchingCIDRs(firewallRule string,
nonHostCIDRs map[string][]v1alpha1.Port) []v1alpha1.Port {
cidrsMap map[string]EbpfFirewallRules) []v1alpha1.Port {
var matchingCIDRL4Info []v1alpha1.Port

_, ipToCheck, _ := net.ParseCIDR(firewallRule)
for nonHostCIDR, l4Info := range nonHostCIDRs {
_, cidrEntry, _ := net.ParseCIDR(nonHostCIDR)
for cidr, cidrFirewallInfo := range cidrsMap {
_, cidrEntry, _ := net.ParseCIDR(cidr)
if cidrEntry.Contains(ipToCheck.IP) {
log().Debugf("Found a CIDR match for IP: %s in CIDR %s ", firewallRule, nonHostCIDR)
matchingCIDRL4Info = append(matchingCIDRL4Info, l4Info...)
log().Debugf("Found CIDR match or IP: %s in CIDR: %s", firewallRule, cidr)
// If CIDR contains IP, check if it is part of any except block under CIDR. If yes, do not include cidrL4Info
foundInExcept := false
for _, except := range cidrFirewallInfo.Except {
_, exceptEntry, _ := net.ParseCIDR(string(except))
if exceptEntry.Contains(ipToCheck.IP) {
foundInExcept = true
log().Debugf("Found IP: %s in except block %s of CIDR %s. Skipping CIDR match", firewallRule, string(except), cidr)
break
}
}
if !foundInExcept {
matchingCIDRL4Info = append(matchingCIDRL4Info, cidrFirewallInfo.L4Info...)
}
}
}
return matchingCIDRL4Info
Expand Down
Loading
Loading