Skip to content

Commit

Permalink
[secure-transport] separate transport/socket and session states (open…
Browse files Browse the repository at this point in the history
…thread#11022)

This commit updates how state is tracked in the `SecureTransport`
class. It directly tracks whether the transport/socket has been
opened or closed in a new member variable `mIsOpen`. The TLS/DTLS
session state is tracked separately in `mSessionState`. This
separation allows for future changes to support multiple sessions
using the same transport/socket.

This commit also simplifies the session states, adding "disconnected"
and "disconnecting" (replacing "close notify") states.
  • Loading branch information
abtink authored Dec 11, 2024
1 parent c427043 commit c2d5265
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 71 deletions.
116 changes: 68 additions & 48 deletions src/core/meshcop/secure_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ SecureTransport::SecureTransport(Instance &aInstance, LinkSecurityMode aLayerTwo
: InstanceLocator(aInstance)
, mLayerTwoSecurity(aLayerTwoSecurity)
, mDatagramTransport(aDatagramTransport)
, mIsOpen(false)
, mIsServer(true)
, mTimerSet(false)
, mVerifyPeerCertificate(true)
, mState(kStateClosed)
, mSessionState(kSessionDisconnected)
, mCipherSuite(kUnspecifiedCipherSuite)
, mMessageSubType(Message::kSubTypeNone)
, mConnectEvent(kDisconnectedError)
Expand Down Expand Up @@ -123,12 +124,12 @@ void SecureTransport::FreeMbedtls(void)
mbedtls_ssl_free(&mSsl);
}

void SecureTransport::SetState(State aState)
void SecureTransport::SetSessionState(SessionState aSessionState)
{
VerifyOrExit(mState != aState);
VerifyOrExit(mSessionState != aSessionState);

LogInfo("State: %s -> %s", StateToString(mState), StateToString(aState));
mState = aState;
LogInfo("State: %s -> %s", SessionStateToString(mSessionState), SessionStateToString(aSessionState));
mSessionState = aSessionState;

exit:
return;
Expand All @@ -138,16 +139,17 @@ Error SecureTransport::Open(ReceiveHandler aReceiveHandler, ConnectedHandler aCo
{
Error error;

VerifyOrExit(IsStateClosed(), error = kErrorAlready);
VerifyOrExit(!mIsOpen, error = kErrorAlready);

SuccessOrExit(error = mSocket.Open(Ip6::kNetifUnspecified));

mIsOpen = true;
mConnectedCallback.Set(aConnectedHandler, aContext);
mReceiveCallback.Set(aReceiveHandler, aContext);

mRemainingConnectionAttempts = mMaxConnectionAttempts;

SetState(kStateOpen);
SetSessionState(kSessionDisconnected);

exit:
return error;
Expand All @@ -157,7 +159,7 @@ Error SecureTransport::SetMaxConnectionAttempts(uint16_t aMaxAttempts, AutoClose
{
Error error = kErrorNone;

VerifyOrExit(IsStateClosed(), error = kErrorInvalidState);
VerifyOrExit(!mIsOpen, error = kErrorInvalidState);

mMaxConnectionAttempts = aMaxAttempts;
mAutoCloseCallback.Set(aCallback, aContext);
Expand All @@ -170,7 +172,8 @@ Error SecureTransport::Connect(const Ip6::SockAddr &aSockAddr)
{
Error error;

VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
VerifyOrExit(mIsOpen, error = kErrorInvalidState);
VerifyOrExit(IsSessionDisconnected(), error = kErrorInvalidState);

if (mRemainingConnectionAttempts > 0)
{
Expand All @@ -190,9 +193,9 @@ Error SecureTransport::Connect(const Ip6::SockAddr &aSockAddr)

void SecureTransport::HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
{
VerifyOrExit(!IsStateClosed());
VerifyOrExit(mIsOpen);

if (IsStateOpen())
if (IsSessionDisconnected())
{
if (mRemainingConnectionAttempts > 0)
{
Expand All @@ -215,7 +218,7 @@ void SecureTransport::HandleReceive(Message &aMessage, const Ip6::MessageInfo &a
}

#ifdef MBEDTLS_SSL_SRV_C
if (IsStateConnecting())
if (IsSessionConnecting())
{
mbedtls_ssl_set_client_transport_id(&mSsl, mMessageInfo.GetPeerAddr().GetBytes(), sizeof(Ip6::Address));
}
Expand All @@ -233,7 +236,8 @@ Error SecureTransport::Bind(uint16_t aPort)
{
Error error;

VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
VerifyOrExit(mIsOpen, error = kErrorInvalidState);
VerifyOrExit(IsSessionDisconnected(), error = kErrorInvalidState);
VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);

SuccessOrExit(error = mSocket.Bind(aPort));
Expand All @@ -247,7 +251,8 @@ Error SecureTransport::Bind(TransportCallback aCallback, void *aContext)
{
Error error = kErrorNone;

VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
VerifyOrExit(mIsOpen, error = kErrorInvalidState);
VerifyOrExit(IsSessionDisconnected(), error = kErrorInvalidState);
VerifyOrExit(!mSocket.IsBound(), error = kErrorAlready);
VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);

Expand All @@ -260,14 +265,15 @@ Error SecureTransport::Bind(TransportCallback aCallback, void *aContext)

Error SecureTransport::Setup(void)
{
int rval;
Error error = kErrorNone;
int rval = 0;

OT_ASSERT(mCipherSuite != kUnspecifiedCipherSuite);

// do not handle new connection before guard time expired
VerifyOrExit(IsStateOpen(), rval = MBEDTLS_ERR_SSL_TIMEOUT);
VerifyOrExit(mIsOpen, error = kErrorInvalidState);
VerifyOrExit(IsSessionDisconnected(), error = kErrorBusy);

SetState(kStateInitializing);
SetSessionState(kSessionInitializing);

//- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// Setup the mbedtls_ssl_config `mConf`.
Expand Down Expand Up @@ -410,46 +416,54 @@ Error SecureTransport::Setup(void)
mReceiveMessage = nullptr;
mMessageSubType = Message::kSubTypeNone;

SetState(kStateConnecting);
SetSessionState(kSessionConnecting);

Process();

exit:
if (IsStateInitializing() && (rval != 0))
if (mIsOpen && IsSessionInitializing())
{
error = Crypto::MbedTls::MapError(rval);

if ((mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0))
{
Close();
mAutoCloseCallback.InvokeIfSet();
}
else
{
SetState(kStateOpen);
SetSessionState(kSessionDisconnected);
FreeMbedtls();
}
}

return Crypto::MbedTls::MapError(rval);
return error;
}

void SecureTransport::Close(void)
{
VerifyOrExit(mIsOpen);

Disconnect(kDisconnectedLocalClosed);
SetSessionState(kSessionDisconnected);

SetState(kStateClosed);
mIsOpen = false;
mTimerSet = false;
mTransportCallback.Clear();

IgnoreError(mSocket.Close());
mTimer.Stop();

exit:
return;
}

void SecureTransport::Disconnect(ConnectEvent aEvent)
{
VerifyOrExit(IsStateConnectingOrConnected());
VerifyOrExit(mIsOpen);
VerifyOrExit(IsSessionConnectingOrConnected());

mbedtls_ssl_close_notify(&mSsl);
SetState(kStateCloseNotify);
SetSessionState(kSessionDisconnecting);
mConnectEvent = aEvent;
mTimer.Start(kGuardTimeNewConnectionMilli);

Expand Down Expand Up @@ -726,11 +740,15 @@ void SecureTransport::HandleTimer(Timer &aTimer)

void SecureTransport::HandleTimer(void)
{
if (IsStateConnectingOrConnected())
VerifyOrExit(mIsOpen);

if (IsSessionConnectingOrConnected())
{
Process();
ExitNow();
}
else if (IsStateCloseNotify())

if (IsSessionDisconnecting())
{
if ((mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0))
{
Expand All @@ -740,11 +758,15 @@ void SecureTransport::HandleTimer(void)
}
else
{
SetState(kStateOpen);
SetSessionState(kSessionDisconnected);
mTimer.Stop();
}

mConnectedCallback.InvokeIfSet(mConnectEvent);
}

exit:
return;
}

void SecureTransport::Process(void)
Expand All @@ -754,15 +776,15 @@ void SecureTransport::Process(void)
ConnectEvent disconnectEvent;
bool shouldReset;

while (IsStateConnectingOrConnected())
while (IsSessionConnectingOrConnected())
{
if (IsStateConnecting())
if (IsSessionConnecting())
{
rval = mbedtls_ssl_handshake(&mSsl);

if (IsMbedtlsHandshakeOver(&mSsl))
{
SetState(kStateConnected);
SetSessionState(kSessionConnected);
mConnectEvent = kConnected;
mConnectedCallback.InvokeIfSet(mConnectEvent);
}
Expand Down Expand Up @@ -879,29 +901,27 @@ void SecureTransport::HandleMbedtlsDebug(int aLevel, const char *aFile, int aLin

#if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)

const char *SecureTransport::StateToString(State aState)
const char *SecureTransport::SessionStateToString(SessionState aState)
{
static const char *const kStateStrings[] = {
"Closed", // (0) kStateClosed
"Open", // (1) kStateOpen
"Initializing", // (2) kStateInitializing
"Connecting", // (3) kStateConnecting
"Connected", // (4) kStateConnected
"CloseNotify", // (5) kStateCloseNotify
static const char *const kSessionStrings[] = {
"Disconnected", // (0) kSessionDisconnected
"Initializing", // (1) kSessionInitializing
"Connecting", // (2) kSessionConnecting
"Connected", // (3) kSessionConnected
"Disconnecting", // (4) kSessionDisconnecting
};

struct EnumCheck
{
InitEnumValidatorCounter();
ValidateNextEnum(kStateClosed);
ValidateNextEnum(kStateOpen);
ValidateNextEnum(kStateInitializing);
ValidateNextEnum(kStateConnecting);
ValidateNextEnum(kStateConnected);
ValidateNextEnum(kStateCloseNotify);
ValidateNextEnum(kSessionDisconnected);
ValidateNextEnum(kSessionInitializing);
ValidateNextEnum(kSessionConnecting);
ValidateNextEnum(kSessionConnected);
ValidateNextEnum(kSessionDisconnecting);
};

return kStateStrings[aState];
return kSessionStrings[aState];
}

#endif
Expand Down Expand Up @@ -1059,7 +1079,7 @@ Error SecureTransport::Extension::GetPeerCertificateBase64(unsigned char *aPeerC
{
Error error = kErrorNone;

VerifyOrExit(mSecureTransport.IsStateConnected(), error = kErrorInvalidState);
VerifyOrExit(mSecureTransport.IsSessionConnected(), error = kErrorInvalidState);

#if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
VerifyOrExit(
Expand Down
48 changes: 25 additions & 23 deletions src/core/meshcop/secure_transport.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,23 +442,23 @@ class SecureTransport : public InstanceLocator
* @retval TRUE If session is active.
* @retval FALSE If session is not active.
*/
bool IsConnectionActive(void) const { return mState >= kStateConnecting; }
bool IsConnectionActive(void) const { return mSessionState >= kSessionConnecting; }

/**
* Indicates whether or not the session is connected.
*
* @retval TRUE The session is connected.
* @retval FALSE The session is not connected.
*/
bool IsConnected(void) const { return mState == kStateConnected; }
bool IsConnected(void) const { return mSessionState == kSessionConnected; }

/**
* Indicates whether or not the session is closed.
* Indicates whether or not the secure transpose socket is closed.
*
* @retval TRUE The session is closed.
* @retval FALSE The session is not closed.
* @retval TRUE The secure transport socket closed.
* @retval FALSE The secure transport socket is not closed.
*/
bool IsClosed(void) const { return mState == kStateClosed; }
bool IsClosed(void) const { return !mIsOpen; }

/**
* Disconnects the session.
Expand Down Expand Up @@ -527,14 +527,13 @@ class SecureTransport : public InstanceLocator
static constexpr uint16_t kApplicationDataMaxLength = OPENTHREAD_CONFIG_DTLS_APPLICATION_DATA_MAX_LENGTH;
#endif

enum State : uint8_t
enum SessionState : uint8_t
{
kStateClosed, // UDP socket is closed.
kStateOpen, // UDP socket is open.
kStateInitializing, // The service is initializing.
kStateConnecting, // The service is establishing a connection.
kStateConnected, // The service has a connection established.
kStateCloseNotify, // The service is closing a connection.
kSessionDisconnected,
kSessionInitializing,
kSessionConnecting,
kSessionConnected,
kSessionDisconnecting,
};

enum CipherSuite : uint8_t
Expand All @@ -550,14 +549,16 @@ class SecureTransport : public InstanceLocator
kUnspecifiedCipherSuite,
};

bool IsStateClosed(void) const { return mState == kStateClosed; }
bool IsStateOpen(void) const { return mState == kStateOpen; }
bool IsStateInitializing(void) const { return mState == kStateInitializing; }
bool IsStateConnecting(void) const { return mState == kStateConnecting; }
bool IsStateConnected(void) const { return mState == kStateConnected; }
bool IsStateCloseNotify(void) const { return mState == kStateCloseNotify; }
bool IsStateConnectingOrConnected(void) const { return mState == kStateConnecting || mState == kStateConnected; }
void SetState(State aState);
bool IsSessionDisconnected(void) const { return mSessionState == kSessionDisconnected; }
bool IsSessionInitializing(void) const { return mSessionState == kSessionInitializing; }
bool IsSessionConnecting(void) const { return mSessionState == kSessionConnecting; }
bool IsSessionConnected(void) const { return mSessionState == kSessionConnected; }
bool IsSessionDisconnecting(void) const { return mSessionState == kSessionDisconnecting; }
bool IsSessionConnectingOrConnected(void) const
{
return mSessionState == kSessionConnecting || mSessionState == kSessionConnected;
}
void SetSessionState(SessionState aSessionState);

void FreeMbedtls(void);
Error Setup(void);
Expand Down Expand Up @@ -621,7 +622,7 @@ class SecureTransport : public InstanceLocator
void Disconnect(ConnectEvent aEvent);

#if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
static const char *StateToString(State aState);
static const char *SessionStateToString(SessionState aState);
#endif

using TransportSocket = Ip6::Udp::SocketIn<SecureTransport, &SecureTransport::HandleReceive>;
Expand All @@ -644,10 +645,11 @@ class SecureTransport : public InstanceLocator

bool mLayerTwoSecurity : 1;
bool mDatagramTransport : 1;
bool mIsOpen : 1;
bool mIsServer : 1;
bool mTimerSet : 1;
bool mVerifyPeerCertificate : 1;
State mState;
SessionState mSessionState;
CipherSuite mCipherSuite;
Message::SubType mMessageSubType;
ConnectEvent mConnectEvent;
Expand Down

0 comments on commit c2d5265

Please sign in to comment.