Skip to content

Commit

Permalink
add GetInitMessage and WriteBeforeMessage to output_tcp.go (#1193)
Browse files Browse the repository at this point in the history
* add GetInitMessage and WriteBeforeMessage to output_tcp.go

* try to fix code duplication
  • Loading branch information
ivan-stankov-salt-security authored Aug 7, 2023
1 parent 3062446 commit fd62d69
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 23 deletions.
31 changes: 25 additions & 6 deletions output_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type TCPOutputConfig struct {
Sticky bool `json:"output-tcp-sticky"`
SkipVerify bool `json:"output-tcp-skip-verify"`
Workers int `json:"output-tcp-workers"`

GetInitMessage func() *Message `json:"-"`
WriteBeforeMessage func(conn net.Conn, msg *Message) error `json:"-"`
}

// NewTCPOutput constructor for TCPOutput
Expand Down Expand Up @@ -78,14 +81,14 @@ func (o *TCPOutput) worker(bufferIndex int) {

defer conn.Close()

if o.config.GetInitMessage != nil {
msg := o.config.GetInitMessage()
_ = o.writeToConnection(conn, msg)
}

for {
msg := <-o.buf[bufferIndex]
if _, err = conn.Write(msg.Meta); err == nil {
if _, err = conn.Write(msg.Data); err == nil {
_, err = conn.Write(payloadSeparatorAsBytes)
}
}

err = o.writeToConnection(conn, msg)
if err != nil {
Debug(2, "INFO: TCP output connection closed, reconnecting")
go o.worker(bufferIndex)
Expand All @@ -95,6 +98,22 @@ func (o *TCPOutput) worker(bufferIndex int) {
}
}

func (o *TCPOutput) writeToConnection(conn net.Conn, msg *Message) (err error) {
if o.config.WriteBeforeMessage != nil {
err = o.config.WriteBeforeMessage(conn, msg)
}

if err == nil {
if _, err = conn.Write(msg.Meta); err == nil {
if _, err = conn.Write(msg.Data); err == nil {
_, err = conn.Write(payloadSeparatorAsBytes)
}
}
}

return err
}

func (o *TCPOutput) getBufferIndex(msg *Message) int {
if !o.config.Sticky {
o.workerIndex++
Expand Down
94 changes: 77 additions & 17 deletions output_tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"bufio"
"log"
"net"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestTCPOutput(t *testing.T) {
Expand All @@ -15,24 +18,8 @@ func TestTCPOutput(t *testing.T) {
listener := startTCP(func(data []byte) {
wg.Done()
})
input := NewTestInput()
output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{Workers: 10})

plugins := &InOutPlugins{
Inputs: []PluginReader{input},
Outputs: []PluginWriter{output},
}

emitter := NewEmitter()
go emitter.Start(plugins, Settings.Middleware)

for i := 0; i < 10; i++ {
wg.Add(1)
input.EmitGET()
}

wg.Wait()
emitter.Close()
runTCPOutput(wg, output, 10, false)
}

func startTCP(cb func([]byte)) net.Listener {
Expand Down Expand Up @@ -131,3 +118,76 @@ func getTestBytes() *Message {
Data: []byte("GET / HTTP/1.1\r\nHost: www.w3.org\r\nUser-Agent: Go 1.1 package http\r\nAccept-Encoding: gzip\r\n\r\n"),
}
}

func TestTCPOutputGetInitMessage(t *testing.T) {
wg := new(sync.WaitGroup)

var dataList [][]byte
listener := startTCP(func(data []byte) {
dataList = append(dataList, data)
wg.Done()
})
getInitMessage := func() *Message {
return &Message{
Meta: []byte{},
Data: []byte("test1"),
}
}
output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{Workers: 1, GetInitMessage: getInitMessage})

runTCPOutput(wg, output, 1, true)

if assert.Equal(t, 2, len(dataList)) {
assert.Equal(t, "test1", string(dataList[0]))
}
}

func TestTCPOutputGetInitMessageAndWriteBeforeMessage(t *testing.T) {
wg := new(sync.WaitGroup)

var dataList [][]byte
listener := startTCP(func(data []byte) {
dataList = append(dataList, data)
wg.Done()
})
getInitMessage := func() *Message {
return &Message{
Meta: []byte{},
Data: []byte("test2"),
}
}
writeBeforeMessage := func(conn net.Conn, _ *Message) error {
_, err := conn.Write([]byte("before"))
return err
}
output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{Workers: 1, GetInitMessage: getInitMessage, WriteBeforeMessage: writeBeforeMessage})

runTCPOutput(wg, output, 1, true)

if assert.Equal(t, 2, len(dataList)) {
assert.Equal(t, "beforetest2", string(dataList[0]))
assert.True(t, strings.HasPrefix(string(dataList[1]), "before"))
}
}

func runTCPOutput(wg *sync.WaitGroup, output PluginWriter, repeat int, initMessage bool) {
input := NewTestInput()
plugins := &InOutPlugins{
Inputs: []PluginReader{input},
Outputs: []PluginWriter{output},
}

emitter := NewEmitter()
go emitter.Start(plugins, Settings.Middleware)

if initMessage {
wg.Add(1)
}
for i := 0; i < repeat; i++ {
wg.Add(1)
input.EmitGET()
}

wg.Wait()
emitter.Close()
}

0 comments on commit fd62d69

Please sign in to comment.