Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close TCP connection in case of error #384

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/collector/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (cp *CollectingProcess) Stop() {
close(cp.stopChan)
// wait for all connections to be safely deleted and returned
cp.wg.Wait()
klog.Info("Stopping the collecting process")
klog.Info("Stopped the collecting process")
}

func (cp *CollectingProcess) GetAddress() net.Addr {
Expand Down
45 changes: 40 additions & 5 deletions pkg/collector/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"net"
"runtime"
"sync"
Expand All @@ -38,10 +39,14 @@ import (
testcerts "github.com/vmware/go-ipfix/pkg/test/certs"
)

var validTemplatePacket = []byte{0, 10, 0, 40, 95, 154, 107, 127, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 101, 255, 255, 0, 0, 220, 186}
var validTemplatePacketIPv6 = []byte{0, 10, 0, 32, 96, 27, 70, 6, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 16, 1, 0, 0, 2, 0, 27, 0, 16, 0, 28, 0, 16}
var validDataPacket = []byte{0, 10, 0, 33, 95, 154, 108, 18, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 17, 1, 2, 3, 4, 5, 6, 7, 8, 4, 112, 111, 100, 49}
var validDataPacketIPv6 = []byte{0, 10, 0, 52, 96, 27, 75, 252, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 36, 32, 1, 0, 0, 50, 56, 223, 225, 0, 99, 0, 0, 0, 0, 254, 251, 32, 1, 0, 0, 50, 56, 223, 225, 0, 99, 0, 0, 0, 0, 254, 251}
var (
validTemplatePacket = []byte{0, 10, 0, 40, 95, 154, 107, 127, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 101, 255, 255, 0, 0, 220, 186}
validTemplatePacketIPv6 = []byte{0, 10, 0, 32, 96, 27, 70, 6, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 16, 1, 0, 0, 2, 0, 27, 0, 16, 0, 28, 0, 16}
validDataPacket = []byte{0, 10, 0, 33, 95, 154, 108, 18, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 17, 1, 2, 3, 4, 5, 6, 7, 8, 4, 112, 111, 100, 49}
validDataPacketIPv6 = []byte{0, 10, 0, 52, 96, 27, 75, 252, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 36, 32, 1, 0, 0, 50, 56, 223, 225, 0, 99, 0, 0, 0, 0, 254, 251, 32, 1, 0, 0, 50, 56, 223, 225, 0, 99, 0, 0, 0, 0, 254, 251}

invalidTemplatePacketWrongVersion = []byte{0, 9, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0, 218, 21}
)

const (
tcpTransport = "tcp"
Expand Down Expand Up @@ -85,6 +90,36 @@ func TestTCPCollectingProcess_ReceiveTemplateRecord(t *testing.T) {
assert.Equal(t, int64(1), cp.GetNumRecordsReceived())
}

func TestTCPCollectingProcess_ReceiveInvalidTemplateRecord(t *testing.T) {
input := getCollectorInput(tcpTransport, false, false)
cp, err := InitCollectingProcess(input)
if err != nil {
t.Fatalf("TCP Collecting Process does not start correctly: %v", err)
}
go cp.Start()
// wait until collector is ready
waitForCollectorReady(t, cp)
go func() {
// consume all messages to avoid blocking
ch := cp.GetMsgChan()
for range ch {
}
}()
collectorAddr := cp.GetAddress()
// client
conn, err := net.Dial(collectorAddr.Network(), collectorAddr.String())
if err != nil {
t.Errorf("Cannot establish connection to %s", collectorAddr.String())
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
conn.Write(invalidTemplatePacketWrongVersion)
readBuffer := make([]byte, 100)
_, err = conn.Read(readBuffer)
assert.ErrorIs(t, err, io.EOF)
cp.Stop()
}

func TestUDPCollectingProcess_ReceiveTemplateRecord(t *testing.T) {
input := getCollectorInput(udpTransport, false, false)
cp, err := InitCollectingProcess(input)
Expand Down Expand Up @@ -375,7 +410,7 @@ func TestCollectingProcess_DecodeTemplateRecord(t *testing.T) {
templateID: &template{},
},
},
templateRecord: []byte{0, 9, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0, 218, 21},
templateRecord: invalidTemplatePacketWrongVersion,
expectedErr: "collector only supports IPFIX (v10)",
// Invalid version means we stop decoding the packet right away, so we will not modify the existing template map
isTemplateExpected: true,
Expand Down
21 changes: 16 additions & 5 deletions pkg/collector/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ func (cp *CollectingProcess) handleTCPClient(conn net.Conn) {
}()
defer conn.Close()
reader := bufio.NewReader(conn)
doneCh := make(chan struct{})
cp.wg.Add(1)
// We read from the connection in a separate goroutine, so we can stop immediately when
// cp.StopChan is closed. An alternative would be to use a read deadline, and check
// cp.StopChan at every iteration.
go func() {
defer cp.wg.Done()
defer close(doneCh)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could also be considered a bug fix. Prior to this commit, return statements in the for loop would cause this goroutine to return, but we would block on cp.stopChan indefinitely (so I think there would be a goroutine leak, as the parent goroutine, which invoked handleTCPClient would never return).

for {
length, err := getMessageLength(reader)
if errors.Is(err, io.EOF) {
Expand All @@ -102,16 +107,22 @@ func (cp *CollectingProcess) handleTCPClient(conn net.Conn) {
}
message, err := cp.decodePacket(bytes.NewBuffer(buff), address)
if err != nil {
// TODO: should we close the connection instead and force the client to
// re-open it?
klog.ErrorS(err, "Error when decoding packet")
continue
// This can be an invalid template record, or invalid data record.
// We close the connection, which is the best way to let the client
// (exporter) know that something is wrong.
klog.ErrorS(err, "Error when decoding packet, closing connection")
return
}
klog.V(4).InfoS("Processed message from exporter",
"observationDomainID", message.GetObsDomainID(), "setType", message.GetSet().GetSetType(), "numRecords", message.GetSet().GetNumberOfRecords())
}
}()
<-cp.stopChan
select {
case <-cp.stopChan:
break
case <-doneCh:
break
}
}

func (cp *CollectingProcess) createServerConfig() (*tls.Config, error) {
Expand Down
Loading