Skip to content

Commit

Permalink
Convert lib/sshutils to use slog (#49829)
Browse files Browse the repository at this point in the history
In addition to the conversion, the reverstunnel server now has
a slog.Logger provided so that it can be passed into the
sshutils.Server, is not used otherwise. A similar migration will
be performed in the future for the reversetunnel package.
  • Loading branch information
rosstimothy authored Dec 10, 2024
1 parent 13d077f commit c80bdf8
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 120 deletions.
20 changes: 14 additions & 6 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -186,8 +187,12 @@ type Config struct {
Component string

// Log specifies the logger
// TODO(tross): remove this once Logger is used everywhere
Log log.FieldLogger

// Logger specifies the logger
Logger *slog.Logger

// FIPS means Teleport was started in a FedRAMP/FIPS 140-2 compliant
// configuration.
FIPS bool
Expand Down Expand Up @@ -260,13 +265,16 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.Component == "" {
cfg.Component = teleport.Component(teleport.ComponentProxy, teleport.ComponentServer)
}
logger := cfg.Log
if cfg.Log == nil {
logger = log.StandardLogger()
cfg.Log = log.StandardLogger()
}
cfg.Log = logger.WithFields(log.Fields{
teleport.ComponentKey: cfg.Component,
})
cfg.Log = cfg.Log.WithField(teleport.ComponentKey, cfg.Component)

if cfg.Logger == nil {
cfg.Logger = slog.Default()
}
cfg.Logger = cfg.Logger.With(teleport.ComponentKey, cfg.Component)

if cfg.LockWatcher == nil {
return trace.BadParameter("missing parameter LockWatcher")
}
Expand Down Expand Up @@ -345,7 +353,7 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) {
sshutils.AuthMethods{
PublicKey: srv.keyAuth,
},
sshutils.SetLogger(cfg.Log),
sshutils.SetLogger(cfg.Logger),
sshutils.SetLimiter(cfg.Limiter),
sshutils.SetCiphers(cfg.Ciphers),
sshutils.SetKEXAlgorithms(cfg.KEXAlgorithms),
Expand Down
1 change: 1 addition & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4408,6 +4408,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
Logger: process.logger,
LockWatcher: lockWatcher,
PeerClient: peerClient,
NodeWatcher: nodeWatcher,
Expand Down
54 changes: 27 additions & 27 deletions lib/sshutils/scp/scp.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ package scp

import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -94,7 +95,7 @@ type Config struct {
// this command will be run on the server
RunOnServer bool
// Log optionally specifies the logger
Log log.FieldLogger
Log *slog.Logger
}

// Command is an API that describes command operations
Expand Down Expand Up @@ -174,20 +175,18 @@ func CreateUploadCommand(cfg Config) (Command, error) {
func (c *Config) CheckAndSetDefaults() error {
logger := c.Log
if logger == nil {
logger = log.StandardLogger()
}
c.Log = logger.WithFields(log.Fields{
teleport.ComponentKey: "SCP",
teleport.ComponentFields: log.Fields{
"LocalAddr": c.Flags.LocalAddr,
"RemoteAddr": c.Flags.RemoteAddr,
"Target": c.Flags.Target,
"PreserveAttrs": c.Flags.PreserveAttrs,
"User": c.User,
"RunOnServer": c.RunOnServer,
"RemoteLocation": c.RemoteLocation,
},
})
logger = slog.Default()
}
c.Log = logger.With(
teleport.ComponentKey, "SCP",
"local_addr", c.Flags.LocalAddr,
"remote_addr", c.Flags.RemoteAddr,
"target", c.Flags.Target,
"preserve_attrs", c.Flags.PreserveAttrs,
"user", c.User,
"run_on_server", c.RunOnServer,
"remote_location", c.RemoteLocation,
)
if c.FileSystem == nil {
c.FileSystem = &localFileSystem{}
}
Expand Down Expand Up @@ -216,7 +215,7 @@ func CreateCommand(cfg Config) (Command, error) {
// to teleport can pretend it launches real SCP behind the scenes
type command struct {
Config
log log.FieldLogger
log *slog.Logger
}

// Execute implements SSH file copy (SCP). It is called on both tsh (client)
Expand Down Expand Up @@ -301,7 +300,7 @@ func (cmd *command) serveSource(ch io.ReadWriter) (retErr error) {
}
}

cmd.log.Debug("Send completed.")
cmd.log.DebugContext(context.Background(), "Send completed")
return nil
}

Expand All @@ -315,7 +314,7 @@ func (cmd *command) sendDir(r *reader, ch io.ReadWriter, fileInfo FileInfo) erro
return trace.Wrap(err)
}

cmd.log.Debug("sendDir got OK")
cmd.log.DebugContext(context.Background(), "sendDir got OK")

fileInfos, err := fileInfo.ReadDir()
if err != nil {
Expand Down Expand Up @@ -381,7 +380,7 @@ func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) err
func (cmd *command) sendErr(ch io.Writer, err error) {
out := fmt.Sprintf("%c%s\n", byte(ErrByte), err)
if _, err := ch.Write([]byte(out)); err != nil {
cmd.log.Debugf("Failed sending SCP error message to the remote side: %v.", err)
cmd.log.DebugContext(context.Background(), "Failed sending SCP error message to the remote side", "error", err)
}
}

Expand Down Expand Up @@ -447,7 +446,7 @@ func (cmd *command) serveSink(ch io.ReadWriter) error {
}

func (cmd *command) processCommand(ch io.ReadWriter, st *state, b byte, line string) error {
cmd.log.Debugf("<- %v %v", string(b), line)
cmd.log.DebugContext(context.Background(), "processing command", "b", string(b), "line", line)
switch b {
case WarnByte, ErrByte:
return trace.Errorf("error from sender: %q", line)
Expand Down Expand Up @@ -487,7 +486,8 @@ func (cmd *command) processCommand(ch io.ReadWriter, st *state, b byte, line str
}

func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) error {
cmd.log.Debugf("scp.receiveFile(%v): %v", cmd.Flags.Target, fc.Name)
ctx := context.Background()
cmd.log.DebugContext(ctx, "processing file copy request", "targets", cmd.Flags.Target, "file_name", fc.Name)

// Unless target specifies a file, use the file name from the command
path := cmd.Flags.Target[0]
Expand Down Expand Up @@ -534,12 +534,12 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro
}
}

cmd.log.Debugf("File %v(%v) copied to %v.", fc.Name, fc.Length, path)
cmd.log.DebugContext(ctx, "File successfully copied", "file", fc.Name, "size", fc.Length, "destination", path)
return nil
}

