Skip to content

Commit

Permalink
Add a simple variant implementation similiar to std::variant (#6624)
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost authored and pull[bot] committed Aug 21, 2021
1 parent a9eb2ee commit 1610964
Show file tree
Hide file tree
Showing 6 changed files with 423 additions and 54 deletions.
73 changes: 37 additions & 36 deletions src/channel/ChannelContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ void ChannelContext::Start(const ChannelBuilder & builder)
ExchangeContext * ChannelContext::NewExchange(ExchangeDelegate * delegate)
{
assert(GetState() == ChannelState::kReady);
return mExchangeManager->NewContext(mStateVars.mReady.mSession, delegate);
return mExchangeManager->NewContext(GetReadyVars().mSession, delegate);
}

bool ChannelContext::MatchNodeId(NodeId nodeId)
{
switch (mState)
{
case ChannelState::kPreparing:
return nodeId == mStateVars.mPreparing.mBuilder.GetPeerNodeId();
return nodeId == GetPrepareVars().mBuilder.GetPeerNodeId();
case ChannelState::kReady: {
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(mStateVars.mReady.mSession);
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(GetReadyVars().mSession);
if (state == nullptr)
return false;
return nodeId == state->GetPeerNodeId();
Expand All @@ -63,7 +63,7 @@ bool ChannelContext::MatchTransport(Transport::Type transport)
switch (mState)
{
case ChannelState::kPreparing:
switch (mStateVars.mPreparing.mBuilder.GetTransportPreference())
switch (GetPrepareVars().mBuilder.GetTransportPreference())
{
case ChannelBuilder::TransportPreference::kPreferConnectionOriented:
case ChannelBuilder::TransportPreference::kConnectionOriented:
Expand All @@ -73,7 +73,7 @@ bool ChannelContext::MatchTransport(Transport::Type transport)
}
return false;
case ChannelState::kReady: {
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(mStateVars.mReady.mSession);
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(GetReadyVars().mSession);
if (state == nullptr)
return false;
return transport == state->GetPeerAddress().GetTransportType();
Expand Down Expand Up @@ -118,36 +118,38 @@ bool ChannelContext::MatchesBuilder(const ChannelBuilder & builder)

bool ChannelContext::IsCasePairing()
{
return mState == ChannelState::kPreparing && mStateVars.mPreparing.mState == PrepareState::kCasePairing;
return mState == ChannelState::kPreparing && GetPrepareVars().mState == PrepareState::kCasePairing;
}

bool ChannelContext::MatchesSession(SecureSessionHandle session, SecureSessionMgr * ssm)
{
switch (mState)
{
case ChannelState::kPreparing: {
switch (mStateVars.mPreparing.mState)
switch (GetPrepareVars().mState)
{
case PrepareState::kCasePairing: {
auto state = ssm->GetPeerConnectionState(session);
return (state->GetPeerNodeId() == mStateVars.mPreparing.mBuilder.GetPeerNodeId() &&
state->GetPeerKeyID() == mStateVars.mPreparing.mBuilder.GetPeerKeyID());
return (state->GetPeerNodeId() == GetPrepareVars().mBuilder.GetPeerNodeId() &&
state->GetPeerKeyID() == GetPrepareVars().mBuilder.GetPeerKeyID());
}
default:
return false;
}
}
case ChannelState::kReady:
return mStateVars.mReady.mSession == session;
return GetReadyVars().mSession == session;
default:
return false;
}
}

void ChannelContext::EnterPreparingState(const ChannelBuilder & builder)
{
mState = ChannelState::kPreparing;
mStateVars.mPreparing.mBuilder = builder;
mState = ChannelState::kPreparing;

mStateVars.Set<PrepareVars>();
GetPrepareVars().mBuilder = builder;

EnterAddressResolve();
}
Expand All @@ -157,14 +159,14 @@ void ChannelContext::ExitPreparingState() {}
// Address resolve
void ChannelContext::EnterAddressResolve()
{
mStateVars.mPreparing.mState = PrepareState::kAddressResolving;
GetPrepareVars().mState = PrepareState::kAddressResolving;

// Skip address resolve if the address is provided
{
auto addr = mStateVars.mPreparing.mBuilder.GetForcePeerAddress();
auto addr = GetPrepareVars().mBuilder.GetForcePeerAddress();
if (addr.HasValue())
{
mStateVars.mPreparing.mAddress = addr.Value();
GetPrepareVars().mAddress = addr.Value();
ExitAddressResolve();
// Only CASE session is supported
EnterCasePairingState();
Expand All @@ -174,10 +176,10 @@ void ChannelContext::EnterAddressResolve()

// TODO: call mDNS Scanner::SubscribeNode after PR #4459 is ready
// Scanner::RegisterScannerDelegate(this)
// Scanner::SubscribeNode(mStateVars.mPreparing.mBuilder.GetPeerNodeId())
// Scanner::SubscribeNode(GetPrepareVars().mBuilder.GetPeerNodeId())

// The HandleNodeIdResolve may already have been called, recheck the state here before set up the timer
if (mState == ChannelState::kPreparing && mStateVars.mPreparing.mState == PrepareState::kAddressResolving)
if (mState == ChannelState::kPreparing && GetPrepareVars().mState == PrepareState::kAddressResolving)
{
System::Layer * layer = mExchangeManager->GetSessionMgr()->SystemLayer();
layer->StartTimer(CHIP_CONFIG_NODE_ADDRESS_RESOLVE_TIMEOUT_MSECS, AddressResolveTimeout, this);
Expand All @@ -196,7 +198,7 @@ void ChannelContext::AddressResolveTimeout()
{
if (mState != ChannelState::kPreparing)
return;
if (mStateVars.mPreparing.mState != PrepareState::kAddressResolving)
if (GetPrepareVars().mState != PrepareState::kAddressResolving)
return;

ExitAddressResolve();
Expand All @@ -219,7 +221,7 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons
return;
}
case ChannelState::kPreparing: {
switch (mStateVars.mPreparing.mState)
switch (GetPrepareVars().mState)
{
case PrepareState::kAddressResolving: {
if (error != CHIP_NO_ERROR)
Expand All @@ -232,8 +234,8 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons

if (!address.mAddress.HasValue())
return;
mStateVars.mPreparing.mAddressType = address.mAddressType;
mStateVars.mPreparing.mAddress = address.mAddress.Value();
GetPrepareVars().mAddressType = address.mAddressType;
GetPrepareVars().mAddress = address.mAddress.Value();
ExitAddressResolve();
EnterCasePairingState();
return;
Expand All @@ -253,18 +255,18 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons

void ChannelContext::EnterCasePairingState()
{
mStateVars.mPreparing.mState = PrepareState::kCasePairing;
mStateVars.mPreparing.mCasePairingSession = Platform::New<CASESession>();
auto & prepare = GetPrepareVars();
prepare.mCasePairingSession = Platform::New<CASESession>();

ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), mStateVars.mPreparing.mCasePairingSession);
ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), prepare.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);

// TODO: currently only supports IP/UDP paring
Transport::PeerAddress addr;
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(mStateVars.mPreparing.mAddress);
CHIP_ERROR err = mStateVars.mPreparing.mCasePairingSession->EstablishSession(
addr, &mStateVars.mPreparing.mBuilder.GetOperationalCredentialSet(), mStateVars.mPreparing.mBuilder.GetPeerNodeId(),
mExchangeManager->GetNextKeyId(), ctxt, this);
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress);
CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, &prepare.mBuilder.GetOperationalCredentialSet(),
prepare.mBuilder.GetPeerNodeId(),
mExchangeManager->GetNextKeyId(), ctxt, this);
if (err != CHIP_NO_ERROR)
{
ExitCasePairingState();
Expand All @@ -275,14 +277,14 @@ void ChannelContext::EnterCasePairingState()

void ChannelContext::ExitCasePairingState()
{
Platform::Delete(mStateVars.mPreparing.mCasePairingSession);
Platform::Delete(GetPrepareVars().mCasePairingSession);
}

void ChannelContext::OnSessionEstablishmentError(CHIP_ERROR error)
{
if (mState != ChannelState::kPreparing)
return;
switch (mStateVars.mPreparing.mState)
switch (GetPrepareVars().mState)
{
case PrepareState::kCasePairing:
ExitCasePairingState();
Expand All @@ -298,11 +300,11 @@ void ChannelContext::OnSessionEstablished()
{
if (mState != ChannelState::kPreparing)
return;
switch (mStateVars.mPreparing.mState)
switch (GetPrepareVars().mState)
{
case PrepareState::kCasePairing:
ExitCasePairingState();
mStateVars.mPreparing.mState = PrepareState::kCasePairingDone;
GetPrepareVars().mState = PrepareState::kCasePairingDone;
// TODO: current CASE paring session API doesn't show how to derive a secure session
return;
default:
Expand All @@ -314,7 +316,7 @@ void ChannelContext::OnNewConnection(SecureSessionHandle session)
{
if (mState != ChannelState::kPreparing)
return;
if (mStateVars.mPreparing.mState != PrepareState::kCasePairingDone)
if (GetPrepareVars().mState != PrepareState::kCasePairingDone)
return;

ExitPreparingState();
Expand All @@ -324,8 +326,7 @@ void ChannelContext::OnNewConnection(SecureSessionHandle session)
void ChannelContext::EnterReadyState(SecureSessionHandle session)
{
mState = ChannelState::kReady;

mStateVars.mReady.mSession = session;
mStateVars.Set<ReadyVars>(session);
mChannelManager->NotifyChannelEvent(this, [](ChannelDelegate * delegate) { delegate->OnEstablished(); });
}

Expand All @@ -344,7 +345,7 @@ void ChannelContext::ExitReadyState()
// Currently SecureSessionManager doesn't provide an interface to close a session

// TODO: call mDNS Scanner::UnubscribeNode after PR #4459 is ready
// Scanner::UnsubscribeNode(mStateVars.mPreparing.mBuilder.GetPeerNodeId())
// Scanner::UnsubscribeNode(GetPrepareVars().mBuilder.GetPeerNodeId())
}

void ChannelContext::EnterFailedState(CHIP_ERROR error)
Expand Down
40 changes: 22 additions & 18 deletions src/channel/ChannelContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <channel/Channel.h>
#include <lib/core/ReferenceCounted.h>
#include <lib/mdns/platform/Mdns.h>
#include <lib/support/Variant.h>
#include <protocols/secure_channel/CASESession.h>
#include <transport/PeerConnectionState.h>
#include <transport/SecureSessionMgr.h>
Expand Down Expand Up @@ -129,25 +130,28 @@ class ChannelContext : public ReferenceCounted<ChannelContext, ChannelContextDel
kCasePairingDone,
};

union StateVars
// mPreparing is pretty big, consider move it outside
struct PrepareVars
{
StateVars() {}

// mPreparing is pretty big, consider move it outside
struct PrepareVars
{
PrepareState mState;
Inet::IPAddressType mAddressType;
Inet::IPAddress mAddress;
CASESession * mCasePairingSession;
ChannelBuilder mBuilder;
} mPreparing;

struct ReadyVars
{
SecureSessionHandle mSession;
} mReady;
} mStateVars;
static constexpr const size_t VariantId = 1;
PrepareState mState;
Inet::IPAddressType mAddressType;
Inet::IPAddress mAddress;
CASESession * mCasePairingSession;
ChannelBuilder mBuilder;
};

struct ReadyVars
{
static constexpr const size_t VariantId = 2;
ReadyVars(SecureSessionHandle session) : mSession(session) {}
const SecureSessionHandle mSession;
};

Variant<PrepareVars, ReadyVars> mStateVars;

PrepareVars & GetPrepareVars() { return mStateVars.Get<PrepareVars>(); }
ReadyVars & GetReadyVars() { return mStateVars.Get<ReadyVars>(); }

// State machine functions
void EnterPreparingState(const ChannelBuilder & builder);
Expand Down
1 change: 1 addition & 0 deletions src/lib/support/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ static_library("support") {
"TimeUtils.h",
"UnitTestRegistration.cpp",
"UnitTestRegistration.h",
"Variant.h",
"logging/CHIPLogging.cpp",
"logging/CHIPLogging.h",
"verhoeff/Verhoeff.cpp",
Expand Down
Loading

0 comments on commit 1610964

Please sign in to comment.