Skip to content

Commit

Permalink
Determine smux keep alive configuration based on SSM agent version check
Browse files Browse the repository at this point in the history
  • Loading branch information
yuting-fan authored and Yangtao-Hua committed Jun 17, 2022
1 parent 31f76d0 commit 986e79d
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 12 deletions.
5 changes: 3 additions & 2 deletions src/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const (
NonInteractiveCommandsPluginName = "NonInteractiveCommands"

//Agent Versions
TerminateSessionFlagSupportedAfterThisAgentVersion = "2.3.722.0"
TCPMultiplexingSupportedAfterThisAgentVersion = "3.0.196.0"
TerminateSessionFlagSupportedAfterThisAgentVersion = "2.3.722.0"
TCPMultiplexingSupportedAfterThisAgentVersion = "3.0.196.0"
TCPMultiplexingWithSmuxKeepAliveDisabledAfterThisAgentVersion = "3.1.1511.0"
)
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (p *BasicPortForwarding) Stop() {
}

// InitializeStreams establishes connection and initializes the stream
func (p *BasicPortForwarding) InitializeStreams(log log.T) (err error) {
func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string) (err error) {
p.handleControlSignals(log)
if err = p.startLocalConn(log); err != nil {
return
Expand Down
20 changes: 14 additions & 6 deletions src/sessionmanagerplugin/session/portsession/muxportforwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/aws/SSMCLI/src/message"
"github.com/aws/SSMCLI/src/sessionmanagerplugin/session"
"github.com/aws/SSMCLI/src/sessionmanagerplugin/session/sessionutil"
"github.com/aws/SSMCLI/src/version"
"github.com/xtaci/smux"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -90,12 +91,12 @@ func (p *MuxPortForwarding) Stop() {
}

// InitializeStreams initializes i/o streams
func (p *MuxPortForwarding) InitializeStreams(log log.T) (err error) {
func (p *MuxPortForwarding) InitializeStreams(log log.T, agentVersion string) (err error) {

p.handleControlSignals(log)
p.socketFile = getUnixSocketPath(p.sessionId, os.TempDir(), "session_manager_plugin_mux.sock")

if err = p.initialize(log); err != nil {
if err = p.initialize(log, agentVersion); err != nil {
p.cleanUp()
}
return
Expand Down Expand Up @@ -142,7 +143,7 @@ func (p *MuxPortForwarding) cleanUp() {
}

// initialize opens a network connection that acts as smux client
func (p *MuxPortForwarding) initialize(log log.T) (err error) {
func (p *MuxPortForwarding) initialize(log log.T, agentVersion string) (err error) {

// open a network listener
var listener net.Listener
Expand All @@ -165,10 +166,17 @@ func (p *MuxPortForwarding) initialize(log log.T) (err error) {
g.Go(func() error {
if muxConn, err := net.Dial(listener.Addr().Network(), listener.Addr().String()); err != nil {
return err
} else if muxSession, err := smux.Client(muxConn, nil); err != nil {
return err
} else {
p.muxClient = &MuxClient{muxConn, muxSession}
smuxConfig := smux.DefaultConfig()
if version.DoesAgentSupportDisableSmuxKeepAlive(log, agentVersion) {
// Disable smux KeepAlive or else it breaks Session Manager idle timeout.
smuxConfig.KeepAliveDisabled = true
}
if muxSession, err := smux.Client(muxConn, smuxConfig); err != nil {
return err
} else {
p.muxClient = &MuxClient{muxConn, muxSession}
}
}
return nil
})
Expand Down
4 changes: 2 additions & 2 deletions src/sessionmanagerplugin/session/portsession/portsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type PortSession struct {

type IPortSession interface {
IsStreamNotSet() (status bool)
InitializeStreams(log log.T) (err error)
InitializeStreams(log log.T, agentVersion string) (err error)
ReadStream(log log.T) (err error)
WriteStream(outputMessage message.ClientMessage) (err error)
Stop()
Expand Down Expand Up @@ -111,7 +111,7 @@ func (s *PortSession) Stop() {

// StartSession redirects inputStream/outputStream data to datachannel.
func (s *PortSession) SetSessionHandlers(log log.T) (err error) {
if err = s.portSessionType.InitializeStreams(log); err != nil {
if err = s.portSessionType.InitializeStreams(log, s.DataChannel.GetAgentVersion()); err != nil {
return err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (p *StandardStreamForwarding) Stop() {
}

// InitializeStreams initializes the streams with its file descriptors
func (p *StandardStreamForwarding) InitializeStreams(log log.T) (err error) {
func (p *StandardStreamForwarding) InitializeStreams(log log.T, agentVersion string) (err error) {
p.inputStream = os.Stdin
p.outputStream = os.Stdout
return
Expand Down
5 changes: 5 additions & 0 deletions src/version/versionvalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ func DoesAgentSupportTCPMultiplexing(log log.T, agentVersion string) (supported
return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TCPMultiplexingSupportedAfterThisAgentVersion)
}

// DoesAgentSupportDisableSmuxKeepAlive returns true if given agentVersion disables smux KeepAlive in TCP multiplexing in port plugin, false otherwise
func DoesAgentSupportDisableSmuxKeepAlive(log log.T, agentVersion string) (supported bool) {
return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TCPMultiplexingWithSmuxKeepAliveDisabledAfterThisAgentVersion)
}

// DoesAgentSupportTerminateSessionFlag returns true if given agentVersion supports TerminateSession flag, false otherwise
func DoesAgentSupportTerminateSessionFlag(log log.T, agentVersion string) (supported bool) {
return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TerminateSessionFlagSupportedAfterThisAgentVersion)
Expand Down
8 changes: 8 additions & 0 deletions src/version/versionvalidator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,11 @@ func TestDoesAgentSupportTerminateSessionFlagForNotSupportedScenario(t *testing.
func TestDoesAgentSupportTerminateSessionFlagWhenAgentVersionIsEqualSupportedAfterVersion(t *testing.T) {
assert.False(t, DoesAgentSupportTerminateSessionFlag(mockLog, "2.3.722.0"))
}

func TestDoesAgentSupportDisableSmuxKeepAliveForNotSupportedScenario(t *testing.T) {
assert.False(t, DoesAgentSupportDisableSmuxKeepAlive(mockLog, "3.1.1476.0"))
}

func TestDoesAgentSupportDisableSmuxKeepAliveForSupportedScenario(t *testing.T) {
assert.True(t, DoesAgentSupportDisableSmuxKeepAlive(mockLog, "3.1.1600.0"))
}

0 comments on commit 986e79d

Please sign in to comment.