Skip to content

Commit

Permalink
fix(selftest): use channel for sync (#364)
Browse files Browse the repository at this point in the history
attachgenericfd and perfbuffers selftests were using a goroutine without
any syncronization mechanism.
  • Loading branch information
geyslan authored Aug 8, 2023
1 parent d3f72ae commit 6ccba02
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
32 changes: 29 additions & 3 deletions selftest/attachgenericfd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package main
import "C"

import (
"bytes"
"fmt"
"net"
"os"
"time"
"unsafe"

bpf "github.com/aquasecurity/libbpfgo"
Expand All @@ -32,6 +34,7 @@ func main() {
os.Exit(-1)
}
defer unix.Close(serverFD)

serverAddr := &unix.SockaddrInet4{
Port: 22345,
Addr: [4]byte{127, 0, 0, 1},
Expand All @@ -40,6 +43,7 @@ func main() {
fmt.Fprintln(os.Stderr, err)
os.Exit(-1)
}

if err := unix.Listen(serverFD, 100); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(-1)
Expand All @@ -60,6 +64,7 @@ func main() {
os.Exit(-1)
}
}()

prog2, err := bpfModule.GetProgram("bpf_prog_verdict")
prog2.AttachGenericFD(sockMapRx.FileDescriptor(),
bpf.BPFAttachTypeSKSKBStreamVerdict, bpf.BPFFNone)
Expand All @@ -70,6 +75,8 @@ func main() {
}
}()

mapUpdateChan := make(chan struct{}, 1)

go func() {
acceptedFD, _, err := unix.Accept(serverFD)
if err != nil {
Expand All @@ -82,6 +89,8 @@ func main() {
fmt.Fprintln(os.Stderr, err)
os.Exit(-1)
}

mapUpdateChan <- struct{}{}
}()

c, err := net.Dial("tcp", "127.0.0.1:22345")
Expand All @@ -90,13 +99,30 @@ func main() {
os.Exit(-1)
}
defer c.Close()
if _, err = c.Write([]byte("foobar")); err != nil {

// wait for the bpf map to be updated
select {
case <-mapUpdateChan:
// continue with write/read
case <-time.After(15 * time.Second): // Same of the selftest
fmt.Fprintln(os.Stderr, "bpf map timeout")
os.Exit(-1)
}

input := []byte("foobar")
if _, err = c.Write(input); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(-1)
}
data := make([]byte, 10)
if _, err = c.Read(data); err != nil {

output := make([]byte, 6)
if _, err = c.Read(output); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(-1)
}

if !bytes.Equal(output, input) {
fmt.Fprintln(os.Stderr, "data mismatch")
os.Exit(-1)
}
}
33 changes: 20 additions & 13 deletions selftest/perfbuffers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "C"
import (
"os"
"runtime"
"time"

"encoding/binary"
"fmt"
Expand Down Expand Up @@ -67,26 +68,32 @@ func main() {

pb.Poll(300)

numberOfEventsReceived := 0
stop := make(chan struct{})

go func() {
for {
syscall.Mmap(999, 999, 999, 1, 1)
select {
case <-stop:
return
case b := <-eventsChannel:
if binary.LittleEndian.Uint32(b) != 2021 {
fmt.Fprintf(os.Stderr, "invalid data retrieved\n")
os.Exit(-1)
}
}
}
}()
recvLoop:
for {
b := <-eventsChannel
if binary.LittleEndian.Uint32(b) != 2021 {
fmt.Fprintf(os.Stderr, "invalid data retrieved\n")
os.Exit(-1)
}
numberOfEventsReceived++
if numberOfEventsReceived > 5 {
break recvLoop
}

// give some time for the upper goroutine to start
time.Sleep(100 * time.Millisecond)

for sent := 0; sent < 5; sent++ {
syscall.Mmap(999, 999, 999, 1, 1)
time.Sleep(100 * time.Millisecond)
}

close(stop)

// Test that it won't cause a panic or block if Stop or Close called multiple times
pb.Stop()
pb.Stop()
Expand Down

0 comments on commit 6ccba02

Please sign in to comment.