Skip to content

Commit

Permalink
Add FQDN TCP DNS support
Browse files Browse the repository at this point in the history
Signed-off-by: graysonwu <wgrayson@vmware.com>
  • Loading branch information
GraysonWu committed Feb 14, 2023
1 parent a63314f commit dc613c8
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 19 deletions.
67 changes: 59 additions & 8 deletions pkg/agent/controller/networkpolicy/fqdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,7 @@ func (f *fqdnController) HandlePacketIn(pktIn *ofctrl.PacketIn) error {
func (f *fqdnController) handlePacketIn(pktIn *ofctrl.PacketIn) error {
klog.V(4).InfoS("Received a packetIn for DNS response")
waitCh := make(chan error, 1)
handleUDPData := func(dnsPkt *protocol.UDP) {
dnsData := dnsPkt.Data
handleDNSData := func(dnsData []byte) {
dnsMsg := dns.Msg{}
if err := dnsMsg.Unpack(dnsData); err != nil {
waitCh <- err
Expand All @@ -762,14 +761,30 @@ func (f *fqdnController) handlePacketIn(pktIn *ofctrl.PacketIn) error {
}
switch ipPkt := ethernetPkt.Data.(type) {
case *protocol.IPv4:
switch dnsPkt := ipPkt.Data.(type) {
case *protocol.UDP:
handleUDPData(dnsPkt)
proto := ipPkt.Protocol
switch proto {
case protocol.Type_UDP:
udpPkt := ipPkt.Data.(*protocol.UDP)
handleDNSData(udpPkt.Data)
case protocol.Type_TCP:
tcpPkt, err := binding.GetTCPPacketFromIPMessage(ipPkt)
if err != nil {
return
}
handleDNSData(tcpPkt.Data)
}
case *protocol.IPv6:
switch dnsPkt := ipPkt.Data.(type) {
case *protocol.UDP:
handleUDPData(dnsPkt)
proto := ipPkt.NextHeader
switch proto {
case protocol.Type_UDP:
udpPkt := ipPkt.Data.(*protocol.UDP)
handleDNSData(udpPkt.Data)
case protocol.Type_TCP:
tcpPkt, err := binding.GetTCPPacketFromIPMessage(ipPkt)
if err != nil {
return
}
handleDNSData(tcpPkt.Data)
}
}
}()
Expand Down Expand Up @@ -806,6 +821,8 @@ func (f *fqdnController) sendDNSPacketout(pktIn *ofctrl.PacketIn) error {
switch dnsPkt := ipPkt.Data.(type) {
case *protocol.UDP:
packetData = dnsPkt.Data
case *protocol.TCP:
packetData = dnsPkt.Data
}
case *protocol.IPv6:
srcIP = ipPkt.NWSrc.String()
Expand All @@ -815,6 +832,8 @@ func (f *fqdnController) sendDNSPacketout(pktIn *ofctrl.PacketIn) error {
switch dnsPkt := ipPkt.Data.(type) {
case *protocol.UDP:
packetData = dnsPkt.Data
case *protocol.TCP:
packetData = dnsPkt.Data
}
}
if prot == protocol.Type_UDP {
Expand Down Expand Up @@ -848,6 +867,38 @@ func (f *fqdnController) sendDNSPacketout(pktIn *ofctrl.PacketIn) error {
udpDstPort,
packetData,
mutatePacketOut)
} else if prot == protocol.Type_TCP {
inPort := f.gwPort
if inPort == 0 {
// Use the original in_port number in the packetIn message to avoid an invalid input port number. Note that,
// this should not happen in container case as antrea-gw0 always exists. This check is for security purpose.
matches := pktIn.GetMatches()
inPortField := matches.GetMatchByName("OXM_OF_IN_PORT")
if inPortField != nil {
inPort = inPortField.GetValue().(uint32)
}
}
tcpSrcPort, tcpDstPort, tcpSeqNum, _, tcpFlag, err := binding.GetTCPHeaderData(ethernetPkt.Data)
if err != nil {
klog.ErrorS(err, "Failed to get TCP header data")
return err
}
mutatePacketOut := func(packetOutBuilder binding.PacketOutBuilder) binding.PacketOutBuilder {
return packetOutBuilder.AddLoadRegMark(openflow.CustomReasonDNSRegMark)
}
return f.ofClient.SendTCPPacketOut(
ethernetPkt.HWSrc.String(),
ethernetPkt.HWDst.String(),
srcIP,
dstIP,
inPort,
0,
isIPv6,
tcpSrcPort,
tcpDstPort,
tcpSeqNum+1,
tcpFlag,
mutatePacketOut)
}
return nil
}
39 changes: 36 additions & 3 deletions pkg/agent/openflow/network_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ var (
MatchServiceGroupID = types.NewMatchKey(binding.ProtocolIP, types.ServiceGroupIDAddr, "reg7[0..31]")
MatchIGMPProtocol = types.NewMatchKey(binding.ProtocolIGMP, types.IGMPAddr, "igmp")
MatchLabelID = types.NewMatchKey(binding.ProtocolIP, types.LabelIDAddr, "tun_id")
MatchTCPFlag = types.NewMatchKey(binding.ProtocolTCP, types.TCPFlagAddr, "tcp_flags")
Unsupported = types.NewMatchKey(binding.ProtocolIP, types.UnSupported, "unknown")

// metricFlowIdentifier is used to identify metric flows in metric table.
Expand All @@ -79,9 +80,15 @@ var (
metricFlowIdentifier = fmt.Sprintf("priority=%d,", priorityNormal)

protocolUDP = v1beta2.ProtocolUDP
protocolTCP = v1beta2.ProtocolTCP
dnsPort = intstr.FromInt(53)
)

type TCPFlag struct {
Flag uint16
Mask uint16
}

// IP address calculated from Pod's address.
type IPAddress net.IP

Expand Down Expand Up @@ -699,17 +706,43 @@ func (c *client) NewDNSpacketInConjunction(id uint32) error {
if err := c.ofEntryOperations.AddAll(conj.actionFlows); err != nil {
return fmt.Errorf("error when adding action flows for the DNS conjunction: %w", err)
}
dnsPriority := priorityDNSIntercept
conj.serviceClause = conj.newClause(1, 2, getTableByID(conj.ruleTableID), nil)
conj.toClause = conj.newClause(2, 2, getTableByID(conj.ruleTableID), nil)
udpService := v1beta2.Service{
Protocol: &protocolUDP,
Port: &dnsPort,
}
dnsPriority := priorityDNSIntercept
conj.serviceClause = conj.newClause(1, 2, getTableByID(conj.ruleTableID), nil)
conj.toClause = conj.newClause(2, 2, getTableByID(conj.ruleTableID), nil)
tcpService := v1beta2.Service{
Protocol: &protocolTCP,
Port: &dnsPort,
}
tcpServiceMatch := &conjunctiveMatch{
tableID: conj.serviceClause.ruleTable.GetID(),
matchPairs: []matchPair{
getServiceMatchPairs(tcpService, c.featureNetworkPolicy.ipProtocols, true)[0][0],
{
matchKey: MatchTCPFlag,
matchValue: TCPFlag{
// URG|ACK|PSH|RST|SYN|FIN|
Flag: 0b011000,
Mask: 0b111111,
},
},
},
priority: &dnsPriority,
}

c.featureNetworkPolicy.conjMatchFlowLock.Lock()
defer c.featureNetworkPolicy.conjMatchFlowLock.Unlock()
ctxChanges := conj.serviceClause.addServiceFlows(c.featureNetworkPolicy, []v1beta2.Service{udpService}, &dnsPriority, true, false)
ctxChange := conj.serviceClause.addConjunctiveMatchFlow(c.featureNetworkPolicy, tcpServiceMatch, false, false)
ctxChanges = append(ctxChanges, ctxChange)
for _, change := range ctxChanges {
for _, pa := range change.context.matchPairs {
klog.Infof("%s:%s", pa.matchKey.GetKeyString(), pa.matchValue)
}
}
if err := c.featureNetworkPolicy.applyConjunctiveMatchFlows(ctxChanges); err != nil {
return err
}
Expand Down
1 change: 0 additions & 1 deletion pkg/agent/openflow/network_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ var (
actionAllow = crdv1alpha1.RuleActionAllow
actionDrop = crdv1alpha1.RuleActionDrop
port8080 = intstr.FromInt(8080)
protocolTCP = v1beta2.ProtocolTCP
protocolICMP = v1beta2.ProtocolICMP
priority100 = uint16(100)
priority200 = uint16(200)
Expand Down
6 changes: 6 additions & 0 deletions pkg/agent/openflow/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -2010,6 +2010,12 @@ func (f *featureNetworkPolicy) addFlowMatch(fb binding.FlowBuilder, matchKey *ty
fb = fb.MatchProtocol(matchKey.GetOFProtocol())
case MatchLabelID:
fb = fb.MatchTunnelID(uint64(matchValue.(uint32)))
case MatchTCPFlag:
fb = fb.MatchProtocol(matchKey.GetOFProtocol())
if matchValue != nil {
tcpFlag := matchValue.(TCPFlag)
fb = fb.MatchTCPFlag(tcpFlag.Flag, tcpFlag.Mask)
}
}
return fb
}
Expand Down
1 change: 1 addition & 0 deletions pkg/agent/types/networkpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const (
ServiceGroupIDAddr
IGMPAddr
LabelIDAddr
TCPFlagAddr
UnSupported
)

Expand Down
1 change: 1 addition & 0 deletions pkg/ovs/openflow/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ type FlowBuilder interface {
MatchConjID(value uint32) FlowBuilder
MatchDstPort(port uint16, portMask *uint16) FlowBuilder
MatchSrcPort(port uint16, portMask *uint16) FlowBuilder
MatchTCPFlag(flag, mask uint16) FlowBuilder
MatchICMPType(icmpType byte) FlowBuilder
MatchICMPCode(icmpCode byte) FlowBuilder
MatchICMPv6Type(icmp6Type byte) FlowBuilder
Expand Down
7 changes: 7 additions & 0 deletions pkg/ovs/openflow/ofctrl_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,13 @@ func (b *ofFlowBuilder) MatchSrcPort(port uint16, portMask *uint16) FlowBuilder
return b
}

func (b *ofFlowBuilder) MatchTCPFlag(flag, mask uint16) FlowBuilder {
b.matchers = append(b.matchers, fmt.Sprintf("tcp_flags=%b/%b", uint8(flag), uint8(mask)))
b.Match.TcpFlags = &flag
b.Match.TcpFlagsMask = &mask
return b
}

// MatchCTSrcIP matches the source IPv4 address of the connection tracker original direction tuple. This match requires
// a match to valid connection tracking state as a prerequisite, and valid connection tracking state matches include
// "+new", "+est", "+rel" and "+trk-inv".
Expand Down
21 changes: 15 additions & 6 deletions pkg/ovs/openflow/ofctrl_packetin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ const (
icmp6EchoRequestType uint8 = 128
)

// GetTCPHeaderData gets TCP header data from IP packet.
func GetTCPHeaderData(ipPkt util.Message) (tcpSrcPort, tcpDstPort uint16, tcpSeqNum, tcpAckNum uint32, tcpFlags uint8, err error) {
tcpIn, err := GetTCPPacketFromIPMessage(ipPkt)
if err != nil {
return 0, 0, 0, 0, 0, err
}

return tcpIn.PortSrc, tcpIn.PortDst, tcpIn.SeqNum, tcpIn.AckNum, tcpIn.Code, nil
}

// GetTCPPacketFromIPMessage gets a TCP struct from an IP message.
func GetTCPPacketFromIPMessage(ipPkt util.Message) (tcpPkt *protocol.TCP, err error) {
var tcpBytes []byte

// Transfer Buffer to TCP
Expand All @@ -40,15 +49,15 @@ func GetTCPHeaderData(ipPkt util.Message) (tcpSrcPort, tcpDstPort uint16, tcpSeq
tcpBytes, err = typedIPPkt.Data.(*util.Buffer).MarshalBinary()
}
if err != nil {
return 0, 0, 0, 0, 0, err
return nil, err
}
tcpIn := new(protocol.TCP)
err = tcpIn.UnmarshalBinary(tcpBytes)
tcpPkt = new(protocol.TCP)
err = tcpPkt.UnmarshalBinary(tcpBytes)
if err != nil {
return 0, 0, 0, 0, 0, err
return nil, err
}

return tcpIn.PortSrc, tcpIn.PortDst, tcpIn.SeqNum, tcpIn.AckNum, tcpIn.Code, nil
return tcpPkt, nil
}

func GetUDPHeaderData(ipPkt util.Message) (udpSrcPort, udpDstPort uint16, err error) {
Expand Down
14 changes: 14 additions & 0 deletions pkg/ovs/openflow/testing/mock_openflow.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 44 additions & 1 deletion test/e2e/antreapolicy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,48 @@ func testFQDNPolicyInClusterService(t *testing.T) {
failOnError(k8sUtils.DeleteACNP(builder.Name), t)
}

// testFQDNPolicyTCP
func testFQDNPolicyTCP(t *testing.T) {
// The ipv6-only test env doesn't have IPv6 access to the web.
skipIfNotIPv4Cluster(t)
// It is convenient to have higher log verbosity for FQDNtests for troubleshooting failures.
logLevel := log.GetLevel()
log.SetLevel(log.TraceLevel)
defer log.SetLevel(logLevel)
builder := &ClusterNetworkPolicySpecBuilder{}
builder = builder.SetName("test-acnp-fqdn-tcp").
SetTier("application").
SetPriority(1.0).
SetAppliedToGroup([]ACNPAppliedToSpec{{NSSelector: map[string]string{}}})
builder.AddFQDNRule("github.com", ProtocolTCP, nil, nil, nil, "", nil, crdv1alpha1.RuleActionDrop)

testcases := []podToAddrTestStep{
{
Pod(namespaces["y"] + "/a"),
"github.com",
80,
Dropped,
},
}
acnp, err := k8sUtils.CreateOrUpdateACNP(builder.Get())
failOnError(err, t)
failOnError(waitForResourceReady(t, timeout, acnp), t)
for _, tc := range testcases {
log.Tracef("Probing: %s -> %s", tc.clientPod.PodName(), tc.destAddr)
destIP := k8sUtils.digDNS(tc.clientPod.PodName(), tc.clientPod.Namespace(), tc.destAddr, true)
connectivity, err := k8sUtils.ProbeAddr(tc.clientPod.Namespace(), "pod", tc.clientPod.PodName(), destIP, tc.destPort, ProtocolTCP)
if err != nil {
t.Errorf("failure -- could not complete probe: %v", err)
}
if connectivity != tc.expectedConnectivity {
t.Errorf("failure -- wrong results for probe: Source %s/%s --> Dest %s:%d connectivity: %v, expected: %v",
tc.clientPod.Namespace(), tc.clientPod.PodName(), tc.destAddr, tc.destPort, connectivity, tc.expectedConnectivity)
}
}
// cleanup test resources
failOnError(k8sUtils.DeleteACNP(builder.Name), t)
}

func testToServices(t *testing.T) {
skipIfProxyDisabled(t)
var services []*v1.Service
Expand Down Expand Up @@ -4286,7 +4328,8 @@ func TestAntreaPolicy(t *testing.T) {
t.Run("Case=ANPGroupRefRuleIPBlocks", func(t *testing.T) { testANPGroupRefRuleIPBlocks(t) })
t.Run("Case=ANPNestedGroup", func(t *testing.T) { testANPNestedGroupCreateAndUpdate(t, data) })
t.Run("Case=ACNPFQDNPolicy", func(t *testing.T) { testFQDNPolicy(t) })
t.Run("Case=FQDNPolicyInCluster", func(t *testing.T) { testFQDNPolicyInClusterService(t) })
t.Run("Case=ACNPFQDNPolicyInCluster", func(t *testing.T) { testFQDNPolicyInClusterService(t) })
t.Run("Case=ACNPFQDNPolicyTCP", func(t *testing.T) { testFQDNPolicyTCP(t) })
t.Run("Case=ACNPToServices", func(t *testing.T) { testToServices(t) })
t.Run("Case=ACNPServiceAccountSelector", func(t *testing.T) { testServiceAccountSelector(t, data) })
t.Run("Case=ACNPNodeSelectorEgress", func(t *testing.T) { testACNPNodeSelectorEgress(t) })
Expand Down
Loading

0 comments on commit dc613c8

Please sign in to comment.