Skip to content

Commit

Permalink
Clean up unusded code in SecureSessionTable (#11041)
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost authored Oct 27, 2021
1 parent 9b30c45 commit d786c58
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 337 deletions.
5 changes: 3 additions & 2 deletions src/app/tests/TestCommandInteraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,9 @@ void TestCommandInteraction::TestCommandSenderWithSendCommand(nlTestSuite * apSu
System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);

AddCommandDataIB(apSuite, apContext, &commandSender, false);
err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional<SessionHandle>::Missing());
NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_NOT_CONNECTED);
err =
commandSender.SendCommandRequest(0 /* nodeid */, 0 /* fabricindex */, Optional<SessionHandle>(ctx.GetSessionBobToAlice()));
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);

GenerateReceivedCommand(apSuite, apContext, buf, true /*aNeedCommandData*/);
err = commandSender.ProcessCommandMessage(std::move(buf), Command::CommandRoleId::SenderId);
Expand Down
240 changes: 40 additions & 200 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,73 +30,31 @@ namespace Transport {
constexpr const uint16_t kAnyKeyId = 0xffff;

/**
* Handles a set of peer connection states.
* Handles a set of sessions.
*
* Intended for:
* - handle connection active time and expiration
* - allocate and free space for connection states.
* - handle session active time and expiration
* - allocate and free space for sessions.
*/
template <size_t kMaxConnectionCount, Time::Source kTimeSource = Time::Source::kSystem>
template <size_t kMaxSessionCount, Time::Source kTimeSource = Time::Source::kSystem>
class SecureSessionTable
{
public:
/**
* Allocates a new peer connection state state object out of the internal resource pool.
* Allocates a new secure session out of the internal resource pool.
*
* @param address represents the connection state address
* @param state [out] will contain the connection state if one was available. May be null if no return value is desired.
*
* @note the newly created state will have an 'active' time set based on the current time source.
*
* @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum connection count
* has been reached (with CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR CreateNewPeerConnectionState(const PeerAddress & address, SecureSession ** state)
{
CHIP_ERROR err = CHIP_ERROR_NO_MEMORY;

if (state)
{
*state = nullptr;
}

for (size_t i = 0; i < kMaxConnectionCount; i++)
{
if (!mStates[i].IsInitialized())
{
mStates[i] = SecureSession(address);
mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs());

if (state)
{
*state = &mStates[i];
}

err = CHIP_NO_ERROR;
break;
}
}

return err;
}

/**
* Allocates a new peer connection state state object out of the internal resource pool.
*
* @param peerNode represents optional peer Node's ID
* @param peerNode represents peer Node's ID
* @param peerSessionId represents the encryption key ID assigned by peer node
* @param localSessionId represents the encryption key ID assigned by local node
* @param state [out] will contain the connection state if one was available. May be null if no return value is desired.
* @param state [out] will contain the session if one was available. May be null if no return value is desired.
*
* @note the newly created state will have an 'active' time set based on the current time source.
*
* @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum connection count
* @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum session count
* has been reached (with CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR CreateNewPeerConnectionState(const Optional<NodeId> & peerNode, uint16_t peerSessionId, uint16_t localSessionId,
SecureSession ** state)
CHIP_ERROR CreateNewSecureSession(NodeId peerNode, uint16_t peerSessionId, uint16_t localSessionId, SecureSession ** state)
{
CHIP_ERROR err = CHIP_ERROR_NO_MEMORY;

Expand All @@ -105,20 +63,16 @@ class SecureSessionTable
*state = nullptr;
}

for (size_t i = 0; i < kMaxConnectionCount; i++)
for (size_t i = 0; i < kMaxSessionCount; i++)
{
if (!mStates[i].IsInitialized())
{
mStates[i] = SecureSession();
mStates[i].SetPeerNodeId(peerNode);
mStates[i].SetPeerSessionId(peerSessionId);
mStates[i].SetLocalSessionId(localSessionId);
mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs());

if (peerNode.ValueOr(kUndefinedNodeId) != kUndefinedNodeId)
{
mStates[i].SetPeerNodeId(peerNode.Value());
}

if (state)
{
*state = &mStates[i];
Expand All @@ -133,56 +87,25 @@ class SecureSessionTable
}

/**
* Get a peer connection state given a Peer address.
* Get a secure session given a Node Id.
*
* @param address is the connection to find (based on address)
* @param nodeId is the session to find (based on nodeId).
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindPeerConnectionState(const PeerAddress & address, SecureSession * begin)
SecureSession * FindSecureSession(NodeId nodeId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxConnectionCount])
if (begin >= iter && begin < &mStates[kMaxSessionCount])
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxConnectionCount]; iter++)
{
if (iter->GetPeerAddress() == address)
{
state = iter;
break;
}
}
return state;
}

/**
* Get a peer connection state given a Node Id.
*
* @param nodeId is the connection to find (based on nodeId). Note that initial connections
* do not have a node id set. Use this if you know the node id should be set.
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindPeerConnectionState(NodeId nodeId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxConnectionCount])
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxConnectionCount]; iter++)
for (; iter < &mStates[kMaxSessionCount]; iter++)
{
if (!iter->IsInitialized())
{
Expand All @@ -198,131 +121,48 @@ class SecureSessionTable
}

/**
* Get a peer connection state given a Node Id and Peer's Encryption Key Id.
*
* @param nodeId is the connection to find (based on nodeId). Note that initial connections
* do not have a node id set. Use this if you know the node id should be set.
* @param peerSessionId Encryption key ID used by the peer node.
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
* Get a secure session given a Node Id and Peer's Encryption Key Id.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindPeerConnectionState(Optional<NodeId> nodeId, uint16_t peerSessionId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxConnectionCount])
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxConnectionCount]; iter++)
{
if (!iter->IsInitialized())
{
continue;
}
if (peerSessionId == kAnyKeyId || iter->GetPeerSessionId() == peerSessionId)
{
if (nodeId.ValueOr(kUndefinedNodeId) == kUndefinedNodeId || iter->GetPeerNodeId() == kUndefinedNodeId ||
iter->GetPeerNodeId() == nodeId.Value())
{
state = iter;
break;
}
}
}
return state;
}

/**
* Get a peer connection state given the local Encryption Key Id.
*
* @param keyId Encryption key ID assigned by the local node.
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindPeerConnectionState(uint16_t keyId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

VerifyOrDie(begin == nullptr || (begin >= iter && begin < &mStates[kMaxConnectionCount]));

if (begin != nullptr)
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxConnectionCount]; iter++)
{
if (!iter->IsInitialized())
{
continue;
}

if (iter->GetLocalSessionId() == keyId)
{
state = iter;
break;
}
}
return state;
}

/**
* Get a peer connection state given a Node Id and Peer's Encryption Key Id.
*
* @param nodeId is the connection to find (based on peer nodeId). Note that initial connections
* do not have a node id set. Use this if you know the node id should be set.
* @param localSessionId Encryption key ID used by the local node.
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindPeerConnectionStateByLocalKey(Optional<NodeId> nodeId, uint16_t localSessionId, SecureSession * begin)
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxConnectionCount])
if (begin >= iter && begin < &mStates[kMaxSessionCount])
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxConnectionCount]; iter++)
for (; iter < &mStates[kMaxSessionCount]; iter++)
{
if (!iter->IsInitialized())
{
continue;
}
if (iter->GetLocalSessionId() == localSessionId)
{
if (nodeId.ValueOr(kUndefinedNodeId) == kUndefinedNodeId || iter->GetPeerNodeId() == kUndefinedNodeId ||
iter->GetPeerNodeId() == nodeId.Value())
{
state = iter;
break;
}
state = iter;
break;
}
}
return state;
}

/**
* Get the first peer connection state that matches the given fabric index.
* Get the first session that matches the given fabric index.
*
* @param fabric The fabric index to match
*
* @return the state found, nullptr if not found
* @return the session found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindPeerConnectionStateByFabric(FabricIndex fabric)
SecureSession * FindSecureSessionByFabric(FabricIndex fabric)
{
for (auto & state : mStates)
{
Expand All @@ -338,51 +178,51 @@ class SecureSessionTable
return nullptr;
}

/// Convenience method to mark a peer connection state as active
void MarkConnectionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }
/// Convenience method to mark a session as active
void MarkSessionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }

/// Convenience method to expired a peer connection state and fired the related callback
/// Convenience method to expired a session and fired the related callback
template <typename Callback>
void MarkConnectionExpired(SecureSession * state, Callback callback)
void MarkSessionExpired(SecureSession * state, Callback callback)
{
callback(*state);
*state = SecureSession(PeerAddress::Uninitialized());
}

/**
* Iterates through all active connections and expires any connection with an idle time
* Iterates through all active sessions and expires any sessions with an idle time
* larger than the given amount.
*
* Expiring a connection involves callback execution and then clearing the internal state.
* Expiring a session involves callback execution and then clearing the internal state.
*/
template <typename Callback>
void ExpireInactiveConnections(uint64_t maxIdleTimeMs, Callback callback)
void ExpireInactiveSessions(uint64_t maxIdleTimeMs, Callback callback)
{
const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs();

for (size_t i = 0; i < kMaxConnectionCount; i++)
for (size_t i = 0; i < kMaxSessionCount; i++)
{
if (!mStates[i].GetPeerAddress().IsInitialized())
if (!mStates[i].IsInitialized())
{
continue; // not an active connection
continue; // not an active session
}

uint64_t connectionActiveTime = mStates[i].GetLastActivityTimeMs();
if (connectionActiveTime + maxIdleTimeMs >= currentTime)
uint64_t sessionActiveTime = mStates[i].GetLastActivityTimeMs();
if (sessionActiveTime + maxIdleTimeMs >= currentTime)
{
continue; // not expired
}

MarkConnectionExpired(&mStates[i], callback);
MarkSessionExpired(&mStates[i], callback);
}
}

/// Allows access to the underlying time source used for keeping track of connection active time
/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
Time::TimeSource<kTimeSource> mTimeSource;
SecureSession mStates[kMaxConnectionCount];
SecureSession mStates[kMaxSessionCount];
};

} // namespace Transport
Expand Down
Loading

0 comments on commit d786c58

Please sign in to comment.