diff --git a/api/go1.99999.txt b/api/go1.99999.txt index dd6cb2c37d580e..752e6be3447d0c 100644 --- a/api/go1.99999.txt +++ b/api/go1.99999.txt @@ -1,2 +1,8 @@ pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55 pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55 +pkg net, func WithSockTrace(context.Context, *SockTrace) context.Context #58 +pkg net, func ContextSockTrace(context.Context) *SockTrace #58 +pkg net, type SockTrace struct #58 +pkg net, type SockTrace struct, DidRead func(int) #58 +pkg net, type SockTrace struct, DidWrite func(int) #58 +pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58 diff --git a/src/net/fd_posix.go b/src/net/fd_posix.go index ffb9bcf8b9e8ca..7f3aeff580c95d 100644 --- a/src/net/fd_posix.go +++ b/src/net/fd_posix.go @@ -24,6 +24,11 @@ type netFD struct { net string laddr Addr raddr Addr + + // hooks (if provided) are called after successful reads or writes with the + // number of bytes transferred. + readHook func(int) + writeHook func(int) } func (fd *netFD) setAddr(laddr, raddr Addr) { @@ -53,83 +58,125 @@ func (fd *netFD) closeWrite() error { func (fd *netFD) Read(p []byte) (n int, err error) { n, err = fd.pfd.Read(p) + if fd.readHook != nil && err == nil { + fd.readHook(n) + } runtime.KeepAlive(fd) return n, wrapSyscallError(readSyscallName, err) } func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { n, sa, err = fd.pfd.ReadFrom(p) + if fd.readHook != nil && err == nil { + fd.readHook(n) + } runtime.KeepAlive(fd) return n, sa, wrapSyscallError(readFromSyscallName, err) } func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) { n, err = fd.pfd.ReadFromInet4(p, from) + if fd.readHook != nil && err == nil { + fd.readHook(n) + } runtime.KeepAlive(fd) return n, wrapSyscallError(readFromSyscallName, err) } func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) { n, err = fd.pfd.ReadFromInet6(p, from) + if fd.readHook != nil && err == nil { + fd.readHook(n) + } runtime.KeepAlive(fd) return n, wrapSyscallError(readFromSyscallName, err) } func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags) + if fd.readHook != nil && err == nil { + fd.readHook(n + oobn) + } runtime.KeepAlive(fd) return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err) } func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa) + if fd.readHook != nil && err == nil { + fd.readHook(n + oobn) + } runtime.KeepAlive(fd) return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) } func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa) + if fd.readHook != nil && err == nil { + fd.readHook(n + oobn) + } runtime.KeepAlive(fd) return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) } func (fd *netFD) Write(p []byte) (nn int, err error) { nn, err = fd.pfd.Write(p) + if fd.writeHook != nil && err == nil { + fd.writeHook(nn) + } runtime.KeepAlive(fd) return nn, wrapSyscallError(writeSyscallName, err) } func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { n, err = fd.pfd.WriteTo(p, sa) + if fd.writeHook != nil && err == nil { + fd.writeHook(n) + } runtime.KeepAlive(fd) return n, wrapSyscallError(writeToSyscallName, err) } func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { n, err = fd.pfd.WriteToInet4(p, sa) + if fd.writeHook != nil && err == nil { + fd.writeHook(n) + } runtime.KeepAlive(fd) return n, wrapSyscallError(writeToSyscallName, err) } func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { n, err = fd.pfd.WriteToInet6(p, sa) + if fd.writeHook != nil && err == nil { + fd.writeHook(n) + } runtime.KeepAlive(fd) return n, wrapSyscallError(writeToSyscallName, err) } func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) + if fd.writeHook != nil && err == nil { + fd.writeHook(n + oobn) + } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) } func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa) + if fd.writeHook != nil && err == nil { + fd.writeHook(n + oobn) + } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) } func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa) + if fd.writeHook != nil && err == nil { + fd.writeHook(n + oobn) + } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) } diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index b3e1806ba9f48c..5c5340c1b756ff 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -28,6 +28,10 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only poll.CloseFunc(s) return nil, err } + if trace := ContextSockTrace(ctx); trace != nil { + fd.readHook = trace.DidRead + fd.writeHook = trace.DidWrite + } // This function makes a network file descriptor for the // following applications: diff --git a/src/net/socktrace.go b/src/net/socktrace.go new file mode 100644 index 00000000000000..57af2be8b75e3e --- /dev/null +++ b/src/net/socktrace.go @@ -0,0 +1,46 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "context" +) + +// SockTrace is a set of hooks to run at various operations on a network socket. +// Any particular hook may be nil. Functions may be called concurrently from +// different goroutines. +type SockTrace struct { + // DidRead is called after a successful read from the socket, where n bytes + // were read. + DidRead func(n int) + // DidWrite is called after a successful write to the socket, where n bytes + // were written. + DidWrite func(n int) + // WillOverwrite is called when the registered trace is overwritten by a + // subsequent call to WithSockTrace. The provided trace is the new trace + // that will be used. + WillOverwrite func(trace *SockTrace) +} + +// WithSockTrace returns a new context based on the provided parent +// ctx. Socket reads and writes made with the returned context will use +// the provided trace hooks. Any previous hooks registered with ctx are +// ovewritten (their WillOverwrite hook will be called). +func WithSockTrace(ctx context.Context, trace *SockTrace) context.Context { + if previous := ContextSockTrace(ctx); previous != nil && previous.WillOverwrite != nil { + previous.WillOverwrite(trace) + } + return context.WithValue(ctx, sockTraceKey{}, trace) +} + +// ContextSockTrace returns the SockTrace associated with the +// provided context. If none, it returns nil. +func ContextSockTrace(ctx context.Context) *SockTrace { + trace, _ := ctx.Value(sockTraceKey{}).(*SockTrace) + return trace +} + +// unique type to prevent assignment. +type sockTraceKey struct{}