Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle overlaps in except and allow CIDRs #344

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 73 additions & 99 deletions pkg/ebpf/bpf_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var (
CONNTRACK_MAP_PIN_PATH = "/sys/fs/bpf/globals/aws/maps/global_aws_conntrack_map"
POLICY_EVENTS_MAP_PIN_PATH = "/sys/fs/bpf/globals/aws/maps/global_policy_events"
CATCH_ALL_PROTOCOL corev1.Protocol = "ANY_IP_PROTOCOL"
DENY_ALL_PROTOCOL corev1.Protocol = "RESERVED_IP_PROTOCOL_NUMBER"
POD_VETH_PREFIX = "eni"
)

Expand Down Expand Up @@ -856,10 +857,7 @@ func mergeDuplicateL4Info(ports []v1alpha1.Port) []v1alpha1.Port {
func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirewallRules) (map[string][]byte, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an enhancement, could you document the behavior of this function as function doc string itself. We pack a lot inside the function here.


firewallMap := make(map[string][]byte)
ipCIDRs := make(map[string][]v1alpha1.Port)
nonHostCIDRs := make(map[string][]v1alpha1.Port)
isCatchAllIPEntryPresent, allowAll := false, false
var catchAllIPPorts []v1alpha1.Port
cidrsMap := make(map[string]EbpfFirewallRules)

//Traffic from the local node should always be allowed. Add NodeIP by default to map entries.
_, mapKey, _ := net.ParseCIDR(l.nodeIP + l.hostMask)
Expand All @@ -870,16 +868,6 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew
//Sort the rules
sortFirewallRulesByPrefixLength(firewallRules, l.hostMask)

//Check and aggregate L4 Port Info for Catch All Entries.
catchAllIPPorts, isCatchAllIPEntryPresent, allowAll = l.checkAndDeriveCatchAllIPPorts(firewallRules)
if isCatchAllIPEntryPresent {
//Add the Catch All IP entry
_, mapKey, _ := net.ParseCIDR("0.0.0.0/0")
key := utils.ComputeTrieKey(*mapKey, l.enableIPv6)
value := utils.ComputeTrieValue(catchAllIPPorts, l.logger, allowAll, false)
firewallMap[string(key)] = value
}

for _, firewallRule := range firewallRules {
var cidrL4Info []v1alpha1.Port

Expand All @@ -901,110 +889,89 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew
continue
}

if !utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) {
if len(firewallRule.L4Info) == 0 {
l.logger.Info("No L4 specified. Add Catch all entry: ", "CIDR: ", firewallRule.IPCidr)
l.addCatchAllL4Entry(&firewallRule)
l.logger.Info("Total L4 entries ", "count: ", len(firewallRule.L4Info))
}
if utils.IsNonHostCIDR(string(firewallRule.IPCidr)) {
existingL4Info, ok := nonHostCIDRs[string(firewallRule.IPCidr)]
if ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingL4Info...)
} else {
// Check if the /m entry is part of any /n CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /m will always win out.
cidrL4Info = l.checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
}
nonHostCIDRs[string(firewallRule.IPCidr)] = firewallRule.L4Info
} else {
if existingL4Info, ok := ipCIDRs[string(firewallRule.IPCidr)]; ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingL4Info...)
}
// Check if the /32 entry is part of any non host CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /32 will always win out.
cidrL4Info = l.checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), nonHostCIDRs)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
ipCIDRs[string(firewallRule.IPCidr)] = firewallRule.L4Info
}
//Include port and protocol combination paired with catch all entries
firewallRule.L4Info = append(firewallRule.L4Info, catchAllIPPorts...)

l.logger.Info("Updating Map with ", "IP Key:", firewallRule.IPCidr)
_, firewallMapKey, _ := net.ParseCIDR(string(firewallRule.IPCidr))
// Key format: Prefix length (4 bytes) followed by 4/16byte IP address
firewallKey := utils.ComputeTrieKey(*firewallMapKey, l.enableIPv6)

if len(firewallRule.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info)
firewallRule.L4Info = mergedL4Info
// If no L4 specified add catch all entry
if len(firewallRule.L4Info) == 0 {
l.logger.Info("No L4 specified. Add Catch all entry: ", "CIDR: ", firewallRule.IPCidr)
l.addCatchAllL4Entry(&firewallRule)
l.logger.Info("Total L4 entries ", "count: ", len(firewallRule.L4Info))
}

