Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support packet and process hook #178

Merged
merged 13 commits into from
Aug 16, 2024
2 changes: 1 addition & 1 deletion ext/pkg/control/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (n *CallNode) forward(proc *process.Process) {
}
n.tracer.Read(inReader, inPck)

n.tracer.AddHandler(inPck, packet.HandlerFunc(func(backPck *packet.Packet) {
n.tracer.AddHook(inPck, packet.HookFunc(func(backPck *packet.Packet) {
n.tracer.Transform(inPck, backPck)
if _, ok := backPck.Payload().(types.Error); ok {
n.tracer.Write(errWriter, backPck)
Expand Down
2 changes: 1 addition & 1 deletion ext/pkg/control/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (n *LoopNode) forward(proc *process.Process) {
n.tracer.Transform(inPck, outPck)
}

n.tracer.AddHandler(inPck, packet.HandlerFunc(func(backPck *packet.Packet) {
n.tracer.AddHook(inPck, packet.HookFunc(func(backPck *packet.Packet) {
n.tracer.Transform(inPck, backPck)
if _, ok := backPck.Payload().(types.Error); ok {
n.tracer.Write(errWriter, backPck)
Expand Down
2 changes: 1 addition & 1 deletion ext/pkg/control/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (n *SessionNode) forward(proc *process.Process) {
}
}

n.tracer.AddHandler(inPck, packet.HandlerFunc(func(backPck *packet.Packet) {
n.tracer.AddHook(inPck, packet.HookFunc(func(backPck *packet.Packet) {
var err error
if v, ok := backPck.Payload().(types.Error); ok {
err = v.Unwrap()
Expand Down
16 changes: 0 additions & 16 deletions pkg/packet/handler.go

This file was deleted.

16 changes: 16 additions & 0 deletions pkg/packet/hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package packet

// Hook defines an interface for handling packets.
type Hook interface {
Handle(*Packet)
}

// HookFunc is a function type that implements the Handler interface.
type HookFunc func(*Packet)

var _ Hook = HookFunc(nil)

// Handle calls the underlying function represented by HandlerFunc with the provided packet.
func (f HookFunc) Handle(pck *Packet) {
f(pck)
}
60 changes: 53 additions & 7 deletions pkg/packet/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (

// Reader represents a packet reader that manages incoming packets from multiple writers.
type Reader struct {
writers []*Writer
in chan *Packet
out chan *Packet
done chan struct{}
mu sync.Mutex
writers []*Writer
in chan *Packet
out chan *Packet
done chan struct{}
inboundHooks []Hook
outboundHooks []Hook
mu sync.Mutex
}

// NewReader creates a new Reader instance and starts its processing loop.
Expand All @@ -37,7 +39,12 @@ func NewReader() *Reader {
if w := r.writer(); w == nil {
break
} else {
w.receive(New(types.NewError(ErrDroppedPacket)), r)
pck := New(types.NewError(ErrDroppedPacket))
if ok := w.receive(pck, r); ok {
for _, hook := range r.outboundHooks {
hook.Handle(pck)
}
}
}
}
return
Expand All @@ -63,6 +70,34 @@ func NewReader() *Reader {
return r
}

// AddInboundHook adds a handler to process inbound packets.
func (r *Reader) AddInboundHook(hook Hook) {
r.mu.Lock()
defer r.mu.Unlock()

for _, h := range r.inboundHooks {
if h == hook {
return
}
}

r.inboundHooks = append(r.inboundHooks, hook)
}

// AddOutboundHook adds a handler to process outbound packets.
func (r *Reader) AddOutboundHook(hook Hook) {
r.mu.Lock()
defer r.mu.Unlock()

for _, h := range r.outboundHooks {
if h == hook {
return
}
}

r.outboundHooks = append(r.outboundHooks, hook)
}

// Read returns the channel for reading packets from the reader.
func (r *Reader) Read() <-chan *Packet {
return r.out
Expand All @@ -73,7 +108,13 @@ func (r *Reader) Receive(pck *Packet) bool {
if w := r.writer(); w == nil {
return false
} else {
return w.receive(pck, r)
ok := w.receive(pck, r)
if ok {
for _, hook := range r.outboundHooks {
hook.Handle(pck)
}
}
return ok
}
}

Expand All @@ -99,6 +140,11 @@ func (r *Reader) write(pck *Packet, writer *Writer) bool {
default:
r.writers = append(r.writers, writer)
r.in <- pck

for _, hook := range r.inboundHooks {
hook.Handle(pck)
}

return true
}
}
Expand Down
32 changes: 32 additions & 0 deletions pkg/packet/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@ import (
"github.com/stretchr/testify/assert"
)

func TestReader_AddHook(t *testing.T) {
w := NewWriter()
defer w.Close()

r := NewReader()
defer r.Close()

count := 0
r.AddInboundHook(HookFunc(func(_ *Packet) {
count += 1
}))
r.AddOutboundHook(HookFunc(func(_ *Packet) {
count += 1
}))

w.Link(r)

out := New(nil)

w.Write(out)

in := <-r.Read()
assert.Equal(t, 1, count)

r.Receive(in)
assert.Equal(t, 2, count)

back, ok := <-w.Receive()
assert.True(t, ok)
assert.Equal(t, in, back)
}

func TestReader_Receive(t *testing.T) {
w := NewWriter()
defer w.Close()
Expand Down
20 changes: 10 additions & 10 deletions pkg/packet/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// Tracer tracks the lifecycle and transformations of packets as they pass through readers and writers.
type Tracer struct {
handlers map[*Packet][]Handler
hooks map[*Packet][]Hook
sources map[*Packet][]*Packet
targets map[*Packet][]*Packet
receives map[*Packet][]*Packet
Expand All @@ -22,7 +22,7 @@ type Tracer struct {
// NewTracer initializes a new Tracer instance.
func NewTracer() *Tracer {
return &Tracer{
handlers: make(map[*Packet][]Handler),
hooks: make(map[*Packet][]Hook),
sources: make(map[*Packet][]*Packet),
targets: make(map[*Packet][]*Packet),
receives: make(map[*Packet][]*Packet),
Expand All @@ -32,12 +32,12 @@ func NewTracer() *Tracer {
}
}

// AddHandler adds a Handler to be invoked when a packet completes processing.
func (t *Tracer) AddHandler(pck *Packet, handler Handler) {
// AddHook adds a Handler to be invoked when a packet completes processing.
func (t *Tracer) AddHook(pck *Packet, hook Hook) {
t.mu.Lock()
defer t.mu.Unlock()

t.handlers[pck] = append(t.handlers[pck], handler)
t.hooks[pck] = append(t.hooks[pck], hook)
}

// Transform tracks the transformation of a source packet into a target packet.
Expand Down Expand Up @@ -132,7 +132,7 @@ func (t *Tracer) Close() {
reader.Receive(New(types.NewError(ErrDroppedPacket)))
}

t.handlers = make(map[*Packet][]Handler)
t.hooks = make(map[*Packet][]Hook)
t.sources = make(map[*Packet][]*Packet)
t.targets = make(map[*Packet][]*Packet)
t.receives = make(map[*Packet][]*Packet)
Expand Down Expand Up @@ -218,15 +218,15 @@ func (t *Tracer) handle(pck *Packet) {
return
}

if handlers := t.handlers[pck]; len(handlers) > 0 {
if hooks := t.hooks[pck]; len(hooks) > 0 {
merged := Merge(receives)

delete(t.handlers, pck)
delete(t.hooks, pck)
delete(t.receives, pck)

t.mu.Unlock()
for _, handler := range handlers {
handler.Handle(merged)
for _, hook := range hooks {
hook.Handle(merged)
}
t.mu.Lock()
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/packet/tracer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestTracer_AddHandler(t *testing.T) {
func TestTracer_AddHook(t *testing.T) {
w1 := NewWriter()
defer w1.Close()

Expand Down Expand Up @@ -39,7 +39,7 @@ func TestTracer_AddHandler(t *testing.T) {
w2.Receive()

count := 0
tr.AddHandler(pck1, HandlerFunc(func(pck *Packet) {
tr.AddHook(pck1, HookFunc(func(pck *Packet) {
count++
}))

Expand Down
67 changes: 58 additions & 9 deletions pkg/packet/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (

// Writer represents a packet writer that sends packets to linked readers.
type Writer struct {
readers []*Reader
receives [][]*Packet
in chan *Packet
out chan *Packet
done chan struct{}
mu sync.Mutex
readers []*Reader
receives [][]*Packet
in chan *Packet
out chan *Packet
done chan struct{}
inboundHooks []Hook
outboundHooks []Hook
mu sync.Mutex
}

// Write sends a packet to the writer and returns the received packet or None if the write fails.
Expand Down Expand Up @@ -56,13 +58,21 @@ func NewWriter() *Writer {
case pck = <-w.in:
case <-w.done:
w.mu.Lock()

receives := w.receives
w.readers = nil
w.receives = nil

w.mu.Unlock()

for range receives {
w.out <- New(types.NewError(ErrDroppedPacket))
pck := New(types.NewError(ErrDroppedPacket))

for _, hook := range w.inboundHooks {
hook.Handle(pck)
}

w.out <- pck
}
return
}
Expand All @@ -87,6 +97,34 @@ func NewWriter() *Writer {
return w
}

// AddInboundHook adds a handler to process inbound packets.
func (w *Writer) AddInboundHook(hook Hook) {
w.mu.Lock()
defer w.mu.Unlock()

for _, h := range w.inboundHooks {
if h == hook {
return
}
}

w.inboundHooks = append(w.inboundHooks, hook)
}

// AddOutboundHook adds a handler to process outbound packets.
func (w *Writer) AddOutboundHook(hook Hook) {
w.mu.Lock()
defer w.mu.Unlock()

for _, h := range w.outboundHooks {
if h == hook {
return
}
}

w.outboundHooks = append(w.outboundHooks, hook)
}

// Link connects a reader to the writer.
func (w *Writer) Link(reader *Reader) {
w.mu.Lock()
Expand Down Expand Up @@ -118,7 +156,13 @@ func (w *Writer) Write(pck *Packet) int {
}
}

w.receives = append(w.receives, receives)
if count > 0 {
w.receives = append(w.receives, receives)

for _, hook := range w.outboundHooks {
hook.Handle(pck)
}
}

return count
}
Expand Down Expand Up @@ -170,8 +214,13 @@ func (w *Writer) receive(pck *Packet, reader *Reader) bool {
}

w.receives = w.receives[1:]
pck := Merge(receives)

w.in <- pck

w.in <- Merge(receives)
for _, hook := range w.inboundHooks {
hook.Handle(pck)
}
}

return true
Expand Down
Loading
Loading