diff --git a/conn.go b/conn.go index 7138665..da84fa6 100644 --- a/conn.go +++ b/conn.go @@ -53,6 +53,7 @@ type Socket interface { Send(m Message) error SendMessages(m []Message) error Receive() ([]Message, error) + ReceiveToBuf(buf []byte) ([]Message, error) } // Dial dials a connection to netlink, using the specified netlink family. @@ -135,7 +136,7 @@ func (c *Conn) Execute(m Message) ([]Message, error) { return nil, err } - res, err := c.lockedReceive() + res, err := c.lockedReceive(nil) if err != nil { return nil, err } @@ -228,14 +229,27 @@ func (c *Conn) Receive() ([]Message, error) { c.mu.RLock() defer c.mu.RUnlock() - return c.lockedReceive() + return c.lockedReceive(nil) +} + +// ReceiveToBuf receives one or more messages from netlink to buffer. Multi-part messages are +// handled transparently and returned as a single slice of Messages, with the +// final empty "multi-part done" message removed. +// +// If any of the messages indicate a netlink error, that error will be returned. +func (c *Conn) ReceiveToBuf(buf []byte) ([]Message, error) { + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + return c.lockedReceive(buf) } // lockedReceive implements Receive, but must be called with c.mu acquired for reading. // We rely on the kernel to deal with concurrent reads and writes to the netlink // socket itself. -func (c *Conn) lockedReceive() ([]Message, error) { - msgs, err := c.receive() +func (c *Conn) lockedReceive(buf []byte) ([]Message, error) { + msgs, err := c.receive(buf) if err != nil { c.debug(func(d *debugger) { d.debugf(1, "recv: err: %v", err) @@ -266,7 +280,7 @@ func (c *Conn) lockedReceive() ([]Message, error) { // receive is the internal implementation of Conn.Receive, which can be called // recursively to handle multi-part messages. -func (c *Conn) receive() ([]Message, error) { +func (c *Conn) receive(buf []byte) ([]Message, error) { // NB: All non-nil errors returned from this function *must* be of type // OpError in order to maintain the appropriate contract with callers of // this package. @@ -276,7 +290,14 @@ func (c *Conn) receive() ([]Message, error) { var res []Message for { - msgs, err := c.sock.Receive() + var msgs []Message + var err error + + if len(buf) == 0 { + msgs, err = c.sock.Receive() + } else { + msgs, err = c.sock.ReceiveToBuf(buf) + } if err != nil { return nil, newOpError("receive", err) } diff --git a/conn_linux.go b/conn_linux.go index 4af18c9..85844f4 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -165,6 +165,38 @@ func (c *conn) Receive() ([]Message, error) { return msgs, nil } +// ReceiveToBuf receives one or more Messages from netlink to buffer. +// Buffer should be aligned on a 4-byte boundary. +func (c *conn) ReceiveToBuf(buf []byte) ([]Message, error) { + // Read out all available messages + n, _, _, _, err := c.s.Recvmsg(context.Background(), buf, nil, 0) + if err != nil { + return nil, err + } + + alignedLen := nlmsgAlign(n) + if alignedLen > len(buf) { + alignedLen -= nlmsgAlignTo + } + + raw, err := syscall.ParseNetlinkMessage(buf[:alignedLen]) + if err != nil { + return nil, err + } + + msgs := make([]Message, 0, len(raw)) + for _, r := range raw { + m := Message{ + Header: sysToHeader(r.Header), + Data: r.Data, + } + + msgs = append(msgs, m) + } + + return msgs, nil +} + // Close closes the connection. func (c *conn) Close() error { return c.s.Close() } diff --git a/conn_others.go b/conn_others.go index 4c5e739..51a8fab 100644 --- a/conn_others.go +++ b/conn_others.go @@ -24,7 +24,8 @@ type conn struct{} func dial(_ int, _ *Config) (*conn, uint32, error) { return nil, 0, errUnimplemented } func newError(_ int) error { return errUnimplemented } -func (c *conn) Send(_ Message) error { return errUnimplemented } -func (c *conn) SendMessages(_ []Message) error { return errUnimplemented } -func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented } -func (c *conn) Close() error { return errUnimplemented } +func (c *conn) Send(_ Message) error { return errUnimplemented } +func (c *conn) SendMessages(_ []Message) error { return errUnimplemented } +func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented } +func (c *conn) ReceiveToBuf(buf []byte) ([]Message, error) { return nil, errUnimplemented } +func (c *conn) Close() error { return errUnimplemented } diff --git a/nltest/nltest.go b/nltest/nltest.go index 2065bab..19fe052 100644 --- a/nltest/nltest.go +++ b/nltest/nltest.go @@ -202,6 +202,10 @@ func (c *socket) Receive() ([]netlink.Message, error) { return msgs, err } +func (c *socket) ReceiveToBuf(buf []byte) ([]netlink.Message, error) { + return nil, nil +} + func panicf(format string, a ...interface{}) { panic(fmt.Sprintf(format, a...)) }