Skip to content

Commit

Permalink
[ICD] Refactor ICDCounter logic (#31957)
Browse files Browse the repository at this point in the history
* Refactor ICD Check-In counter
Fix initial Check-In counter value

* Add basic integration tests to validate behavior

* Restyled by whitespace

* Restyled by prettier-yaml

* rename source_set

* Rename constant name

* Apply suggestions from code review

Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>

* addres review comments

* Add init/shutdown unit-test for the ICDManager

* Fix Build.gn - multiple inclusion of CheckInMessage.cpp

---------

Co-authored-by: Restyled.io <commits@restyled.io>
Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>
  • Loading branch information
3 people authored Feb 6, 2024
1 parent 21b998a commit e882141
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ CHIP_ERROR IcdManagementAttributeAccess::ReadRegisteredClients(EndpointId endpoi

CHIP_ERROR IcdManagementAttributeAccess::ReadICDCounter(EndpointId endpoint, AttributeValueEncoder & encoder)
{
return encoder.Encode(mICDConfigurationData->GetICDCounter());
return encoder.Encode(mICDConfigurationData->GetICDCounter().GetValue());
}

CHIP_ERROR IcdManagementAttributeAccess::ReadClientsSupportedPerFabric(EndpointId endpoint, AttributeValueEncoder & encoder)
Expand Down Expand Up @@ -292,7 +292,7 @@ Status ICDManagementServer::RegisterClient(CommandHandler * commandObj, const Co
TriggerICDMTableUpdatedEvent();
}

icdCounter = mICDConfigurationData->GetICDCounter();
icdCounter = mICDConfigurationData->GetICDCounter().GetValue();
return InteractionModel::Status::Success;
}

Expand Down
1 change: 1 addition & 0 deletions src/app/icd/server/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ source_set("configuration-data") {
":icd-server-config",
"${chip_root}/src/lib/core",
"${chip_root}/src/lib/support",
"${chip_root}/src/protocols/secure_channel:check-in-counter",
]
}
8 changes: 4 additions & 4 deletions src/app/icd/server/ICDConfigurationData.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <lib/core/Optional.h>
#include <lib/support/TimeUtils.h>
#include <platform/CHIPDeviceConfig.h>
#include <protocols/secure_channel/CheckInCounter.h>
#include <system/SystemClock.h>

