diff --git a/src/config/config.go b/src/config/config.go index 47961bf2..08a2454e 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -45,6 +45,7 @@ const ( NonInteractiveCommandsPluginName = "NonInteractiveCommands" //Agent Versions - TerminateSessionFlagSupportedAfterThisAgentVersion = "2.3.722.0" - TCPMultiplexingSupportedAfterThisAgentVersion = "3.0.196.0" + TerminateSessionFlagSupportedAfterThisAgentVersion = "2.3.722.0" + TCPMultiplexingSupportedAfterThisAgentVersion = "3.0.196.0" + TCPMultiplexingWithSmuxKeepAliveDisabledAfterThisAgentVersion = "3.1.1511.0" ) diff --git a/src/sessionmanagerplugin/session/portsession/basicportforwarding.go b/src/sessionmanagerplugin/session/portsession/basicportforwarding.go index 95b91d81..28c0f2d2 100644 --- a/src/sessionmanagerplugin/session/portsession/basicportforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/basicportforwarding.go @@ -65,7 +65,7 @@ func (p *BasicPortForwarding) Stop() { } // InitializeStreams establishes connection and initializes the stream -func (p *BasicPortForwarding) InitializeStreams(log log.T) (err error) { +func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { p.handleControlSignals(log) if err = p.startLocalConn(log); err != nil { return diff --git a/src/sessionmanagerplugin/session/portsession/muxportforwarding.go b/src/sessionmanagerplugin/session/portsession/muxportforwarding.go index f5a03372..209bf234 100644 --- a/src/sessionmanagerplugin/session/portsession/muxportforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/muxportforwarding.go @@ -34,6 +34,7 @@ import ( "github.com/aws/SSMCLI/src/message" "github.com/aws/SSMCLI/src/sessionmanagerplugin/session" "github.com/aws/SSMCLI/src/sessionmanagerplugin/session/sessionutil" + "github.com/aws/SSMCLI/src/version" "github.com/xtaci/smux" "golang.org/x/sync/errgroup" ) @@ -90,12 +91,12 @@ func (p *MuxPortForwarding) Stop() { } // InitializeStreams initializes i/o streams -func (p *MuxPortForwarding) InitializeStreams(log log.T) (err error) { +func (p *MuxPortForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { p.handleControlSignals(log) p.socketFile = getUnixSocketPath(p.sessionId, os.TempDir(), "session_manager_plugin_mux.sock") - if err = p.initialize(log); err != nil { + if err = p.initialize(log, agentVersion); err != nil { p.cleanUp() } return @@ -142,7 +143,7 @@ func (p *MuxPortForwarding) cleanUp() { } // initialize opens a network connection that acts as smux client -func (p *MuxPortForwarding) initialize(log log.T) (err error) { +func (p *MuxPortForwarding) initialize(log log.T, agentVersion string) (err error) { // open a network listener var listener net.Listener @@ -165,10 +166,17 @@ func (p *MuxPortForwarding) initialize(log log.T) (err error) { g.Go(func() error { if muxConn, err := net.Dial(listener.Addr().Network(), listener.Addr().String()); err != nil { return err - } else if muxSession, err := smux.Client(muxConn, nil); err != nil { - return err } else { - p.muxClient = &MuxClient{muxConn, muxSession} + smuxConfig := smux.DefaultConfig() + if version.DoesAgentSupportDisableSmuxKeepAlive(log, agentVersion) { + // Disable smux KeepAlive or else it breaks Session Manager idle timeout. + smuxConfig.KeepAliveDisabled = true + } + if muxSession, err := smux.Client(muxConn, smuxConfig); err != nil { + return err + } else { + p.muxClient = &MuxClient{muxConn, muxSession} + } } return nil }) diff --git a/src/sessionmanagerplugin/session/portsession/portsession.go b/src/sessionmanagerplugin/session/portsession/portsession.go index b2d6450a..a2b95cb9 100644 --- a/src/sessionmanagerplugin/session/portsession/portsession.go +++ b/src/sessionmanagerplugin/session/portsession/portsession.go @@ -35,7 +35,7 @@ type PortSession struct { type IPortSession interface { IsStreamNotSet() (status bool) - InitializeStreams(log log.T) (err error) + InitializeStreams(log log.T, agentVersion string) (err error) ReadStream(log log.T) (err error) WriteStream(outputMessage message.ClientMessage) (err error) Stop() @@ -111,7 +111,7 @@ func (s *PortSession) Stop() { // StartSession redirects inputStream/outputStream data to datachannel. func (s *PortSession) SetSessionHandlers(log log.T) (err error) { - if err = s.portSessionType.InitializeStreams(log); err != nil { + if err = s.portSessionType.InitializeStreams(log, s.DataChannel.GetAgentVersion()); err != nil { return err } diff --git a/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go b/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go index f61278fe..52e5da50 100644 --- a/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go @@ -46,7 +46,7 @@ func (p *StandardStreamForwarding) Stop() { } // InitializeStreams initializes the streams with its file descriptors -func (p *StandardStreamForwarding) InitializeStreams(log log.T) (err error) { +func (p *StandardStreamForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { p.inputStream = os.Stdin p.outputStream = os.Stdout return diff --git a/src/version/versionvalidator.go b/src/version/versionvalidator.go index ea3c9a02..6f3f4b52 100644 --- a/src/version/versionvalidator.go +++ b/src/version/versionvalidator.go @@ -24,6 +24,11 @@ func DoesAgentSupportTCPMultiplexing(log log.T, agentVersion string) (supported return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TCPMultiplexingSupportedAfterThisAgentVersion) } +// DoesAgentSupportDisableSmuxKeepAlive returns true if given agentVersion disables smux KeepAlive in TCP multiplexing in port plugin, false otherwise +func DoesAgentSupportDisableSmuxKeepAlive(log log.T, agentVersion string) (supported bool) { + return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TCPMultiplexingWithSmuxKeepAliveDisabledAfterThisAgentVersion) +} + // DoesAgentSupportTerminateSessionFlag returns true if given agentVersion supports TerminateSession flag, false otherwise func DoesAgentSupportTerminateSessionFlag(log log.T, agentVersion string) (supported bool) { return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TerminateSessionFlagSupportedAfterThisAgentVersion) diff --git a/src/version/versionvalidator_test.go b/src/version/versionvalidator_test.go index 7ef12e83..471e8b46 100644 --- a/src/version/versionvalidator_test.go +++ b/src/version/versionvalidator_test.go @@ -148,3 +148,11 @@ func TestDoesAgentSupportTerminateSessionFlagForNotSupportedScenario(t *testing. func TestDoesAgentSupportTerminateSessionFlagWhenAgentVersionIsEqualSupportedAfterVersion(t *testing.T) { assert.False(t, DoesAgentSupportTerminateSessionFlag(mockLog, "2.3.722.0")) } + +func TestDoesAgentSupportDisableSmuxKeepAliveForNotSupportedScenario(t *testing.T) { + assert.False(t, DoesAgentSupportDisableSmuxKeepAlive(mockLog, "3.1.1476.0")) +} + +func TestDoesAgentSupportDisableSmuxKeepAliveForSupportedScenario(t *testing.T) { + assert.True(t, DoesAgentSupportDisableSmuxKeepAlive(mockLog, "3.1.1600.0")) +}