Skip to content

Commit

Permalink
Make FabricIndex be a common thing across all sessions. (#14904)
Browse files Browse the repository at this point in the history
* Move fabric index into session.h instead of SecureSession

* Remove the cast to securesession when getting the fabric index

* Make group sessions use the common per-session fabric index

* Restyle

* Address code review comments

* Use undefined fabric index as an invalid marker
  • Loading branch information
andy31415 authored and pull[bot] committed Feb 20, 2024
1 parent 6d5c3ef commit 4025024
Show file tree
Hide file tree
Showing 14 changed files with 27 additions and 45 deletions.
3 changes: 1 addition & 2 deletions examples/chip-tool/commands/common/CommandInvoker.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ CHIP_ERROR InvokeGroupCommand(DeviceProxy * aDevice, void * aContext,
//
// We assume the aDevice already has a Case session which is way we can use he established Secure Session
ReturnErrorOnFailure(invoker->InvokeGroupCommand(aDevice->GetExchangeManager(),
aDevice->GetSecureSession().Value()->AsSecureSession()->GetFabricIndex(),
groupId, aRequestData));
aDevice->GetSecureSession().Value()->GetFabricIndex(), groupId, aRequestData));

// invoker is already deleted and is not to be used
invoker.release();
Expand Down
12 changes: 1 addition & 11 deletions src/app/CommandHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,7 @@ TLV::TLVWriter * CommandHandler::GetCommandDataIBTLVWriter()

FabricIndex CommandHandler::GetAccessingFabricIndex() const
{
FabricIndex fabric = kUndefinedFabricIndex;
if (mpExchangeCtx->GetSessionHandle()->IsGroupSession())
{
fabric = mpExchangeCtx->GetSessionHandle()->AsGroupSession()->GetFabricIndex();
}
else
{
fabric = mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
}

return fabric;
return mpExchangeCtx->GetSessionHandle()->GetFabricIndex();
}

CommandHandler * CommandHandler::Handle::Get()
Expand Down
2 changes: 1 addition & 1 deletion src/app/InteractionModelEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ CHIP_ERROR InteractionModelEngine::OnReadInitialRequest(Messaging::ExchangeConte
ChipLogProgress(InteractionModel,
"Deleting previous subscription from NodeId: " ChipLogFormatX64 ", FabricIndex: %" PRIu8,
ChipLogValueX64(apExchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId()),
apExchangeContext->GetSessionHandle()->AsSecureSession()->GetFabricIndex());
apExchangeContext->GetSessionHandle()->GetFabricIndex());
mReadHandlers.ReleaseObject(handler);
}

Expand Down
4 changes: 2 additions & 2 deletions src/app/ReadClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ CHIP_ERROR ReadClient::SendReadRequest(ReadPrepareParams & aReadPrepareParams)
Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse)));

mPeerNodeId = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetPeerNodeId();
mFabricIndex = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetFabricIndex();
mFabricIndex = aReadPrepareParams.mSessionHolder->GetFabricIndex();

MoveToState(ClientState::AwaitingInitialReport);

Expand Down Expand Up @@ -801,7 +801,7 @@ CHIP_ERROR ReadClient::SendSubscribeRequest(ReadPrepareParams & aReadPreparePara
Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse)));

mPeerNodeId = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetPeerNodeId();
mFabricIndex = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetFabricIndex();
mFabricIndex = aReadPrepareParams.mSessionHolder->GetFabricIndex();

MoveToState(ClientState::AwaitingInitialReport);

Expand Down
2 changes: 1 addition & 1 deletion src/app/ReadHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ bool ReadHandler::IsFromSubscriber(Messaging::ExchangeContext & apExchangeContex
{
return (IsType(InteractionType::Subscribe) &&
GetInitiatorNodeId() == apExchangeContext.GetSessionHandle()->AsSecureSession()->GetPeerNodeId() &&
GetAccessingFabricIndex() == apExchangeContext.GetSessionHandle()->AsSecureSession()->GetFabricIndex());
GetAccessingFabricIndex() == apExchangeContext.GetSessionHandle()->GetFabricIndex());
}

