Skip to content

Commit

Permalink
refactor: make single ringbuff reader for each client
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak160 <rocksarthak45@gmail.com>
  • Loading branch information
Sarthak160 committed Nov 4, 2024
1 parent 444f317 commit 40bdf9a
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 56 deletions.
22 changes: 19 additions & 3 deletions pkg/agent/hooks/conn/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewFactory(inactivityThreshold time.Duration, logger *zap.Logger) *Factory

// ProcessActiveTrackers iterates over all conn the trackers and checks if they are complete. If so, it captures the ingress call and
// deletes the tracker. If the tracker is inactive for a long time, it deletes it.
func (factory *Factory) ProcessActiveTrackers(ctx context.Context, t chan *models.TestCase, opts models.IncomingOptions) {
func (factory *Factory) ProcessActiveTrackers(ctx context.Context, testMap *sync.Map, opts models.IncomingOptions) {
factory.mutex.Lock()
defer factory.mutex.Unlock()
var trackersToDelete []ID
Expand All @@ -52,7 +52,7 @@ func (factory *Factory) ProcessActiveTrackers(ctx context.Context, t chan *model
case <-ctx.Done():
return
default:
ok, requestBuf, responseBuf, reqTimestampTest, resTimestampTest := tracker.IsComplete()
ok, requestBuf, responseBuf, reqTimestampTest, resTimestampTest, clientId := tracker.IsComplete()

Check failure on line 55 in pkg/agent/hooks/conn/factory.go

View workflow job for this annotation

GitHub Actions / lint

var-naming: var clientId should be clientID (revive)

Check warning on line 55 in pkg/agent/hooks/conn/factory.go

View workflow job for this annotation

GitHub Actions / lint

var-naming: var clientId should be clientID (revive)
if ok {
fmt.Println("Processing the tracker with key: ", connID)
fmt.Println("Request Buffer::::::::: ", string(requestBuf))
Expand All @@ -71,7 +71,23 @@ func (factory *Factory) ProcessActiveTrackers(ctx context.Context, t chan *model
utils.LogError(factory.logger, err, "failed to parse the http response from byte array", zap.Any("responseBuf", responseBuf))
continue
}
capture(ctx, factory.logger, t, parsedHTTPReq, parsedHTTPRes, reqTimestampTest, resTimestampTest, opts)

//get the channel from the test map
// failed to get the channel from the test map, if the client id is not found
t, ok := testMap.Load(clientId)
if !ok {
factory.logger.Error("failed to get the channel from the test map")
continue
}

// type assert the channel
tc, ok := t.(chan *models.TestCase)
if !ok {
factory.logger.Error("failed to type assert the channel from the test map")
continue
}

capture(ctx, factory.logger, tc, parsedHTTPReq, parsedHTTPRes, reqTimestampTest, resTimestampTest, opts)

} else if tracker.IsInactive(factory.inactivityThreshold) {
trackersToDelete = append(trackersToDelete, connID)
Expand Down
47 changes: 29 additions & 18 deletions pkg/agent/hooks/conn/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"os"
"sync"
"time"
"unsafe"

Expand All @@ -26,17 +27,17 @@ import (
var eventAttributesSize = int(unsafe.Sizeof(SocketDataEvent{}))

// ListenSocket starts the socket event listeners
func ListenSocket(ctx context.Context, l *zap.Logger, clientID uint64, openMap, dataMap, closeMap *ebpf.Map, opts models.IncomingOptions) (<-chan *models.TestCase, error) {
t := make(chan *models.TestCase, 500)
func ListenSocket(ctx context.Context, l *zap.Logger, clientID uint64, testMap *sync.Map, openMap, dataMap, closeMap *ebpf.Map, opts models.IncomingOptions) error {

err := initRealTimeOffset()
if err != nil {
utils.LogError(l, err, "failed to initialize real time offset")
return nil, errors.New("failed to start socket listeners")
return errors.New("failed to start socket listeners")
}
c := NewFactory(time.Minute, l)
g, ok := ctx.Value(models.ErrGroupKey).(*errgroup.Group)
if !ok {
return nil, errors.New("failed to get the error group from the context")
return errors.New("failed to get the error group from the context")
}
fmt.Println("Starting the socket listener", c.connections)
g.Go(func() error {
Expand All @@ -50,32 +51,44 @@ func ListenSocket(ctx context.Context, l *zap.Logger, clientID uint64, openMap,
return
default:
// TODO refactor this to directly consume the events from the maps
c.ProcessActiveTrackers(ctx, t, opts)
c.ProcessActiveTrackers(ctx, testMap, opts)
time.Sleep(100 * time.Millisecond)
}
}
}()
<-ctx.Done()
close(t)

//get the channel from test map and close it
t, ok := testMap.Load(clientID)
if ok {
tc, ok := t.(chan *models.TestCase)
if ok {
// Close the channel when the context is done
close(tc)
} else {
println("Failed to type assert the channel from the test map")
}
}

return nil
})

err = open(ctx, c, l, openMap)
if err != nil {
utils.LogError(l, err, "failed to start open socket listener")
return nil, errors.New("failed to start socket listeners")
return errors.New("failed to start socket listeners")
}
err = data(ctx, clientID, c, l, dataMap)
if err != nil {
utils.LogError(l, err, "failed to start data socket listener")
return nil, errors.New("failed to start socket listeners")
return errors.New("failed to start socket listeners")
}
err = exit(ctx, c, l, closeMap)
if err != nil {
utils.LogError(l, err, "failed to start close socket listener")
return nil, errors.New("failed to start socket listeners")
return errors.New("failed to start socket listeners")
}
return t, err
return err
}

func open(ctx context.Context, c *Factory, l *zap.Logger, m *ebpf.Map) error {
Expand Down Expand Up @@ -142,10 +155,8 @@ func data(ctx context.Context, id uint64, c *Factory, l *zap.Logger, m *ebpf.Map
return errors.New("failed to get the error group from the context")
}
g.Go(func() error {
fmt.Println("INSIDE GOROUTINE !!")
defer utils.Recover(l)
go func() {
fmt.Println("INSIDE GOROUTINE 222 !!")
defer utils.Recover(l)
for {
record, err := r.Read()
Expand Down Expand Up @@ -180,14 +191,14 @@ func data(ctx context.Context, id uint64, c *Factory, l *zap.Logger, m *ebpf.Map
l.Debug(fmt.Sprintf("Request EntryTimestamp :%v\n", convertUnixNanoToTime(event.EntryTimestampNano)))
}

if event.ClientID != id {
// log the expected client id and the received client id
l.Info(fmt.Sprintf("Expected ClientID: %v, Received ClientID: %v", id, event.ClientID))
continue
}
// if event.ClientID != id {
// // log the expected client id and the received client id
// l.Info(fmt.Sprintf("Expected ClientID: %v, Received ClientID: %v", id, event.ClientID))

// continue
// }

fmt.Println("SocketDataEvent-1: ", event.ClientID)
fmt.Println("SocketDataEvent-2: ", event.ConnID)

c.GetOrCreate(event.ConnID).AddDataEvent(event)
}
Expand Down
18 changes: 16 additions & 2 deletions pkg/agent/hooks/conn/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ type Tracker struct {

reqTimestamps []time.Time
isNewRequest bool

//Client Id's array
clientIds []uint64

Check failure on line 71 in pkg/agent/hooks/conn/tracker.go

View workflow job for this annotation

GitHub Actions / lint

var-naming: struct field clientIds should be clientIDs (revive)

Check warning on line 71 in pkg/agent/hooks/conn/tracker.go

View workflow job for this annotation

GitHub Actions / lint

var-naming: struct field clientIds should be clientIDs (revive)
}

func NewTracker(connID ID, logger *zap.Logger) *Tracker {
Expand Down Expand Up @@ -107,7 +110,7 @@ func (conn *Tracker) decRecordTestCount() {
}

// IsComplete checks if the current conn has valid request & response info to capture and also returns the request and response data buffer.
func (conn *Tracker) IsComplete() (bool, []byte, []byte, time.Time, time.Time) {
func (conn *Tracker) IsComplete() (bool, []byte, []byte, time.Time, time.Time, uint64) {
conn.mutex.Lock()
defer conn.mutex.Unlock()

Expand Down Expand Up @@ -234,6 +237,7 @@ func (conn *Tracker) IsComplete() (bool, []byte, []byte, time.Time, time.Time) {
conn.logger.Debug("unverified recording", zap.Any("recordTraffic", recordTraffic))
}

var clientID uint64
// Checking if record traffic is recorded and request & response timestamp is captured or not.
if recordTraffic {
if len(conn.reqTimestamps) > 0 {
Expand All @@ -253,9 +257,16 @@ func (conn *Tracker) IsComplete() (bool, []byte, []byte, time.Time, time.Time) {
}

conn.logger.Debug(fmt.Sprintf("TestRequestTimestamp:%v || TestResponseTimestamp:%v", reqTimestamps, respTimestamp))

//popping out the client id
if len(conn.clientIds) > 0 {
clientID = conn.clientIds[0]
conn.clientIds = conn.clientIds[1:]
}

}

return recordTraffic, requestBuf, responseBuf, reqTimestamps, respTimestamp
return recordTraffic, requestBuf, responseBuf, reqTimestamps, respTimestamp, clientID
}

// reset resets the conn's request and response data buffers.
Expand Down Expand Up @@ -302,6 +313,9 @@ func (conn *Tracker) AddDataEvent(event SocketDataEvent) {
// This is to ensure that we capture the response timestamp for the first chunk of the response.
if !conn.isNewRequest {
conn.isNewRequest = true

// set the client id
conn.clientIds = append(conn.clientIds, event.ClientID)
}

// Assign the size of the message to the variable msgLengt
Expand Down
21 changes: 17 additions & 4 deletions pkg/agent/hooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func NewHooks(logger *zap.Logger, cfg *config.Config) *Hooks {
proxyIP6: [4]uint32{0000, 0000, 0000, 0001},
proxyPort: cfg.ProxyPort,
dnsPort: cfg.DNSPort,
TestMap: &sync.Map{},
}
}

Expand All @@ -49,8 +50,9 @@ type Hooks struct {
proxyIP6 [4]uint32
proxyPort uint32
dnsPort uint32

m sync.Mutex
TestMap *sync.Map
once sync.Once
m sync.Mutex
// eBPF C shared maps
clientRegistrationMap *ebpf.Map
agentRegistartionMap *ebpf.Map
Expand Down Expand Up @@ -479,8 +481,19 @@ func (h *Hooks) load(opts agent.HookCfg) error {
}

func (h *Hooks) Record(ctx context.Context, clientID uint64, opts models.IncomingOptions) (<-chan *models.TestCase, error) {
fmt.Println("Recording hooks...")
return conn.ListenSocket(ctx, h.logger, clientID, h.objects.SocketOpenEvents, h.objects.SocketDataEvents, h.objects.SocketCloseEvents, opts)
// clientId and <-chan *models.TestCase ka map
tc := make(chan *models.TestCase, 1)
// create a sync map with key clientId and t as value
// this map will be used to store the test cases for each client
h.TestMap.Store(clientID, tc)

err := conn.ListenSocket(ctx, h.logger, clientID, h.TestMap, h.objects.SocketOpenEvents, h.objects.SocketDataEvents, h.objects.SocketCloseEvents, opts)
if err != nil {
return nil, err
}

// return the receiver of the channel
return tc, nil
}

func (h *Hooks) SendKeployClientInfo(clientID uint64, clientInfo structs.ClientInfo) error {
Expand Down
8 changes: 4 additions & 4 deletions pkg/agent/routes/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ func (a *AgentRequest) RegisterClients(w http.ResponseWriter, r *http.Request) {
func (a *AgentRequest) DeRegisterClients(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")

var OutgoingReq models.OutgoingReq
err := json.NewDecoder(r.Body).Decode(&OutgoingReq)
var UnregisterReq models.UnregisterReq
err := json.NewDecoder(r.Body).Decode(&UnregisterReq)

mockRes := models.AgentResp{
ClientID: OutgoingReq.ClientID,
ClientID: UnregisterReq.ClientID,
Error: nil,
IsSuccess: true,
}
Expand All @@ -201,7 +201,7 @@ func (a *AgentRequest) DeRegisterClients(w http.ResponseWriter, r *http.Request)
return
}

err = a.agent.DeRegisterClient(r.Context(), OutgoingReq.ClientID)
err = a.agent.DeRegisterClient(r.Context(), UnregisterReq)
if err != nil {
mockRes.Error = err
mockRes.IsSuccess = false
Expand Down
1 change: 1 addition & 0 deletions pkg/models/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ type SetMocksReq struct {
}
type UnregisterReq struct {
ClientID uint64 `json:"clientId"`
Mode Mode `json:"mode"`
}
13 changes: 6 additions & 7 deletions pkg/platform/http/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,18 +581,16 @@ func (a *AgentClient) RegisterClient(ctx context.Context, opts models.SetupOptio
return nil
}

func (a *AgentClient) UnregisterClient(ctx context.Context, clientID uint64) error {
func (a *AgentClient) UnregisterClient(ctx context.Context, unregister models.UnregisterReq) error {
// Unregister the client with the server
isAgentRunning := a.isAgentRunning(context.Background())
if !isAgentRunning {
a.logger.Warn("keploy agent is not running, skipping unregister client")
return io.EOF
}
requestBody := models.UnregisterReq{
ClientID: clientID,
}

fmt.Println("Unregistering the client with the server")
requestJSON, err := json.Marshal(requestBody)
requestJSON, err := json.Marshal(unregister)
if err != nil {
utils.LogError(a.logger, err, "failed to marshal request body for unregister client")
return fmt.Errorf("error marshaling request body for unregister client: %s", err.Error())
Expand All @@ -605,9 +603,10 @@ func (a *AgentClient) UnregisterClient(ctx context.Context, clientID uint64) err
}
req.Header.Set("Content-Type", "application/json")

// Make the HTTP request
resp, err := a.client.Do(req)

if err != nil {
return fmt.Errorf("failed to send request for unregister client: %s", err.Error())
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to unregister client: %s", resp.Status)
}
Expand Down
14 changes: 9 additions & 5 deletions pkg/service/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,17 @@ func (a *Agent) GetConsumedMocks(ctx context.Context, id uint64) ([]string, erro
return a.Proxy.GetConsumedMocks(ctx, id)
}

func (a *Agent) DeRegisterClient(ctx context.Context, id uint64) error {
func (a *Agent) DeRegisterClient(ctx context.Context, unregister models.UnregisterReq) error {
fmt.Println("Inside DeRegisterClient of agent binary !!")
err := a.Proxy.MakeClientDeRegisterd(ctx)
if err != nil {
return err
// send the info of the mode if its test mode we dont need to send the last mock

if unregister.Mode != models.MODE_TEST {
err := a.Proxy.MakeClientDeRegisterd(ctx)
if err != nil {
return err
}
}
err = a.Hooks.DeleteKeployClientInfo(id)
err := a.Hooks.DeleteKeployClientInfo(unregister.ClientID)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/service/agent/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type Service interface {
SetMocks(ctx context.Context, id uint64, filtered []*models.Mock, unFiltered []*models.Mock) error
GetConsumedMocks(ctx context.Context, id uint64) ([]string, error)
RegisterClient(ctx context.Context, opts models.SetupOptions) error
DeRegisterClient(ctx context.Context, id uint64) error
DeRegisterClient(ctx context.Context, opts models.UnregisterReq) error
}

type Options struct {
Expand Down
7 changes: 6 additions & 1 deletion pkg/service/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ func (r *Recorder) Start(ctx context.Context, reRecord bool) error {
}
}

err := r.instrumentation.UnregisterClient(ctx, clientID)
unregister := models.UnregisterReq{
ClientID: clientID,
Mode: models.MODE_RECORD,
}

err := r.instrumentation.UnregisterClient(ctx, unregister)
if err != nil && err != io.EOF {
utils.LogError(r.logger, err, "failed to unregister client")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/service/record/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Instrumentation interface {
// Run is blocking call and will execute until error
Run(ctx context.Context, id uint64, opts models.RunOptions) models.AppError
GetContainerIP(ctx context.Context, id uint64) (string, error)
UnregisterClient(ctx context.Context, clientID uint64) error
UnregisterClient(ctx context.Context, opts models.UnregisterReq) error
}

type Service interface {
Expand Down
Loading

0 comments on commit 40bdf9a

Please sign in to comment.