Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

netlink: add ExecuteFunc method to process large responses with low mem overhead #214

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 95 additions & 30 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,46 @@ func (c *Conn) Execute(m Message) ([]Message, error) {
return res, nil
}

// ExecuteFunc sends a single Message to netlink using Send and passes one or more
// replies obtained by Receive to the provided callback func after checking each
// reply for validity using Validate.
//
// ExecuteFunc acquires a lock for the duration of the function call which blocks
// concurrent calls to Send, SendMessages, and Receive, in order to ensure
// consistency between netlink request/reply messages - do not call methods on this
// Conn from the provided callback func or your application may deadlock.
//
// See the documentation of Send, Receive, and Validate for details about
// each function.
func (c *Conn) ExecuteFunc(m Message, cb func(Message)) error {
// Acquire the write lock and invoke the internal implementations of Send
// and Receive which require the lock already be held.
c.mu.Lock()
defer c.mu.Unlock()

req, err := c.lockedSend(m)
if err != nil {
return err
}

var validateErr error
err = c.lockedReceiveEach(func(m Message) {
if err := Validate(req, []Message{m}); err != nil {
validateErr = err
return
}
cb(m)
})
if err != nil {
return err
}
if validateErr != nil {
return validateErr
}

return nil
}

// SendMessages sends multiple Messages to netlink. The handling of
// a Header's Length, Sequence and PID fields is the same as when
// calling Send.
Expand Down Expand Up @@ -231,11 +271,27 @@ func (c *Conn) Receive() ([]Message, error) {
return c.lockedReceive()
}

// ReceiveEach receives one or more messages from netlink. Multi-part messages are
// handled transparently and a callback invoked for ech message, with the
// final empty "multi-part done" message removed.
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) ReceiveEach(cb func(Message)) error {
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()

return c.lockedReceiveEach(cb)
}

// lockedReceive implements Receive, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedReceive() ([]Message, error) {
msgs, err := c.receive()
var msgs []Message
err := c.receiveEach(func(m Message) {
msgs = append(msgs, m)
})
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
Expand All @@ -250,64 +306,73 @@ func (c *Conn) lockedReceive() ([]Message, error) {
}
})

// When using nltest, it's possible for zero messages to be returned by receive.
if len(msgs) == 0 {
return msgs, nil
}
return msgs, nil
}

// lockedReceive implements Receive, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedReceiveEach(cb func(Message)) error {
err := c.receiveEach(func(m Message) {
cb(m)

c.debug(func(d *debugger) {
d.debugf(1, "recv: %+v", m)
})
})
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
})

// Trim the final message with multi-part done indicator if
// present.
if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done {
return msgs[:len(msgs)-1], nil
return err
}

return msgs, nil
return nil
}

// receive is the internal implementation of Conn.Receive, which can be called
// recursively to handle multi-part messages.
func (c *Conn) receive() ([]Message, error) {
func (c *Conn) receiveEach(cb func(Message)) error {
// NB: All non-nil errors returned from this function *must* be of type
// OpError in order to maintain the appropriate contract with callers of
// this package.
//
// This contract also applies to functions called within this function,
// such as checkMessage.

var res []Message
for {
for more := true; more; {
msgs, err := c.sock.Receive()
if err != nil {
return nil, newOpError("receive", err)
return newOpError("receive", err)
}

// If this message is multi-part, we will need to continue looping to
// drain all the messages from the socket.
var multi bool

more = false
multipartDone := false
for _, m := range msgs {
if err := checkMessage(m); err != nil {
return nil, err
return err
}

// Does this message indicate a multi-part message?
if m.Header.Flags&Multi == 0 {
// No, check the next messages.
continue
if m.Header.Flags&Multi != 0 {
multipartDone = m.Header.Type == Done
more = !multipartDone
}

// Does this message indicate the last message in a series of
// multi-part messages from a single read?
multi = m.Header.Type != Done
}

res = append(res, msgs...)
// Trim the final message with multi-part done indicator if
// present.
if multipartDone {
msgs = msgs[:len(msgs)-1]
}

if !multi {
// No more messages coming.
return res, nil
for _, m := range msgs {
cb(m)
}
}

return nil
}

// A groupJoinLeaver is a Socket that supports joining and leaving
Expand Down
70 changes: 70 additions & 0 deletions conn_linux_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,76 @@ func TestIntegrationConn(t *testing.T) {
}
}

func TestIntegrationConnFunc(t *testing.T) {
t.Parallel()

c, err := netlink.Dial(unix.NETLINK_GENERIC, nil)
if err != nil {
t.Fatalf("failed to dial netlink: %v", err)
}

// Ask to send us an acknowledgement, which will contain an
// error code (or success) and a copy of the payload we sent in
req := netlink.Message{
Header: netlink.Header{
Flags: netlink.Request | netlink.Acknowledge,
},
}

// Perform a request using ExecuteFunc, receive replies, and validate the replies
var msgs []netlink.Message
err = c.ExecuteFunc(req, func(m netlink.Message) {
msgs = append(msgs, m)
})
if err != nil {
t.Fatalf("failed to execute request: %v", err)
}
if want, got := 1, len(msgs); want != got {
t.Fatalf("unexpected message count from netlink:\n- want: %v\n- got: %v",
want, got)
}

if err := c.Close(); err != nil {
t.Fatalf("error closing netlink connection: %v", err)
}

m := msgs[0]

if want, got := 0, int(nlenc.Uint32(m.Data[0:4])); want != got {
t.Fatalf("unexpected error code:\n- want: %v\n- got: %v", want, got)
}

if want, got := 36, int(m.Header.Length); want != got {
t.Fatalf("unexpected header length:\n- want: %v\n- got: %v", want, got)
}
if want, got := netlink.Error, m.Header.Type; want != got {
t.Fatalf("unexpected header type:\n- want: %v\n- got: %v", want, got)
}
// Recent kernel versions (> 4.14) return a 256 here instead of a 0
if want, wantAlt, got := 0, 256, int(m.Header.Flags); want != got && wantAlt != got {
t.Fatalf("unexpected header flags:\n- want: %v or %v\n- got: %v", want, wantAlt, got)
}

// Sequence number is not checked because we assign one at random when
// a Conn is created. PID is not checked because running tests in parallel
// results in only the first socket getting assigned the process's PID as
// its netlink PID.

// Skip error code and unmarshal the copy of request sent back by
// skipping the success code at bytes 0-4
var reply netlink.Message
if err := (&reply).UnmarshalBinary(m.Data[4:]); err != nil {
t.Fatalf("failed to unmarshal reply: %v", err)
}

if want, got := req.Header.Flags, reply.Header.Flags; want != got {
t.Fatalf("unexpected copy header flags:\n- want: %v\n- got: %v", want, got)
}
if want, got := len(req.Data), len(reply.Data); want != got {
t.Fatalf("unexpected copy header data length:\n- want: %v\n- got: %v", want, got)
}
}

func TestIntegrationConnConcurrentManyConns(t *testing.T) {
t.Parallel()
skipShort(t)
Expand Down