Skip to content

Commit

Permalink
cleaner session handling by moving input to goroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
rue-nsilverman committed Jan 3, 2024
1 parent c03513b commit 010f529
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 119 deletions.
26 changes: 22 additions & 4 deletions src/datachannel/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ type IDataChannel interface {
RegisterOutputStreamHandler(handler OutputStreamDataMessageHandler, isSessionSpecificHandler bool)
DeregisterOutputStreamHandler(handler OutputStreamDataMessageHandler)
IsSessionTypeSet() chan bool
EndSession() error
IsSessionEnded() bool
IsStreamMessageResendTimeout() chan bool
GetSessionType() string
SetSessionType(sessionType string)
Expand Down Expand Up @@ -106,6 +108,8 @@ type DataChannel struct {
isSessionTypeSet chan bool
sessionProperties interface{}

isSessionEnded bool

// Used to detect if resending a streaming message reaches timeout
isStreamMessageResendTimeout chan bool

Expand Down Expand Up @@ -187,6 +191,7 @@ func (dataChannel *DataChannel) Initialize(log log.T, clientId string, sessionId
dataChannel.wsChannel = &communicator.WebSocketChannel{}
dataChannel.encryptionEnabled = false
dataChannel.isSessionTypeSet = make(chan bool, 1)
dataChannel.isSessionEnded = false
dataChannel.isStreamMessageResendTimeout = make(chan bool, 1)
dataChannel.sessionType = ""
dataChannel.IsAwsCliUpgradeNeeded = isAwsCliUpgradeNeeded
Expand All @@ -199,7 +204,7 @@ func (dataChannel *DataChannel) SetWebsocket(log log.T, channelUrl string, chann

// FinalizeHandshake sends the token for service to acknowledge the connection.
func (dataChannel *DataChannel) FinalizeDataChannelHandshake(log log.T, tokenValue string) (err error) {
uuid.SwitchFormat(uuid.CleanHyphen)
uuid.SwitchFormat(uuid.FormatCanonical)
uid := uuid.NewV4().String()

log.Infof("Sending token through data channel %s to acknowledge connection", dataChannel.wsChannel.GetStreamUrl())
Expand Down Expand Up @@ -772,7 +777,7 @@ func (dataChannel *DataChannel) HandleAcknowledgeMessage(
}

// handleChannelClosedMessage exits the shell
func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) {
func (dataChannel *DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) {
var (
channelClosedMessage message.ChannelClosed
err error
Expand All @@ -787,6 +792,8 @@ func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler
} else {
fmt.Fprintf(os.Stdout, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output)
}
dataChannel.EndSession()
dataChannel.Close(log)

stopHandler()
}
Expand Down Expand Up @@ -849,7 +856,7 @@ func (dataChannel *DataChannel) CalculateRetransmissionTimeout(log log.T, stream
func (dataChannel *DataChannel) ProcessKMSEncryptionHandshakeAction(log log.T, actionParams json.RawMessage) (err error) {

if dataChannel.IsAwsCliUpgradeNeeded {
return errors.New("Installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI).")
return errors.New("installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI)")
}
kmsEncRequest := message.KMSEncryptionRequest{}
json.Unmarshal(actionParams, &kmsEncRequest)
Expand Down Expand Up @@ -881,7 +888,7 @@ func (dataChannel *DataChannel) ProcessSessionTypeHandshakeAction(actionParams j
dataChannel.sessionProperties = sessTypeReq.Properties
return nil
default:
return errors.New(fmt.Sprintf("Unknown session type %s", sessTypeReq.SessionType))
return fmt.Errorf("Unknown session type %s", sessTypeReq.SessionType)
}
}

Expand All @@ -890,6 +897,17 @@ func (dataChannel *DataChannel) IsSessionTypeSet() chan bool {
return dataChannel.isSessionTypeSet
}

// IsSessionEnded check if session has ended
func (dataChannel *DataChannel) IsSessionEnded() bool {
return dataChannel.isSessionEnded
}

// IsSessionEnded check if session has ended
func (dataChannel *DataChannel) EndSession() error {
dataChannel.isSessionEnded = true
return nil
}

// IsStreamMessageResendTimeout checks if resending a streaming message reaches timeout
func (dataChannel *DataChannel) IsStreamMessageResendTimeout() chan bool {
return dataChannel.isStreamMessageResendTimeout
Expand Down
14 changes: 7 additions & 7 deletions src/message/messageparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,31 +167,31 @@ func getUuid(log log.T, byteArray []byte, offset int) (result uuid.UUID, err err
byteArrayLength := len(byteArray)
if offset > byteArrayLength-1 || offset+16-1 > byteArrayLength-1 || offset < 0 {
log.Error("getUuid failed: Offset is invalid.")
return nil, errors.New("Offset is outside the byte array.")
return uuid.Nil.UUID(), errors.New("Offset is outside the byte array.")
}

leastSignificantLong, err := getLong(log, byteArray, offset)
if err != nil {
log.Error("getUuid failed: failed to get uuid LSBs Long value.")
return nil, errors.New("Failed to get uuid LSBs long value.")
return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs long value.")
}

leastSignificantBytes, err := longToBytes(log, leastSignificantLong)
if err != nil {
log.Error("getUuid failed: failed to get uuid LSBs bytes value.")
return nil, errors.New("Failed to get uuid LSBs bytes value.")
return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs bytes value.")
}

mostSignificantLong, err := getLong(log, byteArray, offset+8)
if err != nil {
log.Error("getUuid failed: failed to get uuid MSBs Long value.")
return nil, errors.New("Failed to get uuid MSBs long value.")
return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs long value.")
}

mostSignificantBytes, err := longToBytes(log, mostSignificantLong)
if err != nil {
log.Error("getUuid failed: failed to get uuid MSBs bytes value.")
return nil, errors.New("Failed to get uuid MSBs bytes value.")
return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs bytes value.")
}

uuidBytes := append(mostSignificantBytes, leastSignificantBytes...)
Expand Down Expand Up @@ -414,7 +414,7 @@ func putBytes(log log.T, byteArray []byte, offsetStart int, offsetEnd int, input

// putUuid puts the 128 bit uuid to an array of bytes starting from the offset.
func putUuid(log log.T, byteArray []byte, offset int, input uuid.UUID) (err error) {
if input == nil {
if uuid.IsNil(input) {
log.Error("putUuid failed: input is null.")
return errors.New("putUuid failed: input is null.")
}
Expand Down Expand Up @@ -494,7 +494,7 @@ func SerializeClientMessageWithAcknowledgeContent(log log.T, acknowledgeContent
return
}

uuid.SwitchFormat(uuid.CleanHyphen)
uuid.SwitchFormat(uuid.FormatCanonical)
messageId := uuid.NewV4()
clientMessage := ClientMessage{
MessageType: AcknowledgeMessage,
Expand Down
66 changes: 29 additions & 37 deletions src/sessionmanagerplugin/session/portsession/basicportforwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,25 @@ import (
// accepts one client connection at a time
type BasicPortForwarding struct {
port IPortSession
stream *net.Conn
listener *net.Listener
stream net.Conn
listener net.Listener
sessionId string
portParameters PortParameters
session session.Session
}

// getNewListener returns a new listener to given address and type like tcp, unix etc.
var getNewListener = func(listenerType string, listenerAddress string) (listener net.Listener, err error) {
return net.Listen(listenerType, listenerAddress)
}

// acceptConnection returns connection to the listener
var acceptConnection = func(log log.T, listener net.Listener) (tcpConn net.Conn, err error) {
return listener.Accept()
}

// IsStreamNotSet checks if stream is not set
func (p *BasicPortForwarding) IsStreamNotSet() (status bool) {
return p.stream == nil
}

// Stop closes the stream
func (p *BasicPortForwarding) Stop() {
p.listener.Close()
if p.stream != nil {
(*p.stream).Close()
p.stream.Close()
}
os.Exit(0)
return
}

// InitializeStreams establishes connection and initializes the stream
Expand All @@ -77,7 +68,7 @@ func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string)
func (p *BasicPortForwarding) ReadStream(log log.T) (err error) {
msg := make([]byte, config.StreamDataPayloadSize)
for {
numBytes, err := (*p.stream).Read(msg)
numBytes, err := p.stream.Read(msg)
if err != nil {
log.Debugf("Reading from port %s failed with error: %v. Close this connection, listen and accept new one.",
p.portParameters.PortNumber, err)
Expand Down Expand Up @@ -108,7 +99,7 @@ func (p *BasicPortForwarding) ReadStream(log log.T) (err error) {

// WriteStream writes data to stream
func (p *BasicPortForwarding) WriteStream(outputMessage message.ClientMessage) error {
_, err := (*p.stream).Write(outputMessage.Payload)
_, err := p.stream.Write(outputMessage.Payload)
return err
}

Expand All @@ -120,41 +111,40 @@ func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) {
localPortNumber = "0"
}

var listener net.Listener
if listener, err = p.startLocalListener(log, localPortNumber); err != nil {
if err = p.startLocalListener(log, localPortNumber); err != nil {
log.Errorf("Unable to open tcp connection to port. %v", err)
return err
}

var tcpConn net.Conn
if tcpConn, err = acceptConnection(log, listener); err != nil {
log.Errorf("Failed to accept connection with error. %v", err)
return err
if p.stream, err = p.listener.Accept(); err != nil {
if p.session.DataChannel.IsSessionEnded() == false {
log.Errorf("Failed to accept connection with error. %v", err)
return err
}
}
if p.session.DataChannel.IsSessionEnded() == false {
log.Infof("Connection accepted for session %s.", p.sessionId)
fmt.Printf("Connection accepted for session %s.\n", p.sessionId)
}
log.Infof("Connection accepted for session %s.", p.sessionId)
fmt.Printf("Connection accepted for session %s.\n", p.sessionId)

p.listener = &listener
p.stream = &tcpConn

return
}

// startLocalListener starts a local listener to given address
func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (listener net.Listener, err error) {
func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (err error) {
var displayMessage string
switch p.portParameters.LocalConnectionType {
case "unix":
if listener, err = getNewListener(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil {
if p.listener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil {
return
}
displayMessage = fmt.Sprintf("Unix socket %s opened for sessionId %s.", p.portParameters.LocalUnixSocket, p.sessionId)
default:
if listener, err = getNewListener("tcp", "localhost:"+portNumber); err != nil {
if p.listener, err = net.Listen("tcp", "localhost:"+portNumber); err != nil {
return
}
// get port number the TCP listener opened
p.portParameters.LocalPortNumber = strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)
p.portParameters.LocalPortNumber = strconv.Itoa(p.listener.Addr().(*net.TCPAddr).Port)
displayMessage = fmt.Sprintf("Port %s opened for sessionId %s.", p.portParameters.LocalPortNumber, p.sessionId)
}

Expand All @@ -171,29 +161,31 @@ func (p *BasicPortForwarding) handleControlSignals(log log.T) {
<-c
fmt.Println("Terminate signal received, exiting.")

p.session.DataChannel.EndSession()
if version.DoesAgentSupportTerminateSessionFlag(log, p.session.DataChannel.GetAgentVersion()) {
if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil {
log.Errorf("Failed to send TerminateSession flag: %v", err)
}
fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId)
p.Stop()
} else {
p.session.TerminateSession(log)
}
p.Stop()
}()
}

// reconnect closes existing connection, listens to new connection and accept it
func (p *BasicPortForwarding) reconnect(log log.T) (err error) {
// close existing connection as it is in a state from which data cannot be read
(*p.stream).Close()
p.stream.Close()

// wait for new connection on listener and accept it
var conn net.Conn
if conn, err = acceptConnection(log, *p.listener); err != nil {
return log.Errorf("Failed to accept connection with error. %v", err)
if p.stream, err = p.listener.Accept(); err != nil {
if p.session.DataChannel.IsSessionEnded() == false {
log.Errorf("Failed to accept connection with error. %v", err)
return err
}
}
p.stream = &conn

return
}
Loading

0 comments on commit 010f529

Please sign in to comment.