From 401663bb86af9192db2967667d2be2ee91d789bb Mon Sep 17 00:00:00 2001 From: Eric Date: Fri, 26 Jul 2024 18:58:56 +0800 Subject: [PATCH] perf: support max idle timeout for sftp session --- pkg/auth/http.go | 2 +- pkg/handler/direct_handler.go | 1 + pkg/handler/server_ssh.go | 15 +- pkg/handler/sftp.go | 13 -- pkg/httpd/sftpvolume.go | 9 + pkg/httpd/webfolder.go | 6 + pkg/jms-sdk-go/model/session.go | 4 + pkg/proxy/domain_gateway.go | 19 --- pkg/proxy/server.go | 78 +++------ pkg/session/manager.go | 4 + pkg/srvconn/sftp_asset.go | 283 ++++++++++++++++++-------------- pkg/srvconn/sftp_node.go | 21 ++- pkg/srvconn/sftp_session.go | 38 +++++ pkg/srvconn/sftpconn.go | 51 +++++- pkg/srvconn/sftpfile.go | 81 +++++++-- pkg/srvconn/ssh.go | 1 + 16 files changed, 389 insertions(+), 237 deletions(-) create mode 100644 pkg/srvconn/sftp_session.go diff --git a/pkg/auth/http.go b/pkg/auth/http.go index 6c25562cf..f9fa453e3 100644 --- a/pkg/auth/http.go +++ b/pkg/auth/http.go @@ -37,7 +37,7 @@ func HTTPMiddleSessionAuth(jmsService *service.JMService) gin.HandlerFunc { func HTTPMiddleDebugAuth() gin.HandlerFunc { return func(ctx *gin.Context) { switch ctx.ClientIP() { - case "127.0.0.1", "localhost": + case "127.0.0.1", "localhost", "::1": return default: _ = ctx.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid host %s", ctx.ClientIP())) diff --git a/pkg/handler/direct_handler.go b/pkg/handler/direct_handler.go index 4f6f83815..b7d496806 100644 --- a/pkg/handler/direct_handler.go +++ b/pkg/handler/direct_handler.go @@ -159,6 +159,7 @@ func (d *DirectHandler) NewSFTPHandler() *SftpHandler { opts = append(opts, srvconn.WithUser(d.opts.User)) opts = append(opts, srvconn.WithRemoteAddr(addr)) opts = append(opts, srvconn.WithLoginFrom(model.LoginFromSSH)) + opts = append(opts, srvconn.WithTerminalCfg(d.opts.terminalConf)) if !d.opts.IsTokenConnection() { opts = append(opts, srvconn.WithAssets(d.opts.assets)) if len(d.opts.assets) == 1 { diff --git a/pkg/handler/server_ssh.go b/pkg/handler/server_ssh.go index aeaed8378..0889c82d9 100644 --- a/pkg/handler/server_ssh.go +++ b/pkg/handler/server_ssh.go @@ -85,7 +85,7 @@ func (s *Server) SFTPHandler(sess ssh.Session) { directSrv := NewDirectHandler(sess, s.jmsService, opts...) sftpHandler = directSrv.NewSFTPHandler() } else { - sftpHandler = NewSFTPHandler(s.jmsService, currentUser, addr) + sftpHandler = s.NewSftpHandler(currentUser, addr) } handlers := sftp.Handlers{ FileGet: sftpHandler, @@ -106,6 +106,19 @@ func (s *Server) SFTPHandler(sess ssh.Session) { logger.Infof("SFTP request %s: Handler exit.", reqID) } +func (s *Server) NewSftpHandler(user *model.User, addr string) *SftpHandler { + terminalCfg := s.GetTerminalConfig() + opts := make([]srvconn.UserSftpOption, 0, 5) + opts = append(opts, srvconn.WithUser(user)) + opts = append(opts, srvconn.WithRemoteAddr(addr)) + opts = append(opts, srvconn.WithLoginFrom(model.LoginFromSSH)) + opts = append(opts, srvconn.WithTerminalCfg(&terminalCfg)) + return &SftpHandler{ + UserSftpConn: srvconn.NewUserSftpConn(s.jmsService, opts...), + recorder: proxy.GetFTPFileRecorder(s.jmsService), + } +} + func (s *Server) LocalPortForwardingPermission(ctx ssh.Context, dstHost string, dstPort uint32) bool { logger.Debugf("LocalPortForwardingPermission: %s %s %d", ctx.User(), dstHost, dstPort) return config.GlobalConfig.EnableLocalPortForward diff --git a/pkg/handler/sftp.go b/pkg/handler/sftp.go index 4ac7719fd..14ffc0731 100644 --- a/pkg/handler/sftp.go +++ b/pkg/handler/sftp.go @@ -9,24 +9,11 @@ import ( "github.com/pkg/sftp" - "github.com/jumpserver/koko/pkg/jms-sdk-go/model" - "github.com/jumpserver/koko/pkg/jms-sdk-go/service" "github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/proxy" "github.com/jumpserver/koko/pkg/srvconn" ) -func NewSFTPHandler(jmsService *service.JMService, user *model.User, addr string) *SftpHandler { - opts := make([]srvconn.UserSftpOption, 0, 5) - opts = append(opts, srvconn.WithUser(user)) - opts = append(opts, srvconn.WithRemoteAddr(addr)) - opts = append(opts, srvconn.WithLoginFrom(model.LoginFromSSH)) - return &SftpHandler{ - UserSftpConn: srvconn.NewUserSftpConn(jmsService, opts...), - recorder: proxy.GetFTPFileRecorder(jmsService), - } -} - type SftpHandler struct { *srvconn.UserSftpConn diff --git a/pkg/httpd/sftpvolume.go b/pkg/httpd/sftpvolume.go index 8fadf81ea..3028bc393 100644 --- a/pkg/httpd/sftpvolume.go +++ b/pkg/httpd/sftpvolume.go @@ -24,6 +24,7 @@ type volumeOption struct { user *model.User asset *model.PermAsset connectToken *model.ConnectToken + terminalCfg *model.TerminalConfig } type VolumeOption func(*volumeOption) @@ -51,6 +52,13 @@ func WithConnectToken(connectToken *model.ConnectToken) VolumeOption { } } +func WithTerminalCfg(cfg *model.TerminalConfig) VolumeOption { + return func(opts *volumeOption) { + opts.terminalCfg = cfg + } + +} + func NewUserVolume(jmsService *service.JMService, opts ...VolumeOption) *UserVolume { var volOpts volumeOption for _, opt := range opts { @@ -77,6 +85,7 @@ func NewUserVolume(jmsService *service.JMService, opts ...VolumeOption) *UserVol sftpOpts = append(sftpOpts, srvconn.WithUser(volOpts.user)) sftpOpts = append(sftpOpts, srvconn.WithRemoteAddr(volOpts.addr)) sftpOpts = append(sftpOpts, srvconn.WithLoginFrom(model.LoginFromWeb)) + sftpOpts = append(sftpOpts, srvconn.WithTerminalCfg(volOpts.terminalCfg)) userSftp := srvconn.NewUserSftpConn(jmsService, sftpOpts...) rawID := fmt.Sprintf("%s@%s", volOpts.user.Username, volOpts.addr) diff --git a/pkg/httpd/webfolder.go b/pkg/httpd/webfolder.go index cf4ed4d2b..edff76cb7 100644 --- a/pkg/httpd/webfolder.go +++ b/pkg/httpd/webfolder.go @@ -22,9 +22,15 @@ func (h *webFolder) Name() string { func (h *webFolder) CheckValidation() error { apiClient := h.ws.apiClient user := h.ws.CurrentUser() + terminalCfg, err := h.ws.apiClient.GetTerminalConfig() + if err != nil { + logger.Errorf("Get terminal config failed: %s", err) + return err + } volOpts := make([]VolumeOption, 0, 5) volOpts = append(volOpts, WithUser(user)) volOpts = append(volOpts, WithAddr(h.ws.ClientIP())) + volOpts = append(volOpts, WithTerminalCfg(&terminalCfg)) params := h.ws.wsParams targetId := params.TargetId assetId := params.AssetId diff --git a/pkg/jms-sdk-go/model/session.go b/pkg/jms-sdk-go/model/session.go index e56b7e824..2b9a9ff72 100644 --- a/pkg/jms-sdk-go/model/session.go +++ b/pkg/jms-sdk-go/model/session.go @@ -128,6 +128,10 @@ var EmptyLifecycleLog = SessionLifecycleLog{} type SessionLifecycleReasonErr string +func (s SessionLifecycleReasonErr) String() string { + return string(s) +} + const ( ReasonErrConnectFailed SessionLifecycleReasonErr = "connect_failed" ReasonErrConnectDisconnect SessionLifecycleReasonErr = "connect_disconnect" diff --git a/pkg/proxy/domain_gateway.go b/pkg/proxy/domain_gateway.go index f7d9450ea..876543998 100644 --- a/pkg/proxy/domain_gateway.go +++ b/pkg/proxy/domain_gateway.go @@ -16,7 +16,6 @@ import ( ) type domainGateway struct { - domain *model.Domain dstIP string dstPort int @@ -97,24 +96,6 @@ func (d *domainGateway) getAvailableGateway() bool { d.sshClient = sshClient return true } - - for i := range d.domain.Gateways { - gateway := d.domain.Gateways[i] - if !gateway.Protocols.IsSupportProtocol(model.ProtocolSSH) { - continue - } - logger.Debugf("Domain %s try dial gateway %s", d.domain.Name, gateway.Name) - sshClient, err := d.createGatewaySSHClient(&gateway) - if err != nil { - logger.Errorf("Dial gateway %s err: %s ", gateway.Name, err) - continue - } - logger.Infof("Domain %s use gateway %s", d.domain.Name, gateway.Name) - d.sshClient = sshClient - d.selectedGateway = &gateway - return true - } - logger.Errorf("Domain Gateway %s has no available gateway", d.domain.Name) return false } diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 719953978..763e052ef 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -37,6 +37,8 @@ var ( ErrNoAuthInfo = errors.New("no auth info") ) +const localIP = "127.0.0.1" + func NewServer(conn UserConnection, jmsService *service.JMService, opts ...ConnectionOption) (*Server, error) { connOpts := &ConnectionOptions{} for _, setter := range opts { @@ -139,9 +141,8 @@ type Server struct { suFromAccount *model.BaseAccount - terminalConf *model.TerminalConfig - domainGateways *model.Domain - gateway *model.Gateway + terminalConf *model.TerminalConfig + gateway *model.Gateway sessionInfo *model.Session @@ -428,31 +429,23 @@ func (s *Server) getCacheSSHConn() (srvConn *srvconn.SSHConnection, ok bool) { return cacheConn, true } -func (s *Server) createAvailableGateWay(domain *model.Domain) (*domainGateway, error) { +func (s *Server) createAvailableGateWay() (*domainGateway, error) { asset := s.connOpts.authInfo.Asset protocol := s.connOpts.authInfo.Protocol - - var dGateway *domainGateway - switch protocol { - case srvconn.ProtocolK8s: - dstHost, dstPort, err := ParseUrlHostAndPort(asset.Address) + dstIP := asset.Address + dstPort := asset.ProtocolPort(protocol) + if protocol == srvconn.ProtocolK8s { + dstHost, dstPort1, err := ParseUrlHostAndPort(asset.Address) if err != nil { return nil, err } - dGateway = &domainGateway{ - domain: domain, - dstIP: dstHost, - dstPort: dstPort, - selectedGateway: s.gateway, - } - default: - port := asset.ProtocolPort(protocol) - dGateway = &domainGateway{ - domain: domain, - dstIP: asset.Address, - dstPort: port, - selectedGateway: s.gateway, - } + dstIP = dstHost + dstPort = dstPort1 + } + dGateway := &domainGateway{ + dstIP: dstIP, + dstPort: dstPort, + selectedGateway: s.gateway, } return dGateway, nil } @@ -462,11 +455,11 @@ func (s *Server) getK8sConConn(localTunnelAddr *net.TCPAddr) (srvConn srvconn.Se asset := s.connOpts.authInfo.Asset clusterServer := asset.Address if localTunnelAddr != nil { - originUrl, err := url.Parse(clusterServer) - if err != nil { - return nil, err + originUrl, err1 := url.Parse(clusterServer) + if err1 != nil { + return nil, err1 } - clusterServer = ReplaceURLHostAndPort(originUrl, "127.0.0.1", localTunnelAddr.Port) + clusterServer = ReplaceURLHostAndPort(originUrl, localIP, localTunnelAddr.Port) } if s.connOpts.k8sContainer != nil { return s.getContainerConn(clusterServer) @@ -515,7 +508,7 @@ func (s *Server) getRedisConn(localTunnelAddr *net.TCPAddr) (srvConn *srvconn.Re host := asset.Address port := asset.ProtocolPort(protocol) if localTunnelAddr != nil { - host = "127.0.0.1" + host = localIP port = localTunnelAddr.Port } username := s.account.Username @@ -547,7 +540,7 @@ func (s *Server) getMongoDBConn(localTunnelAddr *net.TCPAddr) (srvConn *srvconn. host := asset.Address port := asset.ProtocolPort(protocol) if localTunnelAddr != nil { - host = "127.0.0.1" + host = localIP port = localTunnelAddr.Port } platform := s.connOpts.authInfo.Platform @@ -806,29 +799,6 @@ func (s *Server) getGatewayProxyOptions() []srvconn.SSHClientOptions { } return []srvconn.SSHClientOptions{proxyArg} } - // 多个网关的情况 - if s.domainGateways != nil && len(s.domainGateways.Gateways) != 0 { - timeout := config.GlobalConfig.SSHTimeout - proxyArgs := make([]srvconn.SSHClientOptions, 0, len(s.domainGateways.Gateways)) - for i := range s.domainGateways.Gateways { - gateway := s.domainGateways.Gateways[i] - loginAccount := gateway.Account - port := gateway.Protocols.GetProtocolPort(model.ProtocolSSH) - proxyArg := srvconn.SSHClientOptions{ - Host: gateway.Address, - Port: strconv.Itoa(port), - Username: loginAccount.Username, - Timeout: timeout, - } - if loginAccount.IsSSHKey() { - proxyArg.PrivateKey = loginAccount.Secret - } else { - proxyArg.Password = loginAccount.Secret - } - proxyArgs = append(proxyArgs, proxyArg) - } - return proxyArgs - } return nil } @@ -982,13 +952,13 @@ func (s *Server) Proxy() { } }() var proxyAddr *net.TCPAddr - if (s.domainGateways != nil && len(s.domainGateways.Gateways) != 0) || s.gateway != nil { + if s.gateway != nil { protocol := s.connOpts.authInfo.Protocol switch protocol { case srvconn.ProtocolSSH, srvconn.ProtocolTELNET: // ssh 和 telnet 协议不需要本地启动代理 default: - dGateway, err := s.createAvailableGateWay(s.domainGateways) + dGateway, err := s.createAvailableGateWay() if err != nil { msg := lang.T("Start domain gateway failed %s") msg = fmt.Sprintf(msg, err) diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 141352c93..08856ed87 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -25,6 +25,10 @@ func RemoveSession(s *Session) { sessManager.Delete(s.ID) } +func RemoveSessionById(id string) { + sessManager.Delete(id) +} + func newSessionManager() *sessionManager { return &sessionManager{ data: make(map[string]*Session), diff --git a/pkg/srvconn/sftp_asset.go b/pkg/srvconn/sftp_asset.go index 8a0b0e5d4..7a19ae5c0 100644 --- a/pkg/srvconn/sftp_asset.go +++ b/pkg/srvconn/sftp_asset.go @@ -29,19 +29,14 @@ type AssetDir struct { user *model.User detailAsset *model.PermAsset + once sync.Once + suMaps map[string]*model.PermAccount - suMaps map[string]*model.PermAccount + mu sync.Mutex + sftpSessions sync.Map - sftpClients map[string]*SftpConn // Account stringer - sftpTraceSessions map[string]*session.Session - - once sync.Once - - reuse bool ShowHidden bool - mu sync.Mutex - jmsService *service.JMService } @@ -153,6 +148,7 @@ func (ad *AssetDir) Create(path string) (*SftpFile, error) { if con == nil || con.isClosed { return nil, sftp.ErrSshFxConnectionLost } + con.IncreaseRef() for !con.IsOverwriteFile() { if exitFile := IsExistPath(con.client, realPath); !exitFile { break @@ -171,7 +167,7 @@ func (ad *AssetDir) Create(path string) (*SftpFile, error) { isSuccess = true } ftpLog := ad.CreateFTPLog(su, operate, filename, isSuccess) - f := &SftpFile{File: sf, FTPLog: ftpLog} + f := &SftpFile{File: sf, FTPLog: ftpLog, cleanupFunc: con.DecreaseRef} return f, err } @@ -206,6 +202,8 @@ func (ad *AssetDir) MkdirAll(path string) (err error) { strconv.FormatInt(time.Now().Unix(), 10)) logger.Infof("Change duplicate dir path %s to %s", oldPath, realPath) } + con.IncreaseRef() + defer con.DecreaseRef() err = con.client.MkdirAll(realPath) filename := realPath isSuccess := false @@ -238,6 +236,7 @@ func (ad *AssetDir) Open(path string) (*SftpFile, error) { if con == nil { return nil, sftp.ErrSshFxConnectionLost } + con.IncreaseRef() sf, err := con.client.Open(realPath) filename := realPath isSuccess := false @@ -246,7 +245,7 @@ func (ad *AssetDir) Open(path string) (*SftpFile, error) { isSuccess = true } ftpLog := ad.CreateFTPLog(su, operate, filename, isSuccess) - f := &SftpFile{File: sf, FTPLog: ftpLog} + f := &SftpFile{File: sf, FTPLog: ftpLog, cleanupFunc: con.DecreaseRef} return f, err } @@ -272,6 +271,8 @@ func (ad *AssetDir) ReadDir(path string) (res []os.FileInfo, err error) { if con == nil || con.isClosed { return nil, sftp.ErrSshFxConnectionLost } + con.IncreaseRef() + defer con.DecreaseRef() res, err = con.client.ReadDir(realPath) isRootAccount := con.token.Account.Username == "root" fileInfoList := make([]os.FileInfo, 0, len(res)) @@ -314,6 +315,8 @@ func (ad *AssetDir) ReadLink(path string) (res string, err error) { if con == nil || con.isClosed { return "", sftp.ErrSshFxConnectionLost } + con.IncreaseRef() + defer con.DecreaseRef() res, err = con.client.ReadLink(realPath) return } @@ -343,6 +346,8 @@ func (ad *AssetDir) RemoveDirectory(path string) (err error) { logger.Errorf("Diable to remove root setting path %s", realPath) return sftp.ErrSshFxPermissionDenied } + con.IncreaseRef() + defer con.DecreaseRef() err = ad.removeDirectoryAll(con.client, realPath) filename := realPath isSuccess := false @@ -382,6 +387,8 @@ func (ad *AssetDir) Rename(oldNamePath, newNamePath string) (err error) { if conn1 == nil || conn1.isClosed { return sftp.ErrSshFxConnectionLost } + conn1.IncreaseRef() + defer conn1.DecreaseRef() filename := fmt.Sprintf("%s=>%s", oldRealPath, newRealPath) operate := model.OperateRename err = conn1.client.Rename(oldRealPath, newRealPath) @@ -389,7 +396,7 @@ func (ad *AssetDir) Rename(oldNamePath, newNamePath string) (err error) { ad.CreateFTPLog(su, operate, filename, false) return err } - if fileInfo, err := conn2.client.Stat(newRealPath); err == nil && fileInfo.IsDir() { + if fileInfo, err1 := conn2.client.Stat(newRealPath); err1 == nil && fileInfo.IsDir() { operate = model.OperateRenameDir } ad.CreateFTPLog(su, operate, filename, true) @@ -417,8 +424,9 @@ func (ad *AssetDir) Remove(path string) (err error) { if con == nil || con.isClosed { return sftp.ErrSshFxConnectionLost } + con.IncreaseRef() + defer con.DecreaseRef() err = con.client.Remove(realPath) - filename := realPath isSuccess := false operate := model.OperateDelete @@ -447,6 +455,8 @@ func (ad *AssetDir) Stat(path string) (res os.FileInfo, err error) { if con == nil || con.isClosed { return nil, sftp.ErrSshFxConnectionLost } + con.IncreaseRef() + defer con.DecreaseRef() res, err = con.client.Stat(realPath) isRootAccount := con.token.Account.Username == "root" return NewSftpFileInfo(res, isRootAccount), err @@ -477,6 +487,8 @@ func (ad *AssetDir) Symlink(oldNamePath, newNamePath string) (err error) { if conn1 != conn2 { return sftp.ErrSshFxOpUnsupported } + conn1.IncreaseRef() + defer conn1.DecreaseRef() err = conn1.client.Symlink(oldRealPath, newRealPath) filename := fmt.Sprintf("%s=>%s", oldRealPath, newRealPath) isSuccess := false @@ -513,64 +525,75 @@ func (ad *AssetDir) removeDirectoryAll(conn *sftp.Client, path string) error { return conn.RemoveDirectory(path) } +func (ad *AssetDir) checkExpired() { + ad.sftpSessions.Range(func(key, value interface{}) bool { + if value == nil { + return true + } + conn := value.(*SftpSession) + if conn.isClosed { + return true + } + if conn.client == nil { + return true + } + if conn.IsExpired() { + conn.CloseWithReason(model.ReasonErrIdleDisconnect) + logger.Infof("SFTP session %s idle timeout closed", conn.sess.ID) + } + return true + }) +} + func (ad *AssetDir) GetSFTPAndRealPath(su *model.PermAccount, path string) (conn *SftpConn, realPath string) { ad.mu.Lock() defer ad.mu.Unlock() - var ok bool - conn, ok = ad.sftpClients[su.String()] - if !ok { - var err error - conn, err = ad.GetSftpClient(su) - if err != nil { - logger.Errorf("Get Sftp Client err: %s", err.Error()) - return nil, "" - } - ad.sftpClients[su.String()] = conn + key := su.String() + if val, ok := ad.sftpSessions.Load(key); ok { + sftpSess := val.(*SftpSession) + realPath = filepath.Join(sftpSess.rootDirPath, strings.TrimPrefix(path, "/")) + return sftpSess.SftpConn, realPath } - if _, ok1 := ad.sftpTraceSessions[su.String()]; !ok1 { - reqSession := conn.token.CreateSession(ad.opts.RemoteAddr, ad.opts.fromType, model.SFTPType) - respSession, err := ad.jmsService.CreateSession(reqSession) - if err != nil { - logger.Errorf("Create sftp Session err: %s", err.Error()) - return nil, "" - } - terminalFunc := func(task *model.TerminalTask) error { - switch task.Name { - case model.TaskKillSession: - ad.mu.Lock() - defer ad.mu.Unlock() - ad.finishSftpSession(su.String(), conn) - return nil - } - return fmt.Errorf("sftp session not support task: %s", task.Name) - } - traceSession := session.NewSession(&respSession, terminalFunc) - session.AddSession(traceSession) - ad.sftpTraceSessions[su.String()] = traceSession - ad.recordSessionLifecycle(traceSession.ID, model.AssetConnectSuccess, "") - } - if conn.rootDirPath == "" { - platform := conn.token.Platform - sftpRoot := platform.Protocols.GetSftpPath(model.ProtocolSFTP) - accountUsername := su.Username - username := ad.user.Username - switch strings.ToLower(sftpRoot) { - case "home", "~", "": - sftpRoot = conn.HomeDirPath - default: - // ${ACCOUNT} 连接的账号用户名, ${USER} 当前用户用户名, ${HOME} 当前家目录 - homeDir := conn.HomeDirPath - sftpRoot = strings.ReplaceAll(sftpRoot, "${ACCOUNT}", accountUsername) - sftpRoot = strings.ReplaceAll(sftpRoot, "${USER}", username) - sftpRoot = strings.ReplaceAll(sftpRoot, "${HOME}", homeDir) - if strings.Index(sftpRoot, "/") != 0 { - sftpRoot = fmt.Sprintf("/%s", sftpRoot) - } + sftpSession, err := ad.createSftpSession(su) + if err != nil { + logger.Errorf("Create sftp session err: %s", err.Error()) + return nil, "" + } + ad.sftpSessions.Store(key, sftpSession) + realPath = filepath.Join(sftpSession.rootDirPath, strings.TrimPrefix(path, "/")) + return sftpSession.SftpConn, realPath +} + +func (ad *AssetDir) createSftpSession(su *model.PermAccount) (sftpSess *SftpSession, err error) { + conn, err := ad.GetSftpClient(su) + if err != nil { + return nil, err + } + reqSession := conn.token.CreateSession(ad.opts.RemoteAddr, ad.opts.fromType, model.SFTPType) + respSession, err1 := ad.jmsService.CreateSession(reqSession) + if err1 != nil { + logger.Errorf("Create sftp Session err: %s", err1.Error()) + return nil, err1 + } + sftpSession := &SftpSession{SftpConn: conn, sess: &respSession, jmsService: ad.jmsService} + terminalFunc := func(task *model.TerminalTask) error { + switch task.Name { + case model.TaskKillSession: + sftpSession.CloseWithReason(model.ReasonErrAdminTerminate) + return nil } - conn.rootDirPath = sftpRoot + return fmt.Errorf("sftp session not support task: %s", task.Name) } - realPath = filepath.Join(conn.rootDirPath, strings.TrimPrefix(path, "/")) - return + traceSession := session.NewSession(&respSession, terminalFunc) + session.AddSession(traceSession) + ad.recordSessionLifecycle(traceSession.ID, model.AssetConnectSuccess, "") + + go func() { + _ = conn.client.Wait() + sftpSession.Close() + logger.Infof("SFTP session %s closed", sftpSession.sess.ID) + }() + return sftpSession, nil } func (ad *AssetDir) IsUniqueSu() (folderName string, ok bool) { @@ -628,8 +651,67 @@ func (ad *AssetDir) getNewSftpConn(connectToken *model.ConnectToken, return nil, errNoSelectAsset } timeout := config.GlobalConfig.SSHTimeout + sshClient, err := NewSSHClientWithToken(connectToken, timeout) + if err != nil { + logger.Errorf("Get new SSH client err: %s", err) + return nil, err + } + sess, err := sshClient.AcquireSession() + if err != nil { + logger.Errorf("SSH client(%s) start sftp client session err %s", sshClient, err) + _ = sshClient.Close() + return nil, err + } + sftpClient, err := NewSftpConn(sess) + if err != nil { + logger.Errorf("SSH client(%s) start sftp conn err %s", sshClient, err) + _ = sess.Close() + sshClient.ReleaseSession(sess) + _ = sshClient.Close() + return nil, err + } + homeDirPath, err := sftpClient.Getwd() + if err != nil { + logger.Errorf("SSH client sftp (%s) get home dir err %s", sshClient, err) + _ = sftpClient.Close() + sshClient.ReleaseSession(sess) + _ = sshClient.Close() + return nil, err + } + logger.Infof("SSH client %s start sftp client session success", sshClient) - user := connectToken.User + platform := connectToken.Platform + sftpRoot := platform.Protocols.GetSftpPath(model.ProtocolSFTP) + accountUsername := su.Username + username := ad.user.Username + switch strings.ToLower(sftpRoot) { + case "home", "~", "": + sftpRoot = homeDirPath + default: + // ${ACCOUNT} 连接的账号用户名, ${USER} 当前用户用户名, ${HOME} 当前家目录 + homeDir := homeDirPath + sftpRoot = strings.ReplaceAll(sftpRoot, "${ACCOUNT}", accountUsername) + sftpRoot = strings.ReplaceAll(sftpRoot, "${USER}", username) + sftpRoot = strings.ReplaceAll(sftpRoot, "${HOME}", homeDir) + if strings.Index(sftpRoot, "/") != 0 { + sftpRoot = fmt.Sprintf("/%s", sftpRoot) + } + } + maxIdleInt := ad.opts.terminalCfg.MaxIdleTime + conn = &SftpConn{ + sshClient: sshClient, + sshSession: sess, + permAccount: su, + rootDirPath: sftpRoot, + client: sftpClient, + HomeDirPath: homeDirPath, + token: connectToken, + maxIdleTime: time.Duration(maxIdleInt) * time.Minute, + } + return conn, nil +} + +func NewSSHClientWithToken(connectToken *model.ConnectToken, timeout int) (*SSHClient, error) { asset := connectToken.Asset account := connectToken.Account username := account.Username @@ -639,7 +721,6 @@ func (ad *AssetDir) getNewSftpConn(connectToken *model.ConnectToken, sshAuthOpts = append(sshAuthOpts, SSHClientUsername(username)) sshAuthOpts = append(sshAuthOpts, SSHClientHost(asset.Address)) sshAuthOpts = append(sshAuthOpts, SSHClientPort(asset.ProtocolPort(protocol))) - sshAuthOpts = append(sshAuthOpts, SSHClientTimeout(timeout)) if account.IsSSHKey() { if signer, err1 := gossh.ParsePrivateKey([]byte(account.Secret)); err1 == nil { @@ -670,45 +751,7 @@ func (ad *AssetDir) getNewSftpConn(connectToken *model.ConnectToken, proxyArgs = append(proxyArgs, proxyArg) sshAuthOpts = append(sshAuthOpts, SSHClientProxyClient(proxyArgs...)) } - sshClient, err := NewSSHClient(sshAuthOpts...) - if err != nil { - logger.Errorf("Get new SSH client err: %s", err) - return nil, err - } - sess, err := sshClient.AcquireSession() - if err != nil { - logger.Errorf("SSH client(%s) start sftp client session err %s", sshClient, err) - _ = sshClient.Close() - return nil, err - } - sftpClient, err := NewSftpConn(sess) - if err != nil { - logger.Errorf("SSH client(%s) start sftp conn err %s", sshClient, err) - _ = sess.Close() - sshClient.ReleaseSession(sess) - _ = sshClient.Close() - return nil, err - } - go func() { - _ = sftpClient.Wait() - sshClient.ReleaseSession(sess) - _ = sshClient.Close() - logger.Infof("User %s SSH client(%s) for SFTP release", user.String(), sshClient) - if sftpSession, ok := ad.sftpTraceSessions[su.String()]; ok { - sid := sftpSession.ID - reason := string(model.ReasonErrConnectDisconnect) - ad.recordSessionLifecycle(sid, model.AssetConnectFinished, reason) - } - }() - homeDirPath, err := sftpClient.Getwd() - if err != nil { - logger.Errorf("SSH client sftp (%s) get home dir err %s", sshClient, err) - _ = sftpClient.Close() - return nil, err - } - logger.Infof("SSH client %s start sftp client session success", sshClient) - conn = &SftpConn{client: sftpClient, HomeDirPath: homeDirPath, token: connectToken} - return conn, nil + return NewSSHClient(sshAuthOpts...) } func (ad *AssetDir) parsePath(path string) []string { @@ -717,31 +760,19 @@ func (ad *AssetDir) parsePath(path string) []string { } func (ad *AssetDir) close() { - ad.mu.Lock() - defer ad.mu.Unlock() - for key, conn := range ad.sftpClients { - if conn != nil { - ad.finishSftpSession(key, conn) + ad.sftpSessions.Range(func(key, value interface{}) bool { + if conn, ok := value.(*SftpSession); ok { + conn.Close() } - } -} - -func (ad *AssetDir) finishSftpSession(key string, conn *SftpConn) { - sess := ad.sftpTraceSessions[key] - if sess != nil { - session.RemoveSession(sess) - if err := ad.jmsService.SessionFinished(sess.ID, common.NewNowUTCTime()); err != nil { - logger.Errorf("SFTP Session finished err: %s", err) - } - logger.Debugf("SFTP Session finished %s", sess.ID) - } - conn.Close() + return true + }) } func (ad *AssetDir) CreateFTPLog(su *model.PermAccount, operate, filename string, isSuccess bool) *model.FTPLog { sessionId := "" - if traceSession, ok := ad.sftpTraceSessions[su.String()]; ok { - sessionId = traceSession.ID + if val, ok := ad.sftpSessions.Load(su.String()); ok { + traceSession := val.(*SftpSession) + sessionId = traceSession.sess.ID } else { logger.Errorf("Not found sftp session for asset %s account %s", ad.detailAsset.String(), su.String()) diff --git a/pkg/srvconn/sftp_node.go b/pkg/srvconn/sftp_node.go index fc07ae0e9..f91137420 100644 --- a/pkg/srvconn/sftp_node.go +++ b/pkg/srvconn/sftp_node.go @@ -11,7 +11,8 @@ type NodeDir struct { ID string folderName string - subDirs map[string]os.FileInfo + subDirs map[string]os.FileInfo + _subDirs sync.Map modeTime time.Time @@ -29,6 +30,7 @@ func (nd *NodeDir) Size() int64 { return 0 } func (nd *NodeDir) Mode() os.FileMode { return os.FileMode(0444) | os.ModeDir } + func (nd *NodeDir) ModTime() time.Time { return nd.modeTime } func (nd *NodeDir) IsDir() bool { return true } @@ -49,6 +51,9 @@ func (nd *NodeDir) loadSubNodeTree() { if nd.loadSubFunc != nil { nd.subDirs = nd.loadSubFunc() } + for k, v := range nd.subDirs { + nd._subDirs.Store(k, v) + } }) } @@ -61,6 +66,18 @@ func (nd *NodeDir) close() { if assetDir, ok := dir.(*AssetDir); ok { assetDir.close() } - } } + +func (nd *NodeDir) checkExpired() { + nd._subDirs.Range(func(key, value interface{}) bool { + if nodeDir, ok := value.(*NodeDir); ok { + nodeDir.checkExpired() + return true + } + if assetDir, ok := value.(*AssetDir); ok { + assetDir.checkExpired() + } + return true + }) +} diff --git a/pkg/srvconn/sftp_session.go b/pkg/srvconn/sftp_session.go new file mode 100644 index 000000000..b8e10ce01 --- /dev/null +++ b/pkg/srvconn/sftp_session.go @@ -0,0 +1,38 @@ +package srvconn + +import ( + "sync" + + "github.com/jumpserver/koko/pkg/jms-sdk-go/common" + "github.com/jumpserver/koko/pkg/jms-sdk-go/model" + "github.com/jumpserver/koko/pkg/jms-sdk-go/service" + "github.com/jumpserver/koko/pkg/logger" + "github.com/jumpserver/koko/pkg/session" +) + +type SftpSession struct { + *SftpConn + sess *model.Session + once sync.Once + jmsService *service.JMService +} + +func (s *SftpSession) CloseWithReason(reason model.SessionLifecycleReasonErr) { + s.once.Do(func() { + s.SftpConn.Close() + session.RemoveSessionById(s.sess.ID) + if err := s.jmsService.SessionFinished(s.sess.ID, common.NewNowUTCTime()); err != nil { + logger.Errorf("SFTP Session finished err: %s", err) + } + logger.Debugf("SFTP Session finished %s", s.sess.ID) + logObj := model.SessionLifecycleLog{Reason: reason.String()} + if err := s.jmsService.RecordSessionLifecycleLog(s.sess.ID, model.AssetConnectFinished, logObj); err != nil { + logger.Errorf("Update session %s lifecycle asset_connect_finished failed: %s", s.sess.ID, err) + } + }) + +} + +func (s *SftpSession) Close() { + s.CloseWithReason(model.ReasonErrConnectDisconnect) +} diff --git a/pkg/srvconn/sftpconn.go b/pkg/srvconn/sftpconn.go index d9bec9997..99d3d0b41 100644 --- a/pkg/srvconn/sftpconn.go +++ b/pkg/srvconn/sftpconn.go @@ -347,8 +347,11 @@ func (u *UserSftpConn) generateSubFoldersFromNodeTree(nodeTrees model.NodeTreeLi folderName := cleanFolderName(node.Value) folderName = findAvailableKeyByPaddingSuffix(matchFunc, folderName, paddingCharacter) loadFunc := u.LoadNodeSubFoldersByKey(node.Key) - nodeDir := NewNodeDir(WithFolderID(item.ID), - WithFolderName(folderName), WithSubFoldersLoadFunc(loadFunc)) + opts := make([]FolderBuilderOption, 0, 3) + opts = append(opts, WithFolderID(item.ID)) + opts = append(opts, WithFolderName(folderName)) + opts = append(opts, WithSubFoldersLoadFunc(loadFunc)) + nodeDir := NewNodeDir(opts...) dirs[folderName] = &nodeDir case model.TreeTypeAsset: assetMeta := item.Meta.Data @@ -363,8 +366,9 @@ func (u *UserSftpConn) generateSubFoldersFromNodeTree(nodeTrees model.NodeTreeLi opts = append(opts, WithFolderName(folderName)) opts = append(opts, WitRemoteAddr(u.Addr)) opts = append(opts, WithFromType(u.loginFrom)) + opts = append(opts, WithTerminalConfig(u.opts.terminalCfg)) assetDir := NewAssetDir(u.jmsService, u.User, opts...) - dirs[folderName] = &assetDir + dirs[folderName] = assetDir } } return dirs @@ -383,10 +387,11 @@ func (u *UserSftpConn) generateSubFoldersFromToken(token *model.ConnectToken) ma opts = append(opts, WitRemoteAddr(u.Addr)) opts = append(opts, WithToken(token)) opts = append(opts, WithFromType(u.loginFrom)) + opts = append(opts, WithTerminalConfig(u.opts.terminalCfg)) assetDir := NewAssetDir(u.jmsService, u.User, opts...) assetDir.loadSystemUsers() - dirs[folderName] = &assetDir - u.assetDir = &assetDir + dirs[folderName] = assetDir + u.assetDir = assetDir return dirs } @@ -413,9 +418,11 @@ func (u *UserSftpConn) generateSubFoldersFromAssets(assets []model.PermAsset) ma opts = append(opts, WithFolderName(folderName)) opts = append(opts, WitRemoteAddr(u.Addr)) opts = append(opts, WithAsset(assets[i])) + opts = append(opts, WithFromType(u.loginFrom)) + opts = append(opts, WithTerminalConfig(u.opts.terminalCfg)) opts = append(opts, WithFolderUsername(u.opts.accountUsername)) assetDir := NewAssetDir(u.jmsService, u.User, opts...) - dirs[folderName] = &assetDir + dirs[folderName] = assetDir } return dirs } @@ -443,6 +450,8 @@ type userSftpOption struct { token *model.ConnectToken accountUsername string + + terminalCfg *model.TerminalConfig } type UserSftpOption func(*userSftpOption) @@ -483,6 +492,12 @@ func WithAccountUsername(username string) UserSftpOption { } } +func WithTerminalCfg(cfg *model.TerminalConfig) UserSftpOption { + return func(o *userSftpOption) { + o.terminalCfg = cfg + } +} + func NewUserSftpConn(jmsService *service.JMService, opts ...UserSftpOption) *UserSftpConn { var sftpOpts userSftpOption for _, setter := range opts { @@ -507,9 +522,33 @@ func NewUserSftpConn(jmsService *service.JMService, opts ...UserSftpOption) *Use default: u.Dirs = u.generateSubFoldersFromRootTree() } + go u.run() return &u } +func (u *UserSftpConn) run() { + tick := time.NewTicker(time.Minute) + defer tick.Stop() + for { + select { + case <-u.closed: + logger.Infof("User %s sftp conn closed", u.User.String()) + return + case <-tick.C: + logger.Debugf("User %s sftp conn check expired", u.User.String()) + } + for _, dir := range u.Dirs { + if nodeDir, ok := dir.(*NodeDir); ok { + nodeDir.checkExpired() + continue + } + if assetDir, ok := dir.(*AssetDir); ok { + assetDir.checkExpired() + } + } + } +} + func cleanFolderName(folderName string) string { return strings.ReplaceAll(folderName, SFTPPathSeparator, "_") } diff --git a/pkg/srvconn/sftpfile.go b/pkg/srvconn/sftpfile.go index 0b89734d7..9e4483b24 100644 --- a/pkg/srvconn/sftpfile.go +++ b/pkg/srvconn/sftpfile.go @@ -5,15 +5,16 @@ import ( "os" "strings" "sync" + "sync/atomic" "syscall" "time" "github.com/pkg/sftp" + gossh "golang.org/x/crypto/ssh" "github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/jms-sdk-go/model" "github.com/jumpserver/koko/pkg/jms-sdk-go/service" - "github.com/jumpserver/koko/pkg/session" ) const ( @@ -81,6 +82,7 @@ func NewNodeDir(builders ...FolderBuilderOption) NodeDir { ID: dirConf.ID, folderName: dirConf.Name, subDirs: map[string]os.FileInfo{}, + _subDirs: sync.Map{}, modeTime: time.Now().UTC(), once: new(sync.Once), loadSubFunc: dirConf.loadSubFunc, @@ -103,6 +105,8 @@ type folderOptions struct { token *model.ConnectToken accountUsername string + + terminalCfg *model.TerminalConfig } func WithFolderUsername(username string) FolderBuilderOption { @@ -153,7 +157,13 @@ func WithFromType(fromType model.LabelField) FolderBuilderOption { } } -func NewAssetDir(jmsService *service.JMService, user *model.User, opts ...FolderBuilderOption) AssetDir { +func WithTerminalConfig(cfg *model.TerminalConfig) FolderBuilderOption { + return func(info *folderOptions) { + info.terminalCfg = cfg + } +} + +func NewAssetDir(jmsService *service.JMService, user *model.User, opts ...FolderBuilderOption) *AssetDir { var dirOpts folderOptions for _, setter := range opts { setter(&dirOpts) @@ -173,46 +183,87 @@ func NewAssetDir(jmsService *service.JMService, user *model.User, opts ...Folder permAccounts = append(permAccounts, permAccount) detailAsset = dirOpts.asset } - return AssetDir{ + return &AssetDir{ opts: dirOpts, user: user, detailAsset: detailAsset, modeTime: time.Now().UTC(), suMaps: generateSubAccountsFolderMap(permAccounts), ShowHidden: conf.ShowHiddenFile, - reuse: conf.ReuseConnection, - sftpClients: map[string]*SftpConn{}, - sftpTraceSessions: make(map[string]*session.Session), - jmsService: jmsService, + sftpSessions: sync.Map{}, + jmsService: jmsService, } } type SftpFile struct { *sftp.File FTPLog *model.FTPLog + + cleanupFunc func() +} + +func (s *SftpFile) Close() error { + if s.cleanupFunc != nil { + s.cleanupFunc() + } + return s.File.Close() } type SftpConn struct { + permAccount *model.PermAccount HomeDirPath string client *sftp.Client + sshClient *SSHClient + sshSession *gossh.Session token *model.ConnectToken isClosed bool rootDirPath string + + nextExpiredTime time.Time + refs atomic.Int32 + lock sync.Mutex + maxIdleTime time.Duration } -func (s *SftpConn) IsOverwriteFile() bool { - resolution := s.token.ConnectOptions.FilenameConflictResolution - switch strings.ToLower(resolution) { - case FilenamePolicyReplace: - return true - case FilenamePolicySuffix: +func (s *SftpConn) IsExpired() bool { + if s.Ref() > 0 { + // some client is using return false - default: - return true } + s.lock.Lock() + defer s.lock.Unlock() + now := time.Now() + return now.Sub(s.nextExpiredTime) > 0 || s.token.ExpireAt.IsExpired(now) } +func (s *SftpConn) UpdateExpiredTime() { + s.lock.Lock() + defer s.lock.Unlock() + s.nextExpiredTime = time.Now().Add(s.maxIdleTime) +} + +func (s *SftpConn) IncreaseRef() { + s.refs.Add(1) + s.UpdateExpiredTime() +} + +func (s *SftpConn) DecreaseRef() { + s.refs.Add(-1) + s.UpdateExpiredTime() +} + +func (s *SftpConn) Ref() int32 { + return s.refs.Load() +} + +func (s *SftpConn) IsOverwriteFile() bool { + resolution := s.token.ConnectOptions.FilenameConflictResolution + return !strings.EqualFold(resolution, FilenamePolicySuffix) +} + +// check if the path is root path and disable to remove + func (s *SftpConn) IsRootPath(path string) bool { return s.rootDirPath == path } diff --git a/pkg/srvconn/ssh.go b/pkg/srvconn/ssh.go index 483bd0660..5009fcc9d 100644 --- a/pkg/srvconn/ssh.go +++ b/pkg/srvconn/ssh.go @@ -225,6 +225,7 @@ func (s *SSHClient) decreaseSelfRef() { func (s *SSHClient) selfRef() int32 { return s._selfRef } + func (s *SSHClient) String() string { return fmt.Sprintf("%s@%s:%s", s.Cfg.Username, s.Cfg.Host, s.Cfg.Port)