func (cmd *command) receiveDir(st *state, fc newFileCmd, ch io.ReadWriter) error {
cmd.log.Debugf("scp.receiveDir(%v): %v", cmd.Flags.Target, fc.Name)
cmd.log.DebugContext(context.Background(), "processing directory copy request", "targets", cmd.Flags.Target, "name", fc.Name)

if cmd.FileSystem.IsDir(cmd.Flags.Target[0]) {
// Copying into an existing directory? append to it:
Expand All @@ -561,7 +561,7 @@ func (cmd *command) receiveDir(st *state, fc newFileCmd, ch io.ReadWriter) error

func (cmd *command) sendDirMode(r *reader, ch io.Writer, fileInfo FileInfo) error {
out := fmt.Sprintf("D%04o 0 %s\n", fileInfo.GetModePerm(), fileInfo.GetName())
cmd.log.WithField("cmd", out).Debug("Send directory mode.")
cmd.log.DebugContext(context.Background(), "Sending directory mode", "cmd", out)
_, err := io.WriteString(ch, out)
if err != nil {
return trace.Wrap(err)
Expand All @@ -582,7 +582,7 @@ func (cmd *command) sendFileTimes(r *reader, ch io.Writer, fileInfo FileInfo) er
fileInfo.GetModTime().Unix(),
fileInfo.GetAccessTime().Unix(),
)
cmd.log.WithField("cmd", out).Debug("Send file times.")
cmd.log.DebugContext(context.Background(), "Sending file times", "cmd", out)
_, err := io.WriteString(ch, out)
if err != nil {
return trace.Wrap(err)
Expand All @@ -596,7 +596,7 @@ func (cmd *command) sendFileMode(r *reader, ch io.Writer, fileInfo FileInfo) err
fileInfo.GetSize(),
fileInfo.GetName(),
)
cmd.log.WithField("cmd", out).Debug("Send file mode.")
cmd.log.DebugContext(context.Background(), "Sending file mode", "cmd", out)
_, err := io.WriteString(ch, out)
if err != nil {
return trace.Wrap(err)
Expand Down
39 changes: 18 additions & 21 deletions lib/sshutils/scp/scp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package scp

import (
"bytes"
"context"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
Expand All @@ -30,7 +32,6 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
Expand All @@ -48,7 +49,7 @@ func TestSend(t *testing.T) {
atime := testNow.Add(1 * time.Second)
dirModtime := testNow.Add(2 * time.Second)
dirAtime := testNow.Add(3 * time.Second)
logger := logrus.WithField(teleport.ComponentKey, "t:send")
logger := utils.NewSlogLoggerForTests().With(teleport.ComponentKey, "send")
testCases := []struct {
desc string
config Config
Expand Down Expand Up @@ -112,7 +113,7 @@ func TestReceive(t *testing.T) {
atime := testNow.Add(1 * time.Second)
dirModtime := testNow.Add(2 * time.Second)
dirAtime := testNow.Add(3 * time.Second)
logger := logrus.WithField(teleport.ComponentKey, "t:recv")
logger := utils.NewSlogLoggerForTests().With(teleport.ComponentKey, "recv")
testCases := []struct {
desc string
config Config
Expand Down Expand Up @@ -172,7 +173,7 @@ func TestReceive(t *testing.T) {
for _, tt := range testCases {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
logger := logger.WithField("test", tt.desc)
logger := logger.With("test", tt.desc)
t.Parallel()

sourceDir := t.TempDir()
Expand Down Expand Up @@ -235,7 +236,7 @@ func TestSCPFailsIfNoSource(t *testing.T) {
//
// See https://github.com/gravitational/teleport/issues/5497
func TestReceiveIntoExistingDirectory(t *testing.T) {
logger := logrus.WithField("test", t.Name())
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
config := newTargetConfigWithFS("dir",
Flags{PreserveAttrs: true, Recursive: true},
newTestFS(logger, newDir("dir")),
Expand Down Expand Up @@ -278,7 +279,7 @@ func TestReceiveIntoExistingDirectory(t *testing.T) {
//
// See https://github.com/gravitational/teleport/issues/5695
func TestReceiveIntoNonExistingDirectoryFailsWithCorrectMessage(t *testing.T) {
logger := logrus.WithField("test", t.Name())
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
// Target configuration with no existing directory
root := t.TempDir()
config := newTargetConfigWithFS(filepath.Join(root, "dir"),
Expand Down Expand Up @@ -306,7 +307,7 @@ func TestReceiveIntoNonExistingDirectoryFailsWithCorrectMessage(t *testing.T) {
// TestCopyIntoNestedNonExistingDirectoriesDoesNotCreateIntermediateDirectories validates that copying a directory
// into a remote '/path/to/remote' where '/path/to' does not exist causes an error.
func TestCopyIntoNestedNonExistingDirectoriesDoesNotCreateIntermediateDirectories(t *testing.T) {
logger := logrus.WithField("test", t.Name())
logger := utils.NewSlogLoggerForTests().With("test", t.Name())

config := newTargetConfig("non-existing/remote_dir", Flags{Recursive: true})
sourceFS := newTestFS(logger, newDir("dir"))
Expand Down Expand Up @@ -631,30 +632,30 @@ func newCmd(name string, args ...string) (cmd *exec.Cmd, stdin io.WriteCloser, s

// newTestFS creates a new test FileSystem using the specified logger
// and the set of top-level files
func newTestFS(logger logrus.FieldLogger, files ...*testFileInfo) *testFS {
func newTestFS(logger *slog.Logger, files ...*testFileInfo) *testFS {
fs := newEmptyTestFS(logger)
addFiles(fs.fs, files...)
return fs
}

// newEmptyTestFS creates a new test FileSystem without content
func newEmptyTestFS(logger logrus.FieldLogger) *testFS {
func newEmptyTestFS(logger *slog.Logger) *testFS {
return &testFS{
fs: make(map[string]*testFileInfo),
l: logger,
}
}

func (r *testFS) IsDir(path string) bool {
r.l.WithField("path", path).Debug("IsDir.")
r.l.DebugContext(context.Background(), "IsDir", "path", path)
if fi, exists := r.fs[path]; exists {
return fi.IsDir()
}
return false
}

func (r *testFS) GetFileInfo(path string) (FileInfo, error) {
r.l.WithField("path", path).Debug("GetFileInfo.")
r.l.DebugContext(context.Background(), "GetFileInfo", "path", path)
fi, exists := r.fs[path]
if !exists {
return nil, newErrMissingFile(path)
Expand All @@ -663,7 +664,7 @@ func (r *testFS) GetFileInfo(path string) (FileInfo, error) {
}

func (r *testFS) MkDir(path string, mode int) error {
r.l.WithFields(logrus.Fields{"path": path, "mode": mode}).Debug("MkDir.")
r.l.DebugContext(context.Background(), "MkDir", "path", path, "mode", mode)
_, exists := r.fs[path]
if exists {
return trace.AlreadyExists("directory %v already exists", path)
Expand All @@ -677,7 +678,7 @@ func (r *testFS) MkDir(path string, mode int) error {
}

func (r *testFS) OpenFile(path string) (io.ReadCloser, error) {
r.l.WithField("path", path).Debug("OpenFile.")
r.l.DebugContext(context.Background(), "OpenFile", "path", path)
fi, exists := r.fs[path]
if !exists {
return nil, newErrMissingFile(path)
Expand All @@ -687,7 +688,7 @@ func (r *testFS) OpenFile(path string) (io.ReadCloser, error) {
}

func (r *testFS) CreateFile(path string, length uint64) (io.WriteCloser, error) {
r.l.WithFields(logrus.Fields{"path": path, "len": length}).Debug("CreateFile.")
r.l.DebugContext(context.Background(), "CreateFile", "path", path, "len", length)
baseDir := filepath.Dir(path)
if _, exists := r.fs[baseDir]; baseDir != "." && !exists {
return nil, newErrMissingFile(baseDir)
Expand All @@ -704,7 +705,7 @@ func (r *testFS) CreateFile(path string, length uint64) (io.WriteCloser, error)
}

func (r *testFS) Chmod(path string, mode int) error {
r.l.WithFields(logrus.Fields{"path": path, "mode": mode}).Debug("Chmod.")
r.l.DebugContext(context.Background(), "Chmod", "path", path, "mode", mode)
fi, exists := r.fs[path]
if !exists {
return newErrMissingFile(path)
Expand All @@ -714,11 +715,7 @@ func (r *testFS) Chmod(path string, mode int) error {
}

func (r *testFS) Chtimes(path string, atime, mtime time.Time) error {
r.l.WithFields(logrus.Fields{
"path": path,
"atime": atime,
"mtime": mtime,
}).Debug("Chtimes.")
r.l.DebugContext(context.Background(), "Chtimes", "path", path, "atime", atime, "mtime", mtime)
fi, exists := r.fs[path]
if !exists {
return newErrMissingFile(path)
Expand All @@ -730,7 +727,7 @@ func (r *testFS) Chtimes(path string, atime, mtime time.Time) error {

// testFS implements a fake FileSystem
type testFS struct {
l logrus.FieldLogger
l *slog.Logger
fs map[string]*testFileInfo
}

Expand Down
Loading

0 comments on commit c80bdf8

Please sign in to comment.