if existingFirewallRuleInfo, ok := cidrsMap[string(firewallRule.IPCidr)]; ok {
firewallRule.L4Info = append(firewallRule.L4Info, existingFirewallRuleInfo.L4Info...)
firewallRule.Except = append(firewallRule.Except, existingFirewallRuleInfo.Except...)
} else {
// Check if the /m entry is part of any /n CIDRs that we've encountered so far
// If found, we need to include the port and protocol combination against the current entry as well since
// we use LPM TRIE map and the /m will always win out.
cidrL4Info = l.checkAndDeriveL4InfoFromAnyMatchingCIDRs(string(firewallRule.IPCidr), cidrsMap)
if len(cidrL4Info) > 0 {
firewallRule.L4Info = append(firewallRule.L4Info, cidrL4Info...)
}
firewallValue := utils.ComputeTrieValue(firewallRule.L4Info, l.logger, allowAll, false)
firewallMap[string(firewallKey)] = firewallValue
}
cidrsMap[string(firewallRule.IPCidr)] = firewallRule
}

// Go through except CIDRs and append DENY all rule to the L4 info
for _, firewallRule := range firewallRules {
if firewallRule.Except != nil {
for _, exceptCIDR := range firewallRule.Except {
_, mapKey, _ := net.ParseCIDR(string(exceptCIDR))
key := utils.ComputeTrieKey(*mapKey, l.enableIPv6)
l.logger.Info("Parsed Except CIDR", "IP Key: ", mapKey)
if len(firewallRule.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(firewallRule.L4Info)
firewallRule.L4Info = mergedL4Info
for _, exceptCidr := range firewallRule.Except {
if _, ok := cidrsMap[string(exceptCidr)]; !ok {
exceptFirewall := EbpfFirewallRules{
IPCidr: exceptCidr,
Except: []v1alpha1.NetworkAddress{},
L4Info: []v1alpha1.Port{},
}
l.addDenyAllL4Entry(&exceptFirewall)
cidrsMap[string(exceptCidr)] = exceptFirewall
}
value := utils.ComputeTrieValue(firewallRule.L4Info, l.logger, false, true)
firewallMap[string(key)] = value
l.logger.Info("Parsed Except CIDR", "IP Key: ", string(exceptCidr))
}
}
}

return firewallMap, nil
}
for key, value := range cidrsMap {
l.logger.Info("Updating Map with ", "IP Key:", key)
_, firewallMapKey, _ := net.ParseCIDR(string(key))
// Key format: Prefix length (4 bytes) followed by 4/16byte IP address
firewallKey := utils.ComputeTrieKey(*firewallMapKey, l.enableIPv6)

if len(value.L4Info) != 0 {
mergedL4Info := mergeDuplicateL4Info(value.L4Info)
value.L4Info = mergedL4Info

func (l *bpfClient) checkAndDeriveCatchAllIPPorts(firewallRules []EbpfFirewallRules) ([]v1alpha1.Port, bool, bool) {
var catchAllL4Info []v1alpha1.Port
isCatchAllIPEntryPresent := false
allowAllPortAndProtocols := false
for _, firewallRule := range firewallRules {
if !strings.Contains(string(firewallRule.IPCidr), "/") {
firewallRule.IPCidr += v1alpha1.NetworkAddress(l.hostMask)
}
if !l.enableIPv6 && strings.Contains(string(firewallRule.IPCidr), "::") {
l.logger.Info("IPv6 catch all entry in IPv4 mode - skip ")
continue
}
if utils.IsCatchAllIPEntry(string(firewallRule.IPCidr)) {
catchAllL4Info = append(catchAllL4Info, firewallRule.L4Info...)
isCatchAllIPEntryPresent = true
if len(firewallRule.L4Info) == 0 {
//All ports and protocols
allowAllPortAndProtocols = true
}
}
l.logger.Info("Current L4 entry count for catch all entry: ", "count: ", len(catchAllL4Info))
firewallValue := utils.ComputeTrieValue(value.L4Info, l.logger, false, false)
firewallMap[string(firewallKey)] = firewallValue
}
l.logger.Info("Total L4 entry count for catch all entry: ", "count: ", len(catchAllL4Info))
return catchAllL4Info, isCatchAllIPEntryPresent, allowAllPortAndProtocols

return firewallMap, nil
}

func (l *bpfClient) checkAndDeriveL4InfoFromAnyMatchingCIDRs(firewallRule string,
nonHostCIDRs map[string][]v1alpha1.Port) []v1alpha1.Port {
cidrsMap map[string]EbpfFirewallRules) []v1alpha1.Port {
var matchingCIDRL4Info []v1alpha1.Port

_, ipToCheck, _ := net.ParseCIDR(firewallRule)
for nonHostCIDR, l4Info := range nonHostCIDRs {
_, cidrEntry, _ := net.ParseCIDR(nonHostCIDR)
l.logger.Info("CIDR match: ", "for IP: ", firewallRule, "in CIDR: ", nonHostCIDR)
for cidr, cidrFirewallInfo := range cidrsMap {
if !utils.IsNonHostCIDR(cidr) {
continue
}
_, cidrEntry, _ := net.ParseCIDR(cidr)
if cidrEntry.Contains(ipToCheck.IP) {
l.logger.Info("Found a CIDR match: ", "for IP: ", firewallRule, "in CIDR: ", nonHostCIDR)
matchingCIDRL4Info = append(matchingCIDRL4Info, l4Info...)
l.logger.Info("Found CIDR match: ", "for IP: ", firewallRule, "in CIDR: ", cidr)
// If CIDR contains IP, check if it is part of any except block under CIDR. If yes, do not include cidrL4Info
foundInExcept := false
for _, except := range cidrFirewallInfo.Except {
_, exceptEntry, _ := net.ParseCIDR(string(except))
if exceptEntry.Contains(ipToCheck.IP) {
foundInExcept = true
l.logger.Info("Found IP: ", firewallRule, " in except block ", string(except), " of CIDR ", cidr, " . Skipping CIDR match")
break
}
}
if !foundInExcept {
matchingCIDRL4Info = append(matchingCIDRL4Info, cidrFirewallInfo.L4Info...)
}
}
}
return matchingCIDRL4Info
Expand All @@ -1017,6 +984,13 @@ func (l *bpfClient) addCatchAllL4Entry(firewallRule *EbpfFirewallRules) {
firewallRule.L4Info = append(firewallRule.L4Info, catchAllL4Entry)
}

func (l *bpfClient) addDenyAllL4Entry(firewallRule *EbpfFirewallRules) {
denyAllL4Entry := v1alpha1.Port{
Protocol: &DENY_ALL_PROTOCOL,
}
firewallRule.L4Info = append(firewallRule.L4Info, denyAllL4Entry)
}

func (l *bpfClient) DeletePodFromIngressProgPodCaches(podName string, podNamespace string) {
podNamespacedName := utils.GetPodNamespacedName(podName, podNamespace)
if progFD, ok := l.IngressPodToProgMap.Load(podNamespacedName); ok {
Expand Down
113 changes: 10 additions & 103 deletions pkg/ebpf/bpf_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,31 +175,18 @@ func TestBpfClient_IsEBPFProbeAttached(t *testing.T) {
}
}

func TestBpfClient_CheckAndDeriveCatchAllIPPorts(t *testing.T) {
func TestBpfClient_CheckAndDeriveL4InfoFromAnyMatchingCIDRs(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more test cases here that will cover for the scenarios described in the description of the PR are handled by the test case.

192.168.0.0/16 - allow on port 3306. This is part of 0.0.0.0/0, so check if it is part of any except under this cidr. If yes do not add ports of 0.0.0.0/0 to 192.168.0.0/16. If no, add ports. 

Handle all except at end. For every except key, check if entry is already present in cidrsMap. If yes, there is some explicit allow and everything else will be default deny so no change required. If no, we need to add deny all entry to make sure we deny all traffic and do not allow it if it is part of some other superset CIDR

The code change and the logic looks good to me. As noted in the review of the design doc, I wanted to see if there is any simpler way, but handling all except at the end looks alright.

Coverage with test case can help us push this PR to finish line here. Thank you.

protocolTCP := corev1.ProtocolTCP
var port80 int32 = 80

type want struct {
catchAllL4Info []v1alpha1.Port
isCatchAllIPEntryPresent bool
allowAllPortAndProtocols bool
}

l4InfoWithCatchAllEntry := []EbpfFirewallRules{
{
IPCidr: "0.0.0.0/0",
L4Info: []v1alpha1.Port{
{
Protocol: &protocolTCP,
Port: &port80,
},
},
},
matchingCIDRL4Info []v1alpha1.Port
}

l4InfoWithNoCatchAllEntry := []EbpfFirewallRules{
{
IPCidr: "1.1.1.1/32",
sampleCidrsMap := map[string]EbpfFirewallRules{
"1.1.1.0/24": {
IPCidr: "1.1.1.0/24",
Except: []v1alpha1.NetworkAddress{},
L4Info: []v1alpha1.Port{
{
Protocol: &protocolTCP,
Expand All @@ -209,96 +196,16 @@ func TestBpfClient_CheckAndDeriveCatchAllIPPorts(t *testing.T) {
},
}

l4InfoWithCatchAllEntryAndAllProtocols := []EbpfFirewallRules{
{
IPCidr: "0.0.0.0/0",
},
}

tests := []struct {
name string
firewallRules []EbpfFirewallRules
want want
}{
{
name: "Catch All Entry present",
firewallRules: l4InfoWithCatchAllEntry,
want: want{
catchAllL4Info: []v1alpha1.Port{
{
Protocol: &protocolTCP,
Port: &port80,
},
},
isCatchAllIPEntryPresent: true,
allowAllPortAndProtocols: false,
},
},

{
name: "No Catch All Entry present",
firewallRules: l4InfoWithNoCatchAllEntry,
want: want{
isCatchAllIPEntryPresent: false,
allowAllPortAndProtocols: false,
},
},

{
name: "Catch All Entry With no Port info",
firewallRules: l4InfoWithCatchAllEntryAndAllProtocols,
want: want{
isCatchAllIPEntryPresent: true,
allowAllPortAndProtocols: true,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testBpfClient := &bpfClient{
nodeIP: "10.1.1.1",
logger: logr.New(&log.NullLogSink{}),
enableIPv6: false,
hostMask: "/32",
IngressPodToProgMap: new(sync.Map),
EgressPodToProgMap: new(sync.Map),
}
gotCatchAllL4Info, gotIsCatchAllIPEntryPresent, gotAllowAllPortAndProtocols := testBpfClient.checkAndDeriveCatchAllIPPorts(tt.firewallRules)
assert.Equal(t, tt.want.catchAllL4Info, gotCatchAllL4Info)
assert.Equal(t, tt.want.isCatchAllIPEntryPresent, gotIsCatchAllIPEntryPresent)
assert.Equal(t, tt.want.allowAllPortAndProtocols, gotAllowAllPortAndProtocols)
})
}
}

func TestBpfClient_CheckAndDeriveL4InfoFromAnyMatchingCIDRs(t *testing.T) {
protocolTCP := corev1.ProtocolTCP
var port80 int32 = 80

type want struct {
matchingCIDRL4Info []v1alpha1.Port
}

sampleNonHostCIDRs := map[string][]v1alpha1.Port{
"1.1.1.0/24": {
{
Protocol: &protocolTCP,
Port: &port80,
},
},
}

tests := []struct {
name string
firewallRule string
nonHostCIDRs map[string][]v1alpha1.Port
cidrsMap map[string]EbpfFirewallRules
want want
}{
{
name: "Match Present",
firewallRule: "1.1.1.2/32",
nonHostCIDRs: sampleNonHostCIDRs,
cidrsMap: sampleCidrsMap,
want: want{
matchingCIDRL4Info: []v1alpha1.Port{
{
Expand All @@ -312,7 +219,7 @@ func TestBpfClient_CheckAndDeriveL4InfoFromAnyMatchingCIDRs(t *testing.T) {
{
name: "No Match",
firewallRule: "2.1.1.2/32",
nonHostCIDRs: sampleNonHostCIDRs,
cidrsMap: sampleCidrsMap,
want: want{},
},
}
Expand All @@ -327,7 +234,7 @@ func TestBpfClient_CheckAndDeriveL4InfoFromAnyMatchingCIDRs(t *testing.T) {
IngressPodToProgMap: new(sync.Map),
EgressPodToProgMap: new(sync.Map),
}
gotMatchingCIDRL4Info := testBpfClient.checkAndDeriveL4InfoFromAnyMatchingCIDRs(tt.firewallRule, tt.nonHostCIDRs)
gotMatchingCIDRL4Info := testBpfClient.checkAndDeriveL4InfoFromAnyMatchingCIDRs(tt.firewallRule, tt.cidrsMap)
assert.Equal(t, tt.want.matchingCIDRL4Info, gotMatchingCIDRL4Info)
})
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/ebpf/c/tc.v4egress.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ int handle_egress(struct __sk_buff *skb)
return BPF_DROP;
}

if ((trie_val->protocol == ANY_IP_PROTOCOL) || (trie_val->protocol == ip->protocol &&
if ((trie_val->protocol == ANY_IP_PROTOCOL && ((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enhancement Request: We can format this and explain the behavior of in the comment.
a) match any problem and ensure either destination port matches or falls within the range.

(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port))) || (trie_val->protocol == ip->protocol &&
((trie_val->start_port == ANY_PORT) || (l4_dst_port == trie_val->start_port) ||
(l4_dst_port > trie_val->start_port && l4_dst_port <= trie_val->end_port)))) {
//Inject in to conntrack map
Expand Down
Loading
Loading