namespace chip {
Expand All @@ -44,7 +45,7 @@ class TestICDManager;
class ICDConfigurationData
{
public:
static constexpr uint32_t ICD_CHECK_IN_COUNTER_MIN_INCREMENT = 100;
static constexpr uint32_t kICDCounterPersistenceIncrement = 100;

enum class ICDMode : uint8_t
{
Expand All @@ -60,7 +61,7 @@ class ICDConfigurationData

System::Clock::Milliseconds16 GetActiveModeThreshold() { return mActiveThreshold; }

uint32_t GetICDCounter() { return mICDCounter; }
Protocols::SecureChannel::CheckInCounter & GetICDCounter() { return mICDCounter; }

uint16_t GetClientsSupportedPerFabric() { return mFabricClientsSupported; }

Expand Down Expand Up @@ -99,7 +100,6 @@ class ICDConfigurationData
friend class chip::app::TestICDManager;

void SetICDMode(ICDMode mode) { mICDMode = mode; };
void SetICDCounter(uint32_t count) { mICDCounter = count; }
void SetSlowPollingInterval(System::Clock::Milliseconds32 slowPollInterval) { mSlowPollingInterval = slowPollInterval; };
void SetFastPollingInterval(System::Clock::Milliseconds32 fastPollInterval) { mFastPollingInterval = fastPollInterval; };

Expand Down Expand Up @@ -137,7 +137,7 @@ class ICDConfigurationData

System::Clock::Milliseconds16 mActiveThreshold = System::Clock::Milliseconds16(CHIP_CONFIG_ICD_ACTIVE_MODE_THRESHOLD_MS);

uint32_t mICDCounter = 0;
Protocols::SecureChannel::CheckInCounter mICDCounter;

static_assert((CHIP_CONFIG_ICD_CLIENTS_SUPPORTED_PER_FABRIC) >= 1,
"Spec requires the minimum of supported clients per fabric be equal or greater to 1.");
Expand Down
59 changes: 7 additions & 52 deletions src/app/icd/server/ICDManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ void ICDManager::Init(PersistentStorageDelegate * storage, FabricTable * fabricT
mExchangeManager = exchangeManager;
mSubManager = manager;

VerifyOrDie(InitCounter() == CHIP_NO_ERROR);
VerifyOrDie(ICDConfigurationData::GetInstance().GetICDCounter().Init(mStorage, DefaultStorageKeyAllocator::ICDCheckInCounter(),
ICDConfigurationData::kICDCounterPersistenceIncrement) ==
CHIP_NO_ERROR);

UpdateICDMode();
UpdateOperationState(OperationalState::IdleMode);
Expand Down Expand Up @@ -115,7 +117,8 @@ void ICDManager::SendCheckInMsgs()
#if !CONFIG_BUILD_FOR_HOST_UNIT_TEST
VerifyOrDie(mStorage != nullptr);
VerifyOrDie(mFabricTable != nullptr);
uint32_t counter = ICDConfigurationData::GetInstance().GetICDCounter();

uint32_t counterValue = ICDConfigurationData::GetInstance().GetICDCounter().GetNextCheckInCounterValue();
bool counterIncremented = false;

for (const auto & fabricInfo : *mFabricTable)
Expand Down Expand Up @@ -156,7 +159,7 @@ void ICDManager::SendCheckInMsgs()
{
counterIncremented = true;

if (CHIP_NO_ERROR != IncrementCounter())
if (CHIP_NO_ERROR != ICDConfigurationData::GetInstance().GetICDCounter().Advance())
{
ChipLogError(AppServer, "Incremented ICDCounter but failed to access/save to Persistent storage");
}
Expand All @@ -167,7 +170,7 @@ void ICDManager::SendCheckInMsgs()
ICDCheckInSender * sender = mICDSenderPool.CreateObject(mExchangeManager);
VerifyOrReturn(sender != nullptr, ChipLogError(AppServer, "Failed to allocate ICDCheckinSender"));

if (CHIP_NO_ERROR != sender->RequestResolve(entry, mFabricTable, counter))
if (CHIP_NO_ERROR != sender->RequestResolve(entry, mFabricTable, counterValue))
{
ChipLogError(AppServer, "Failed to send ICD Check-In");
}
Expand All @@ -176,54 +179,6 @@ void ICDManager::SendCheckInMsgs()
#endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST
}

CHIP_ERROR ICDManager::InitCounter()
{
CHIP_ERROR err;
uint32_t temp;
uint16_t size = static_cast<uint16_t>(sizeof(uint32_t));

err = mStorage->SyncGetKeyValue(DefaultStorageKeyAllocator::ICDCheckInCounter().KeyName(), &temp, size);
if (err == CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND)
{
// First time retrieving the counter
temp = chip::Crypto::GetRandU32();
}
else if (err != CHIP_NO_ERROR)
{
return err;
}

ICDConfigurationData::GetInstance().SetICDCounter(temp);
temp += ICDConfigurationData::ICD_CHECK_IN_COUNTER_MIN_INCREMENT;

// Increment the count directly to minimize flash write.
return mStorage->SyncSetKeyValue(DefaultStorageKeyAllocator::ICDCheckInCounter().KeyName(), &temp, size);
}

CHIP_ERROR ICDManager::IncrementCounter()
{
uint32_t temp = 0;
StorageKeyName key = DefaultStorageKeyAllocator::ICDCheckInCounter();
uint16_t size = static_cast<uint16_t>(sizeof(uint32_t));

ICDConfigurationData::GetInstance().mICDCounter++;

if (mStorage == nullptr)
{
return CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND;
}

ReturnErrorOnFailure(mStorage->SyncGetKeyValue(key.KeyName(), &temp, size));

if (temp == ICDConfigurationData::GetInstance().mICDCounter)
{
temp = ICDConfigurationData::GetInstance().mICDCounter + ICDConfigurationData::ICD_CHECK_IN_COUNTER_MIN_INCREMENT;
return mStorage->SyncSetKeyValue(key.KeyName(), &temp, size);
}

return CHIP_NO_ERROR;
}

void ICDManager::UpdateICDMode()
{
assertChipStackLockedByCurrentThread();
Expand Down
4 changes: 0 additions & 4 deletions src/app/icd/server/ICDManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ class ICDManager : public ICDListener
*/
static void OnTransitionToIdle(System::Layer * aLayer, void * appState);

// ICD Counter
CHIP_ERROR IncrementCounter();
CHIP_ERROR InitCounter();

uint8_t mOpenExchangeContextCount = 0;
uint8_t mCheckInRequestCount = 0;

Expand Down
25 changes: 16 additions & 9 deletions src/app/tests/TestICDManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <lib/support/TestPersistentStorageDelegate.h>
#include <lib/support/TimeUtils.h>
#include <lib/support/UnitTestContext.h>
#include <lib/support/UnitTestExtendedAssertions.h>
#include <lib/support/UnitTestRegistration.h>
#include <nlunit-test.h>
#include <system/SystemLayerImpl.h>
Expand Down Expand Up @@ -113,7 +114,7 @@ class TestContext : public chip::Test::AppContext
CHIP_ERROR SetUp() override
{
ReturnErrorOnFailure(chip::Test::AppContext::SetUp());
mICDManager.Init(&testStorage, &GetFabricTable(), &mKeystore, &GetExchangeManager(), &mSubManager);
mICDManager.Init(&testStorage, &GetFabricTable(), &mKeystore, &GetExchangeManager(), &mSubInfoProvider);
mICDManager.RegisterObserver(&mICDStateObserver);
return CHIP_NO_ERROR;
}
Expand All @@ -128,12 +129,12 @@ class TestContext : public chip::Test::AppContext
System::Clock::Internal::MockClock mMockClock;
TestSessionKeystoreImpl mKeystore;
app::ICDManager mICDManager;
TestSubscriptionsInfoProvider mSubManager;
TestSubscriptionsInfoProvider mSubInfoProvider;
TestPersistentStorageDelegate testStorage;
TestICDStateObserver mICDStateObserver;

private:
System::Clock::ClockBase * mRealClock;
TestICDStateObserver mICDStateObserver;
};

} // namespace
Expand Down Expand Up @@ -196,7 +197,7 @@ class TestICDManager
ctx->mICDManager.SetTestFeatureMapValue(0x07);

// Set that there are no matching subscriptions
ctx->mSubManager.SetReturnValue(false);
ctx->mSubInfoProvider.SetReturnValue(false);

// Set New durations for test case
Milliseconds32 oldActiveModeDuration = icdConfigData.GetActiveModeDuration();
Expand Down Expand Up @@ -275,7 +276,7 @@ class TestICDManager
ctx->mICDManager.SetTestFeatureMapValue(0x07);

// Set that there are not matching subscriptions
ctx->mSubManager.SetReturnValue(true);
ctx->mSubInfoProvider.SetReturnValue(true);

// Set New durations for test case
Milliseconds32 oldActiveModeDuration = icdConfigData.GetActiveModeDuration();
Expand Down Expand Up @@ -506,10 +507,16 @@ class TestICDManager
static void TestICDCounter(nlTestSuite * aSuite, void * aContext)
{
TestContext * ctx = static_cast<TestContext *>(aContext);
uint32_t counter = ICDConfigurationData::GetInstance().GetICDCounter();
ctx->mICDManager.IncrementCounter();
uint32_t counter2 = ICDConfigurationData::GetInstance().GetICDCounter();
NL_TEST_ASSERT(aSuite, (counter + 1) == counter2);
uint32_t counter = ICDConfigurationData::GetInstance().GetICDCounter().GetValue();

// Shut down and reinit ICDManager to increment counter
ctx->mICDManager.Shutdown();
ctx->mICDManager.Init(&(ctx->testStorage), &(ctx->GetFabricTable()), &(ctx->mKeystore), &(ctx->GetExchangeManager()),
&(ctx->mSubInfoProvider));
ctx->mICDManager.RegisterObserver(&(ctx->mICDStateObserver));

NL_TEST_ASSERT_EQUALS(aSuite, counter + ICDConfigurationData::kICDCounterPersistenceIncrement,
ICDConfigurationData::GetInstance().GetICDCounter().GetValue());
}

static void TestOnSubscriptionReport(nlTestSuite * aSuite, void * aContext)
Expand Down
62 changes: 42 additions & 20 deletions src/app/tests/suites/TestIcdManagementCluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ config:
nodeId: 0x12344321
cluster: "ICD Management"
endpoint: 0
beforeRebootICDCounter:
type: int32u
defaultValue: 0

tests:
- label: "Read the commissioner node ID"
Expand All @@ -36,6 +39,26 @@ tests:
- name: "nodeId"
value: nodeId

# chip-tool will register itself with the ICD during the tests.
- label: "Read RegisteredClients For Registration During Commissioning"
command: "readAttribute"
attribute: "RegisteredClients"
response:
value:
[
{
CheckInNodeID: commissionerNodeId,
MonitoredSubject: commissionerNodeId,
},
]

- label: "Unregister Client Registered During Commissioning"
command: "UnregisterClient"
arguments:
values:
- name: "CheckInNodeID"
value: commissionerNodeId

- label: "Read Feature Map"
command: "readAttribute"
attribute: "FeatureMap"
Expand Down Expand Up @@ -68,6 +91,25 @@ tests:
type: int32u
minValue: 0x0
maxValue: 0xFFFFFFFF
saveAs: beforeRebootICDCounter

- label: "Reboot target device"
cluster: "SystemCommands"
command: "Reboot"

- label: "Connect to the device again"
cluster: "DelayCommands"
command: "WaitForCommissionee"
arguments:
values:
- name: "nodeId"
value: nodeId

- label: "Read ICDCounter after reboot"
command: "readAttribute"
attribute: "ICDCounter"
response:
value: beforeRebootICDCounter + 100

- label: "Read UserActiveModeTriggerHint"
command: "readAttribute"
Expand Down Expand Up @@ -119,26 +161,6 @@ tests:
response:
error: NOT_FOUND

# chip-tool will register itself with the ICD during the tests.
- label: "Read RegisteredClients For Registration During Commissioning"
command: "readAttribute"
attribute: "RegisteredClients"
response:
value:
[
{
CheckInNodeID: commissionerNodeId,
MonitoredSubject: commissionerNodeId,
},
]

- label: "Unregister Client Registered During Commissioning"
command: "UnregisterClient"
arguments:
values:
- name: "CheckInNodeID"
value: commissionerNodeId

- label: "Register 1.0 (key too short)"
command: "RegisterClient"
arguments:
Expand Down
28 changes: 21 additions & 7 deletions src/lib/support/PersistedCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,12 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>
ReturnErrorOnFailure(ReadStartValue(startValue));

#if CHIP_CONFIG_PERSISTED_COUNTER_DEBUG_LOGGING
// Compiler should optimize these branches.
if (is_same_v<decltype(T), uint64_t>)
if constexpr (std::is_same_v<decltype(startValue), uint64_t>)
{
ChipLogDetail(EventLogging, "PersistedCounter::Init() aEpoch 0x" ChipLogFormatX64 " startValue 0x" ChipLogFormatX64,
ChipLogValueX64(aEpoch), ChipLogValueX64(startValue));
}
else if (is_same_v<decltype(T), uint32_t>)
else if (std::is_same_v<decltype(startValue), uint32_t>)
{
ChipLogDetail(EventLogging, "PersistedCounter::Init() aEpoch 0x%" PRIx32 " startValue 0x%" PRIx32,
static_cast<uint32_t>(aEpoch), static_cast<uint32_t>(startValue));
Expand Down Expand Up @@ -151,8 +150,7 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>
{
mNextEpoch = aStartValue;
#if CHIP_CONFIG_PERSISTED_COUNTER_DEBUG_LOGGING
// Compiler should optimize these branches.
if (is_same_v<decltype(T), uint64_t>)
if constexpr (std::is_same_v<decltype(aStartValue), uint64_t>)
{
ChipLogDetail(EventLogging, "PersistedCounter::WriteStartValue() aStartValue 0x" ChipLogFormatX64,
ChipLogValueX64(aStartValue));
Expand All @@ -178,7 +176,7 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>
*/
CHIP_ERROR ReadStartValue(T & aStartValue)
{
T valueLE = 0;
T valueLE = GetInitialCounterValue();
uint16_t size = sizeof(valueLE);

VerifyOrReturnError(mKey.IsInitialized(), CHIP_ERROR_INCORRECT_STATE);
Expand Down Expand Up @@ -206,12 +204,28 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>
aStartValue = Encoding::LittleEndian::HostSwap<T>(valueLE);

#if CHIP_CONFIG_PERSISTED_COUNTER_DEBUG_LOGGING
ChipLogDetail(EventLogging, "PersistedCounter::ReadStartValue() aStartValue 0x%x", aStartValue);
if constexpr (std::is_same_v<decltype(aStartValue), uint64_t>)
{
ChipLogDetail(EventLogging, "PersistedCounter::ReadStartValue() aStartValue 0x" ChipLogFormatX64,
ChipLogValueX64(aStartValue));
}
else
{
ChipLogDetail(EventLogging, "PersistedCounter::ReadStartValue() aStartValue 0x%" PRIx32,
static_cast<uint32_t>(aStartValue));
}
#endif

return CHIP_NO_ERROR;
}

/**
* @brief Get the Initial Counter Value
*
* By default, persisted counters start off at 0.
*/
virtual inline T GetInitialCounterValue() { return 0; }

PersistentStorageDelegate * mStorage = nullptr; // start value is stored here
StorageKeyName mKey;
T mEpoch = 0; // epoch modulus value
Expand Down
Loading

0 comments on commit e882141

Please sign in to comment.