Skip to content

Commit

Permalink
perf: support max idle timeout for sftp session
Browse files Browse the repository at this point in the history
  • Loading branch information
LeeEirc committed Jul 31, 2024
1 parent 1faa22a commit 401663b
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 237 deletions.
2 changes: 1 addition & 1 deletion pkg/auth/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func HTTPMiddleSessionAuth(jmsService *service.JMService) gin.HandlerFunc {
func HTTPMiddleDebugAuth() gin.HandlerFunc {
return func(ctx *gin.Context) {
switch ctx.ClientIP() {
case "127.0.0.1", "localhost":
case "127.0.0.1", "localhost", "::1":
return
default:
_ = ctx.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid host %s", ctx.ClientIP()))
Expand Down
1 change: 1 addition & 0 deletions pkg/handler/direct_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ func (d *DirectHandler) NewSFTPHandler() *SftpHandler {
opts = append(opts, srvconn.WithUser(d.opts.User))
opts = append(opts, srvconn.WithRemoteAddr(addr))
opts = append(opts, srvconn.WithLoginFrom(model.LoginFromSSH))
opts = append(opts, srvconn.WithTerminalCfg(d.opts.terminalConf))
if !d.opts.IsTokenConnection() {
opts = append(opts, srvconn.WithAssets(d.opts.assets))
if len(d.opts.assets) == 1 {
Expand Down
15 changes: 14 additions & 1 deletion pkg/handler/server_ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (s *Server) SFTPHandler(sess ssh.Session) {
directSrv := NewDirectHandler(sess, s.jmsService, opts...)
sftpHandler = directSrv.NewSFTPHandler()
} else {
sftpHandler = NewSFTPHandler(s.jmsService, currentUser, addr)
sftpHandler = s.NewSftpHandler(currentUser, addr)
}
handlers := sftp.Handlers{
FileGet: sftpHandler,
Expand All @@ -106,6 +106,19 @@ func (s *Server) SFTPHandler(sess ssh.Session) {
logger.Infof("SFTP request %s: Handler exit.", reqID)
}

func (s *Server) NewSftpHandler(user *model.User, addr string) *SftpHandler {
terminalCfg := s.GetTerminalConfig()
opts := make([]srvconn.UserSftpOption, 0, 5)
opts = append(opts, srvconn.WithUser(user))
opts = append(opts, srvconn.WithRemoteAddr(addr))
opts = append(opts, srvconn.WithLoginFrom(model.LoginFromSSH))
opts = append(opts, srvconn.WithTerminalCfg(&terminalCfg))
return &SftpHandler{
UserSftpConn: srvconn.NewUserSftpConn(s.jmsService, opts...),
recorder: proxy.GetFTPFileRecorder(s.jmsService),
}
}

func (s *Server) LocalPortForwardingPermission(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger.Debugf("LocalPortForwardingPermission: %s %s %d", ctx.User(), dstHost, dstPort)
return config.GlobalConfig.EnableLocalPortForward
Expand Down
13 changes: 0 additions & 13 deletions pkg/handler/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,11 @@ import (

"github.com/pkg/sftp"

"github.com/jumpserver/koko/pkg/jms-sdk-go/model"
"github.com/jumpserver/koko/pkg/jms-sdk-go/service"
"github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/proxy"
"github.com/jumpserver/koko/pkg/srvconn"
)

func NewSFTPHandler(jmsService *service.JMService, user *model.User, addr string) *SftpHandler {
opts := make([]srvconn.UserSftpOption, 0, 5)
opts = append(opts, srvconn.WithUser(user))
opts = append(opts, srvconn.WithRemoteAddr(addr))
opts = append(opts, srvconn.WithLoginFrom(model.LoginFromSSH))
return &SftpHandler{
UserSftpConn: srvconn.NewUserSftpConn(jmsService, opts...),
recorder: proxy.GetFTPFileRecorder(jmsService),
}
}

type SftpHandler struct {
*srvconn.UserSftpConn

Expand Down
9 changes: 9 additions & 0 deletions pkg/httpd/sftpvolume.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type volumeOption struct {
user *model.User
asset *model.PermAsset
connectToken *model.ConnectToken
terminalCfg *model.TerminalConfig
}
type VolumeOption func(*volumeOption)

Expand Down Expand Up @@ -51,6 +52,13 @@ func WithConnectToken(connectToken *model.ConnectToken) VolumeOption {
}
}

func WithTerminalCfg(cfg *model.TerminalConfig) VolumeOption {
return func(opts *volumeOption) {
opts.terminalCfg = cfg
}

}

func NewUserVolume(jmsService *service.JMService, opts ...VolumeOption) *UserVolume {
var volOpts volumeOption
for _, opt := range opts {
Expand All @@ -77,6 +85,7 @@ func NewUserVolume(jmsService *service.JMService, opts ...VolumeOption) *UserVol
sftpOpts = append(sftpOpts, srvconn.WithUser(volOpts.user))
sftpOpts = append(sftpOpts, srvconn.WithRemoteAddr(volOpts.addr))
sftpOpts = append(sftpOpts, srvconn.WithLoginFrom(model.LoginFromWeb))
sftpOpts = append(sftpOpts, srvconn.WithTerminalCfg(volOpts.terminalCfg))
userSftp := srvconn.NewUserSftpConn(jmsService, sftpOpts...)
rawID := fmt.Sprintf("%s@%s", volOpts.user.Username, volOpts.addr)

Expand Down
6 changes: 6 additions & 0 deletions pkg/httpd/webfolder.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ func (h *webFolder) Name() string {
func (h *webFolder) CheckValidation() error {
apiClient := h.ws.apiClient
user := h.ws.CurrentUser()
terminalCfg, err := h.ws.apiClient.GetTerminalConfig()
if err != nil {
logger.Errorf("Get terminal config failed: %s", err)
return err
}
volOpts := make([]VolumeOption, 0, 5)
volOpts = append(volOpts, WithUser(user))
volOpts = append(volOpts, WithAddr(h.ws.ClientIP()))
volOpts = append(volOpts, WithTerminalCfg(&terminalCfg))
params := h.ws.wsParams
targetId := params.TargetId
assetId := params.AssetId
Expand Down
4 changes: 4 additions & 0 deletions pkg/jms-sdk-go/model/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ var EmptyLifecycleLog = SessionLifecycleLog{}

type SessionLifecycleReasonErr string

func (s SessionLifecycleReasonErr) String() string {
return string(s)
}

const (
ReasonErrConnectFailed SessionLifecycleReasonErr = "connect_failed"
ReasonErrConnectDisconnect SessionLifecycleReasonErr = "connect_disconnect"
Expand Down
19 changes: 0 additions & 19 deletions pkg/proxy/domain_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
)

type domainGateway struct {
domain *model.Domain
dstIP string
dstPort int

Expand Down Expand Up @@ -97,24 +96,6 @@ func (d *domainGateway) getAvailableGateway() bool {
d.sshClient = sshClient
return true
}

for i := range d.domain.Gateways {
gateway := d.domain.Gateways[i]
if !gateway.Protocols.IsSupportProtocol(model.ProtocolSSH) {
continue
}
logger.Debugf("Domain %s try dial gateway %s", d.domain.Name, gateway.Name)
sshClient, err := d.createGatewaySSHClient(&gateway)
if err != nil {
logger.Errorf("Dial gateway %s err: %s ", gateway.Name, err)
continue
}
logger.Infof("Domain %s use gateway %s", d.domain.Name, gateway.Name)
d.sshClient = sshClient
d.selectedGateway = &gateway
return true
}
logger.Errorf("Domain Gateway %s has no available gateway", d.domain.Name)
return false
}

Expand Down
78 changes: 24 additions & 54 deletions pkg/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ var (
ErrNoAuthInfo = errors.New("no auth info")
)

const localIP = "127.0.0.1"

func NewServer(conn UserConnection, jmsService *service.JMService, opts ...ConnectionOption) (*Server, error) {
connOpts := &ConnectionOptions{}
for _, setter := range opts {
Expand Down Expand Up @@ -139,9 +141,8 @@ type Server struct {

suFromAccount *model.BaseAccount

terminalConf *model.TerminalConfig
domainGateways *model.Domain
gateway *model.Gateway
terminalConf *model.TerminalConfig
gateway *model.Gateway

sessionInfo *model.Session

Expand Down Expand Up @@ -428,31 +429,23 @@ func (s *Server) getCacheSSHConn() (srvConn *srvconn.SSHConnection, ok bool) {
return cacheConn, true
}

func (s *Server) createAvailableGateWay(domain *model.Domain) (*domainGateway, error) {
func (s *Server) createAvailableGateWay() (*domainGateway, error) {
asset := s.connOpts.authInfo.Asset
protocol := s.connOpts.authInfo.Protocol

var dGateway *domainGateway
switch protocol {
case srvconn.ProtocolK8s:
dstHost, dstPort, err := ParseUrlHostAndPort(asset.Address)
dstIP := asset.Address
dstPort := asset.ProtocolPort(protocol)
if protocol == srvconn.ProtocolK8s {
dstHost, dstPort1, err := ParseUrlHostAndPort(asset.Address)
if err != nil {
return nil, err
}
dGateway = &domainGateway{
domain: domain,
dstIP: dstHost,
dstPort: dstPort,
selectedGateway: s.gateway,
}
default:
port := asset.ProtocolPort(protocol)
dGateway = &domainGateway{
domain: domain,
dstIP: asset.Address,
dstPort: port,
selectedGateway: s.gateway,
}
dstIP = dstHost
dstPort = dstPort1
}
dGateway := &domainGateway{
dstIP: dstIP,
dstPort: dstPort,
selectedGateway: s.gateway,
}
return dGateway, nil
}
Expand All @@ -462,11 +455,11 @@ func (s *Server) getK8sConConn(localTunnelAddr *net.TCPAddr) (srvConn srvconn.Se
asset := s.connOpts.authInfo.Asset
clusterServer := asset.Address
if localTunnelAddr != nil {
originUrl, err := url.Parse(clusterServer)
if err != nil {
return nil, err
originUrl, err1 := url.Parse(clusterServer)
if err1 != nil {
return nil, err1
}
clusterServer = ReplaceURLHostAndPort(originUrl, "127.0.0.1", localTunnelAddr.Port)
clusterServer = ReplaceURLHostAndPort(originUrl, localIP, localTunnelAddr.Port)
}
if s.connOpts.k8sContainer != nil {
return s.getContainerConn(clusterServer)
Expand Down Expand Up @@ -515,7 +508,7 @@ func (s *Server) getRedisConn(localTunnelAddr *net.TCPAddr) (srvConn *srvconn.Re
host := asset.Address
port := asset.ProtocolPort(protocol)
if localTunnelAddr != nil {
host = "127.0.0.1"
host = localIP
port = localTunnelAddr.Port
}
username := s.account.Username
Expand Down Expand Up @@ -547,7 +540,7 @@ func (s *Server) getMongoDBConn(localTunnelAddr *net.TCPAddr) (srvConn *srvconn.
host := asset.Address
port := asset.ProtocolPort(protocol)
if localTunnelAddr != nil {
host = "127.0.0.1"
host = localIP
port = localTunnelAddr.Port
}
platform := s.connOpts.authInfo.Platform
Expand Down Expand Up @@ -806,29 +799,6 @@ func (s *Server) getGatewayProxyOptions() []srvconn.SSHClientOptions {
}
return []srvconn.SSHClientOptions{proxyArg}
}
// 多个网关的情况
if s.domainGateways != nil && len(s.domainGateways.Gateways) != 0 {
timeout := config.GlobalConfig.SSHTimeout
proxyArgs := make([]srvconn.SSHClientOptions, 0, len(s.domainGateways.Gateways))
for i := range s.domainGateways.Gateways {
gateway := s.domainGateways.Gateways[i]
loginAccount := gateway.Account
port := gateway.Protocols.GetProtocolPort(model.ProtocolSSH)
proxyArg := srvconn.SSHClientOptions{
Host: gateway.Address,
Port: strconv.Itoa(port),
Username: loginAccount.Username,
Timeout: timeout,
}
if loginAccount.IsSSHKey() {
proxyArg.PrivateKey = loginAccount.Secret
} else {
proxyArg.Password = loginAccount.Secret
}
proxyArgs = append(proxyArgs, proxyArg)
}
return proxyArgs
}
return nil
}

Expand Down Expand Up @@ -982,13 +952,13 @@ func (s *Server) Proxy() {
}
}()
var proxyAddr *net.TCPAddr
if (s.domainGateways != nil && len(s.domainGateways.Gateways) != 0) || s.gateway != nil {
if s.gateway != nil {
protocol := s.connOpts.authInfo.Protocol
switch protocol {
case srvconn.ProtocolSSH, srvconn.ProtocolTELNET:
// ssh 和 telnet 协议不需要本地启动代理
default:
dGateway, err := s.createAvailableGateWay(s.domainGateways)
dGateway, err := s.createAvailableGateWay()
if err != nil {
msg := lang.T("Start domain gateway failed %s")
msg = fmt.Sprintf(msg, err)
Expand Down
4 changes: 4 additions & 0 deletions pkg/session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func RemoveSession(s *Session) {
sessManager.Delete(s.ID)
}

func RemoveSessionById(id string) {
sessManager.Delete(id)
}

func newSessionManager() *sessionManager {
return &sessionManager{
data: make(map[string]*Session),
Expand Down
Loading

0 comments on commit 401663b

Please sign in to comment.