CHIP_ERROR ReadHandler::OnUnknownMsgType(Messaging::ExchangeContext * apExchangeContext, const PayloadHeader & aPayloadHeader,
Expand Down
12 changes: 1 addition & 11 deletions src/app/WriteHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,17 +479,7 @@ CHIP_ERROR WriteHandler::AddStatus(const ConcreteAttributePath & aPath, const St

FabricIndex WriteHandler::GetAccessingFabricIndex() const
{
FabricIndex fabric = kUndefinedFabricIndex;
if (mpExchangeCtx->GetSessionHandle()->IsGroupSession())
{
fabric = mpExchangeCtx->GetSessionHandle()->AsGroupSession()->GetFabricIndex();
}
else
{
fabric = mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
}

return fabric;
return mpExchangeCtx->GetSessionHandle()->GetFabricIndex();
}

const char * WriteHandler::GetStateStr() const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ bool emberAfGeneralCommissioningClusterCommissioningCompleteCallback(
* Once bindings are implemented, this may no longer be needed.
*/
SessionHandle handle = commandObj->GetExchangeContext()->GetSessionHandle();
server->SetFabricIndex(handle->AsSecureSession()->GetFabricIndex());
server->SetFabricIndex(handle->GetFabricIndex());
server->SetPeerNodeId(handle->AsSecureSession()->GetPeerNodeId());

CheckSuccess(server->CommissioningComplete(), Failure);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ class FabricCleanupExchangeDelegate : public chip::Messaging::ExchangeDelegate
void OnResponseTimeout(chip::Messaging::ExchangeContext * ec) override {}
void OnExchangeClosing(chip::Messaging::ExchangeContext * ec) override
{
FabricIndex currentFabricIndex = ec->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
FabricIndex currentFabricIndex = ec->GetSessionHandle()->GetFabricIndex();
ec->GetExchangeMgr()->GetSessionManager()->ExpireAllPairingsForFabric(currentFabricIndex);
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/controller/CHIPCluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ CHIP_ERROR ClusterBase::AssociateWithGroup(DeviceProxy * device, GroupId groupId
{
// Local copy to preserve original SessionHandle for future Unicast communication.
Optional<SessionHandle> session = mDevice->GetExchangeManager()->GetSessionManager()->CreateGroupSession(
groupId, mDevice->GetSecureSession().Value()->AsSecureSession()->GetFabricIndex());
groupId, mDevice->GetSecureSession().Value()->GetFabricIndex());
// Sanity check
if (!session.HasValue() || !session.Value()->IsGroupSession())
{
Expand Down
2 changes: 2 additions & 0 deletions src/credentials/FabricTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ static constexpr FabricIndex kMinValidFabricIndex = 1;
static constexpr FabricIndex kMaxValidFabricIndex = std::min<FabricIndex>(UINT8_MAX - 1, CHIP_CONFIG_MAX_FABRICS);
static constexpr uint8_t kFabricLabelMaxLengthInBytes = 32;

static_assert(kUndefinedFabricIndex < chip::kMinValidFabricIndex, "Undefined fabric index should not be valid");

// KVS store is sensitive to length of key strings, based on the underlying
// platform. Keeping them short.
constexpr char kFabricTableKeyPrefix[] = "Fabric";
Expand Down
4 changes: 1 addition & 3 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace Transport {
class GroupSession : public Session
{
public:
GroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group), mFabricIndex(fabricIndex) {}
GroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group) { SetFabricIndex(fabricIndex); }
~GroupSession() { NotifySessionReleased(); }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroup; }
Expand Down Expand Up @@ -59,11 +59,9 @@ class GroupSession : public Session
}

GroupId GetGroupId() const { return mGroupId; }
FabricIndex GetFabricIndex() const { return mFabricIndex; }

private:
const GroupId mGroupId;
const FabricIndex mFabricIndex;
};

/*
Expand Down
4 changes: 2 additions & 2 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
subjectDescriptor.authMode = Access::AuthMode::kCase;
subjectDescriptor.subject = mPeerNodeId;
subjectDescriptor.cats = mPeerCATs;
subjectDescriptor.fabricIndex = mFabric;
subjectDescriptor.fabricIndex = GetFabricIndex();
}
else if (IsPAKEKeyId(mPeerNodeId))
{
subjectDescriptor.authMode = Access::AuthMode::kPase;
subjectDescriptor.subject = mPeerNodeId;
subjectDescriptor.fabricIndex = mFabric;
subjectDescriptor.fabricIndex = GetFabricIndex();
}
else
{
Expand Down
15 changes: 6 additions & 9 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ class SecureSession : public Session
FabricIndex fabric, const ReliableMessageProtocolConfig & config) :
mSecureSessionType(secureSessionType),
mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId),
mFabric(fabric), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}
mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{
SetFabricIndex(fabric);
}
~SecureSession() { NotifySessionReleased(); }

SecureSession(SecureSession &&) = delete;
Expand Down Expand Up @@ -112,7 +114,6 @@ class SecureSession : public Session

uint16_t GetLocalSessionId() const { return mLocalSessionId; }
uint16_t GetPeerSessionId() const { return mPeerSessionId; }
FabricIndex GetFabricIndex() const { return mFabric; }

// Should only be called for PASE sessions, which start with undefined fabric,
// to migrate to a newly commissioned fabric after successful
Expand All @@ -123,10 +124,10 @@ class SecureSession : public Session
// TODO(#13711): this check won't work until the issue is addressed
if (mSecureSessionType == Type::kPASE)
{
mFabric = fabricIndex;
SetFabricIndex(fabricIndex);
}
#else
mFabric = fabricIndex;
SetFabricIndex(fabricIndex);
#endif
return CHIP_NO_ERROR;
}
Expand All @@ -145,10 +146,6 @@ class SecureSession : public Session
const uint16_t mLocalSessionId;
const uint16_t mPeerSessionId;

// PASE sessions start with undefined fabric, but are migrated to a newly
// commissioned fabric after successful OperationalCredentialsCluster::AddNOC
FabricIndex mFabric;

PeerAddress mPeerAddress;
System::Clock::Timestamp mLastActivityTime;
ReliableMessageProtocolConfig mMRPConfig;
Expand Down
6 changes: 6 additions & 0 deletions src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <credentials/FabricTable.h>
#include <lib/core/CHIPConfig.h>
#include <messaging/ReliableMessageProtocolConfig.h>
#include <transport/SessionHolder.h>
Expand Down Expand Up @@ -67,6 +68,8 @@ class Session
virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0;
virtual System::Clock::Milliseconds32 GetAckTimeout() const = 0;

FabricIndex GetFabricIndex() const { return mFabricIndex; }

SecureSession * AsSecureSession();
UnauthenticatedSession * AsUnauthenticatedSession();
GroupSession * AsGroupSession();
Expand All @@ -85,8 +88,11 @@ class Session
}
}

void SetFabricIndex(FabricIndex index) { mFabricIndex = index; }

private:
IntrusiveList<SessionHolder> mHolders;
FabricIndex mFabricIndex = kUndefinedFabricIndex;
};

} // namespace Transport
Expand Down

0 comments on commit 4025024

Please sign in to comment.