Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Unix forwarding server implementations #196

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) {

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
l := newLocalTCPListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
Expand Down
2 changes: 2 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Server struct {
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding (direct-streamlocal@openssh.com), denies all if nil
ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding (streamlocal-forward@openssh.com), denies all if nil
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions

Expand Down
4 changes: 2 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestAddHostKey(t *testing.T) {
}

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
testBytes := []byte("Hello world\n")
s := &Server{
Handler: func(s Session) {
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestServerShutdown(t *testing.T) {
}

func TestServerClose(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
s := &Server{
Handler: func(s Session) {
time.Sleep(5 * time.Second)
Expand Down
19 changes: 15 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,25 @@ func (srv *Server) serveOnce(l net.Listener) error {
return e
}
srv.ChannelHandlers = map[string]ChannelHandler{
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"direct-streamlocal@openssh.com": DirectStreamLocalHandler,
}

forwardedTCPHandler := &ForwardedTCPHandler{}
forwardedUnixHandler := &ForwardedUnixHandler{}
srv.RequestHandlers = map[string]RequestHandler{
"tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest,
"cancel-streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest,
}

srv.HandleConn(conn)
return nil
}

func newLocalListener() net.Listener {
func newLocalTCPListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
Expand Down Expand Up @@ -64,7 +75,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
}

func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
l := newLocalListener()
l := newLocalTCPListener()
go srv.serveOnce(l)
return newClientSession(t, l.Addr().String(), cfg)
}
Expand Down
13 changes: 13 additions & 0 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"crypto/subtle"
"errors"
"net"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -29,6 +30,9 @@ const (
// DefaultHandler is the default Handler used by Serve.
var DefaultHandler Handler

// ErrReject is returned by some callbacks to reject a request.
var ErrRejected = errors.New("rejected")

// Option is a functional option handler for Server.
type Option func(*Server) error

Expand Down Expand Up @@ -64,6 +68,15 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool

// LocalUnixForwardingCallback is a hook for allowing unix forwarding
// (direct-streamlocal@openssh.com)
type LocalUnixForwardingCallback func(ctx Context, socketPath string) bool

// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
// (streamlocal-forward@openssh.com). Returning ErrRejected will reject the
// request.
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)

// ServerConfigCallback is a hook for creating custom default server configs
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig

Expand Down
242 changes: 242 additions & 0 deletions streamlocal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package ssh

import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"os"
"path/filepath"
"sync"
"syscall"

gossh "golang.org/x/crypto/ssh"
)

const (
forwardedUnixChannelType = "forwarded-streamlocal@openssh.com"
)

// directStreamLocalChannelData data struct as specified in OpenSSH's protocol
// extensions document, Section 2.4.
// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD
type directStreamLocalChannelData struct {
SocketPath string

Reserved1 string
Reserved2 uint32
}

// DirectStreamLocalHandler provides Unix forwarding from client -> server. It
// can be enabled by adding it to the server's ChannelHandlers under
// `direct-streamlocal@openssh.com`.
//
// Unix socket support on Windows is not widely available, so this handler may
// not work on all Windows installations and is not tested on Windows.
func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
var d directStreamLocalChannelData
err := gossh.Unmarshal(newChan.ExtraData(), &d)
if err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error())
return
}

if srv.LocalUnixForwardingCallback == nil || !srv.LocalUnixForwardingCallback(ctx, d.SocketPath) {
newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
return
}

var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "unix", d.SocketPath)
if err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error()))
return
}

ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go gossh.DiscardRequests(reqs)

bicopy(ctx, ch, dconn)
}

// remoteUnixForwardRequest describes the extra data sent in a
// streamlocal-forward@openssh.com containing the socket path to bind to.
type remoteUnixForwardRequest struct {
SocketPath string
}

// remoteUnixForwardChannelData describes the data sent as the payload in the new
// channel request when a Unix connection is accepted by the listener.
type remoteUnixForwardChannelData struct {
SocketPath string
Reserved uint32
}

// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and
// adding the HandleSSHRequest callback to the server's RequestHandlers under
// `streamlocal-forward@openssh.com` and
// `cancel-streamlocal-forward@openssh.com`
//
// Unix socket support on Windows is not widely available, so this handler may
// not work on all Windows installations and is not tested on Windows.
type ForwardedUnixHandler struct {
sync.Mutex
forwards map[string]net.Listener
}

func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
}
h.Unlock()
conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
if !ok {
// TODO: log cast failure
return false, nil
}

switch req.Type {
case "streamlocal-forward@openssh.com":
var reqPayload remoteUnixForwardRequest
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
// TODO: log parse failure
return false, nil
}

if srv.ReverseUnixForwardingCallback == nil {
return false, []byte("unix forwarding is disabled")
}

addr := reqPayload.SocketPath
h.Lock()
_, ok := h.forwards[addr]
h.Unlock()
if ok {
// TODO: log failure
return false, nil
}

ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
if err != nil {
if errors.Is(err, ErrRejected) {
return false, []byte("unix forwarding is disabled")
}
// TODO: log unix listen failure
return false, nil
}

// The listener needs to successfully start before it can be added to
// the map, so we don't have to worry about checking for an existing
// listener as you can't listen on the same socket twice.
//
// This is also what the TCP version of this code does.
h.Lock()
h.forwards[addr] = ln
h.Unlock()

ctx, cancel := context.WithCancel(ctx)
go func() {
<-ctx.Done()
_ = ln.Close()
}()
go func() {
defer cancel()

for {
c, err := ln.Accept()
if err != nil {
// closed below
break
}
payload := gossh.Marshal(&remoteUnixForwardChannelData{
SocketPath: addr,
})

go func() {
ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload)
if err != nil {
_ = c.Close()
return
}
go gossh.DiscardRequests(reqs)
bicopy(ctx, ch, c)
}()
}

h.Lock()
ln2, ok := h.forwards[addr]
if ok && ln2 == ln {
delete(h.forwards, addr)
}
h.Unlock()
_ = ln.Close()
}()

return true, nil

case "cancel-streamlocal-forward@openssh.com":
var reqPayload remoteUnixForwardRequest
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
// TODO: log parse failure
return false, nil
}
h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
h.Unlock()
if ok {
_ = ln.Close()
}
return true, nil

default:
return false, nil
}
}

// unlink removes files and unlike os.Remove, directories are kept.
func unlink(path string) error {
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
// for more details.
for {
err := syscall.Unlink(path)
if !errors.Is(err, syscall.EINTR) {
return err
}
}
}

// SimpleUnixReverseForwardingCallback provides a basic implementation for
// ReverseUnixForwardingCallback. The parent directory will be created (with
// os.MkdirAll), and existing files with the same name will be removed.
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
// Create socket parent dir if not exists.
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0700)
if err != nil {
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
}

// Remove existing socket if it exists. We do not use os.Remove() here
// so that directories are kept. Note that it's possible that we will
// overwrite a regular file here. Both of these behaviors match OpenSSH,
// however, which is why we unlink.
err = unlink(socketPath)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
}

ln, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
}

return ln, err
}
Loading