Skip to content

Commit

Permalink
Added checksum for PacketOut to avoid invalid packet (#967)
Browse files Browse the repository at this point in the history
  • Loading branch information
gran-vmv authored Jul 21, 2020
1 parent 4d09ed5 commit 940b6c4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pkg/agent/controller/traceflow/traceflow_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ import (
"github.com/vmware-tanzu/antrea/pkg/ovs/ovsconfig"
)

type icmpType uint8
type icmpCode uint8

const (
controllerName = "AntreaAgentTraceflowController"
// Set resyncPeriod to 0 to disable resyncing.
Expand All @@ -53,6 +56,9 @@ const (
// Seconds delay before injecting packet into OVS. The time of different nodes may not be completely
// synchronized, which requires a delay before inject packet.
injectPacketDelay = 5
// ICMP Echo Request type and code.
icmpEchoRequestType icmpType = 8
icmpEchoRequestCode icmpCode = 0
)

// Controller is responsible for setting up Openflow entries and injecting traceflow packet into
Expand Down Expand Up @@ -341,8 +347,8 @@ func (c *Controller) injectPacket(tf *opsv1alpha1.Traceflow) error {
TCPFlags,
UDPSrcPort,
UDPDstPort,
0,
0,
uint8(icmpEchoRequestType),
uint8(icmpEchoRequestCode),
ICMPID,
ICMPSequence,
uint32(podInterfaces[0].OFPort),
Expand Down
62 changes: 62 additions & 0 deletions pkg/ovs/openflow/ofctrl_packetout.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,17 @@ func (b *ofPacketOutBuilder) AddLoadAction(name string, data uint64, rng Range)
func (b *ofPacketOutBuilder) Done() *ofctrl.PacketOut {
if b.pktOut.ICMPHeader != nil {
b.setICMPData()
b.pktOut.ICMPHeader.Checksum = b.icmpHeaderChecksum()
b.pktOut.IPHeader.Length = 20 + b.pktOut.ICMPHeader.Len()
} else if b.pktOut.TCPHeader != nil {
b.pktOut.TCPHeader.HdrLen = 5
b.pktOut.TCPHeader.SeqNum = rand.Uint32()
b.pktOut.TCPHeader.AckNum = rand.Uint32()
b.pktOut.TCPHeader.Checksum = b.tcpHeaderChecksum()
b.pktOut.IPHeader.Length = 20 + b.pktOut.TCPHeader.Len()
} else if b.pktOut.UDPHeader != nil {
b.pktOut.UDPHeader.Length = b.pktOut.UDPHeader.Len()
b.pktOut.UDPHeader.Checksum = b.udpHeaderChecksum()
b.pktOut.IPHeader.Length = 20 + b.pktOut.UDPHeader.Len()
}
b.pktOut.IPHeader.Id = uint16(rand.Uint32())
Expand All @@ -217,6 +221,7 @@ func (b *ofPacketOutBuilder) Done() *ofctrl.PacketOut {
} else {
b.pktOut.IPHeader.Version = 0x6
}
b.pktOut.IPHeader.Checksum = b.ipHeaderChecksum()
return b.pktOut
}

Expand All @@ -230,3 +235,60 @@ func (b *ofPacketOutBuilder) setICMPData() {
}
b.pktOut.ICMPHeader.Data = data
}

func (b *ofPacketOutBuilder) ipHeaderChecksum() uint16 {
ipHeader := *b.pktOut.IPHeader
ipHeader.Checksum = 0
ipHeader.Data = nil
data, _ := ipHeader.MarshalBinary()
return checksum(data)
}

func (b *ofPacketOutBuilder) icmpHeaderChecksum() uint16 {
icmpHeader := *b.pktOut.ICMPHeader
icmpHeader.Checksum = 0
data, _ := icmpHeader.MarshalBinary()
return checksum(data)
}

func (b *ofPacketOutBuilder) tcpHeaderChecksum() uint16 {
tcpHeader := *b.pktOut.TCPHeader
tcpHeader.Checksum = 0
data, _ := tcpHeader.MarshalBinary()
checksumData := append(b.generatePseudoHeader(uint16(len(data))), data...)
return checksum(checksumData)
}

func (b *ofPacketOutBuilder) udpHeaderChecksum() uint16 {
udpHeader := *b.pktOut.UDPHeader
udpHeader.Checksum = 0
data, _ := udpHeader.MarshalBinary()
checksumData := append(b.generatePseudoHeader(uint16(len(data))), data...)
return checksum(checksumData)
}

func (b *ofPacketOutBuilder) generatePseudoHeader(length uint16) []byte {
pseudoHeader := make([]byte, 12)
copy(pseudoHeader[0:4], b.pktOut.IPHeader.NWSrc.To4())
copy(pseudoHeader[4:8], b.pktOut.IPHeader.NWDst.To4())
pseudoHeader[8] = 0x0
pseudoHeader[9] = b.pktOut.IPHeader.Protocol
binary.BigEndian.PutUint16(pseudoHeader[10:12], length)
return pseudoHeader
}

func checksum(data []byte) uint16 {
var sum uint32
var index int
length := len(data)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index])
}
sum += (sum >> 16)
return uint16(^sum)
}

0 comments on commit 940b6c4

Please sign in to comment.