Skip to content

Cherry-pick from tailscale #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,25 @@ type ServerConfig struct {

hostKeys []Signer

// ImplictAuthMethod is sent to the client in the list of acceptable
// authentication methods. To make an authentication decision based on
// connection metadata use NoClientAuthCallback. If NoClientAuthCallback is
// nil, the value is unused.
ImplictAuthMethod string

// NoClientAuth is true if clients are allowed to connect without
// authenticating.
// To determine NoClientAuth at runtime, set NoClientAuth to true
// and the optional NoClientAuthCallback to a non-nil value.
NoClientAuth bool

// NoClientAuthCallback, if non-nil, is called when a user
// attempts to authenticate with auth method "none".
// NoClientAuth must also be set to true for this be used, or
// this func is unused.
// If the function returns ErrDenied, the connection is terminated.
NoClientAuthCallback func(ConnMetadata) (*Permissions, error)

// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
Expand All @@ -78,6 +93,7 @@ type ServerConfig struct {

// PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password.
// If the function returns ErrDenied, the connection is terminated.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)

// PublicKeyCallback, if non-nil, is called when a client
Expand All @@ -88,6 +104,7 @@ type ServerConfig struct {
// offered is in fact used to authenticate. To record any data
// depending on the public key, store it inside a
// Permissions.Extensions entry.
// If the function returns ErrDenied, the connection is terminated.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)

// KeyboardInteractiveCallback, if non-nil, is called when
Expand All @@ -97,6 +114,7 @@ type ServerConfig struct {
// Challenge rounds. To avoid information leaks, the client
// should be presented a challenge even if the user is
// unknown.
// If the function returns ErrDenied, the connection is terminated.
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)

// AuthLogCallback, if non-nil, is called to log all authentication
Expand Down Expand Up @@ -292,6 +310,19 @@ func isAcceptableAlgo(algo string) bool {
return false
}

// WithBannerError is an error wrapper type that can be returned from an authentication
// function to additionally write out a banner error message.
type WithBannerError struct {
Err error
Message string
}

func (e WithBannerError) Unwrap() error {
return e.Err
}

func (e WithBannerError) Error() string { return e.Err.Error() }

func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required")
Expand Down Expand Up @@ -389,12 +420,19 @@ func (l ServerAuthError) Error() string {
return "[" + strings.Join(errs, ", ") + "]"
}

// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
// 'none' authentication to discover available methods.
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
var (
// ErrDenied can be returned from an authentication callback to inform the
// client that access is denied and that no further attempt will be accepted
// on the connection.
ErrDenied = errors.New("ssh: access denied")

// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
// 'none' authentication to discover available methods.
// It is returned in ServerAuthError.Errors from NewServerConn.
ErrNoAuth = errors.New("ssh: no auth passed yet")
)

func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
Expand Down Expand Up @@ -455,7 +493,11 @@ userAuthLoop:
switch userAuthReq.Method {
case "none":
if config.NoClientAuth {
authErr = nil
if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s)
} else {
authErr = nil
}
}

// allow initial attempt of 'none' without penalty
Expand Down Expand Up @@ -639,6 +681,25 @@ userAuthLoop:
break userAuthLoop
}

var w WithBannerError
if errors.As(authErr, &w) && w.Message != "" {
bannerMsg := &userAuthBannerMsg{Message: w.Message}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
if errors.Is(authErr, ErrDenied) {
var failureMsg userAuthFailureMsg
if config.ImplictAuthMethod != "" {
failureMsg.Methods = []string{config.ImplictAuthMethod}
}
if err := s.transport.writePacket(Marshal(failureMsg)); err != nil {
return nil, err
}

return nil, authErr
}

authFailures++
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
// If we have hit the max attempts, don't bother sending the
Expand Down Expand Up @@ -666,6 +727,9 @@ userAuthLoop:
}

var failureMsg userAuthFailureMsg
if config.NoClientAuthCallback != nil && config.ImplictAuthMethod != "" {
failureMsg.Methods = append(failureMsg.Methods, config.ImplictAuthMethod)
}
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
Expand Down
51 changes: 51 additions & 0 deletions ssh/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,54 @@ func TestHostKeyAlgorithms(t *testing.T) {
t.Fatal("succeeded connecting with unknown hostkey algorithm")
}
}

func TestServerClientAuthCallback(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()

userCh := make(chan string, 1)

serverConf := &ServerConfig{
NoClientAuth: true,
NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
userCh <- conn.User()
return nil, nil
},
}
const someUsername = "some-username"

serverConf.AddHostKey(testSigners["ecdsa"])
clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
User: someUsername,
}

go func() {
_, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil {
t.Errorf("server handshake: %v", err)
userCh <- "error"
return
}
go DiscardRequests(reqs)
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()

conn, _, _, err := NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("client handshake: %v", err)
return
}
conn.Close()

got := <-userCh
if got != someUsername {
t.Errorf("username = %q; want %q", got, someUsername)
}
}