Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve numerical stability of CCA variance #629

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 49 additions & 39 deletions core/include/traccc/clusterization/impl/measurement_creation.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#pragma once

#include <cassert>

namespace traccc::details {

TRACCC_HOST_DEVICE
Expand All @@ -28,6 +30,8 @@ inline vector2 position_from_cell(const cell& cell, const cell_module& mod) {
TRACCC_HOST_DEVICE inline void calc_cluster_properties(
const cell_collection_types::const_device& cluster, const cell_module& mod,
point2& mean, point2& var, scalar& totalWeight) {
point2 offset{0., 0.};
bool first_processed = false;

// Loop over the cells of the cluster.
for (const cell& cell : cluster) {
Expand All @@ -37,20 +41,30 @@ TRACCC_HOST_DEVICE inline void calc_cluster_properties(

// Only consider cells over a minimum threshold.
if (weight > mod.threshold) {
totalWeight += weight;
scalar weight_factor = weight / totalWeight;

// Update all output properties with this cell.
totalWeight += cell.activation;
const point2 cell_position = position_from_cell(cell, mod);
const point2 prev = mean;
const point2 diff = cell_position - prev;
point2 cell_position = position_from_cell(cell, mod);

mean = prev + (weight / totalWeight) * diff;
for (std::size_t i = 0; i < 2; ++i) {
var[i] =
var[i] + weight * (diff[i]) * (cell_position[i] - mean[i]);
if (!first_processed) {
offset = cell_position;
first_processed = true;
}

cell_position = cell_position - offset;

const point2 diff_old = cell_position - mean;
mean = mean + diff_old * weight_factor;
const point2 diff_new = cell_position - mean;

var[0] = (1.f - weight_factor) * var[0] +
weight_factor * (diff_old[0] * diff_new[0]);
var[1] = (1.f - weight_factor) * var[1] +
weight_factor * (diff_old[1] * diff_new[1]);
}
}

mean = mean + offset;
}

TRACCC_HOST_DEVICE inline void fill_measurement(
Expand All @@ -70,38 +84,34 @@ TRACCC_HOST_DEVICE inline void fill_measurement(
// edition, chapter 4.2.2.

// Calculate the cluster properties
scalar totalWeight = 0.;
point2 mean{0., 0.}, var{0., 0.};
scalar totalWeight = 0.f;
point2 mean{0.f, 0.f}, var{0.f, 0.f};
calc_cluster_properties(cluster, mod, mean, var, totalWeight);

if (totalWeight > 0.) {

// Access the measurement in question.
measurement& m = measurements[measurement_index];

m.module_link = mod_link;
m.surface_link = mod.surface_link;
// normalize the cell position
m.local = mean;
// normalize the variance
m.variance[0] = var[0] / totalWeight;
m.variance[1] = var[1] / totalWeight;
// plus pitch^2 / 12
const auto pitch = mod.pixel.get_pitch();
m.variance =
m.variance + point2{pitch[0] * pitch[0] / static_cast<scalar>(12.),
pitch[1] * pitch[1] / static_cast<scalar>(12.)};
// @todo add variance estimation

// For the ambiguity resolution algorithm, give a unique measurement ID
m.measurement_id = measurement_index;

// Adjust the measurement object for 1D surfaces.
if (mod.pixel.dimension == 1) {
m.meas_dim = 1;
m.local[1] = 0.f;
m.variance[1] = mod.pixel.variance_y;
}
assert(totalWeight > 0.f);

// Access the measurement in question.
measurement& m = measurements[measurement_index];

m.module_link = mod_link;
m.surface_link = mod.surface_link;
// normalize the cell position
m.local = mean;

// plus pitch^2 / 12
const auto pitch = mod.pixel.get_pitch();
m.variance = var + point2{pitch[0] * pitch[0] / static_cast<scalar>(12.),
pitch[1] * pitch[1] / static_cast<scalar>(12.)};
// @todo add variance estimation

// For the ambiguity resolution algorithm, give a unique measurement ID
m.measurement_id = measurement_index;

// Adjust the measurement object for 1D surfaces.
if (mod.pixel.dimension == 1) {
m.meas_dim = 1;
m.local[1] = 0.f;
m.variance[1] = mod.pixel.variance_y;
}
}

Expand Down
10 changes: 0 additions & 10 deletions core/src/clusterization/measurement_creation_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@ measurement_creation_algorithm::operator()(

// Process the clusters one-by-one.
for (std::size_t i = 0; i < clusters.size(); ++i) {
// To calculate the mean and variance with high numerical stability
// we use a weighted variant of Welford's algorithm. This is a
// single-pass online algorithm that works well for large numbers
// of samples, as well as samples with very high values.
//
// To learn more about this algorithm please refer to:
// [1] https://doi.org/10.1080/00401706.1962.10490022
// [2] The Art of Computer Programming, Donald E. Knuth, second
// edition, chapter 4.2.2.

// Get the cluster.
cluster_container_types::device::item_vector::const_reference cluster =
clusters.get_items()[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,43 @@ inline void aggregate_cluster(
vecmem::device_vector<unsigned int> cell_links_device(cell_links);

/*
* Now, we iterate over all other cells to check if they belong
* to our cluster. Note that we can start at the current index
* because no cell is ever a child of a cluster owned by a cell
* with a higher ID.
* Now, we iterate over all other cells to check if they belong to our
* cluster. Note that we can start at the current index because no cell is
* ever a child of a cluster owned by a cell with a higher ID.
*
* Implemented here is a weighted version of Welford's algorithm. To read
* more about this algorithm, see the following sources:
*
* [1] https://doi.org/10.1080/00401706.1962.10490022
* [2] The Art of Computer Programming, Donald E. Knuth, second edition,
* chapter 4.2.2.
*
* The core idea of Welford's algorithm is to use the recurrence relation
*
* $$\sigma^2_n = (1 - \frac{w_n}{W_n}) \sigma^2_{n-1} + \frac{w_n}{W_n}
* (x_n - \mu_n) (x_n - \mu_{n-1})$$
*
* Which makes the algorithm less prone to catastrophic cancellation and
* other unwanted effects. In addition, we offset the entire computation
* by the first cell in the cluster, which brings the entire computation
* closer to zero where floating point precision is higher. This relies on
* the following:
*
* $$\mu(x_1, \ldots, x_n) = \mu(x_1 - C, \ldots, x_n - C) + C$$
*
* and
*
* $$\sigma^2(x_1, \ldots, x_n) = \sigma^2(x_1 - C, \ldots, x_n - C)$$
*/
scalar totalWeight = 0.;
point2 mean{0., 0.}, var{0., 0.};
point2 mean{0., 0.}, var{0., 0.}, offset{0., 0.};

const auto module_link = cells[cid + start].module_link;
const cell_module this_module = modules.at(module_link);
const unsigned short partition_size = end - start;

bool first_processed = false;

channel_id maxChannel1 = std::numeric_limits<channel_id>::min();

for (unsigned short j = cid; j < partition_size; j++) {
Expand Down Expand Up @@ -67,15 +93,26 @@ inline void aggregate_cluster(

if (weight > this_module.threshold) {
totalWeight += weight;
const point2 cell_position =
scalar weight_factor = weight / totalWeight;

point2 cell_position =
traccc::details::position_from_cell(this_cell, this_module);
const point2 prev = mean;
const point2 diff = cell_position - prev;

mean = prev + (weight / totalWeight) * diff;
for (char i = 0; i < 2; ++i) {
var[i] += weight * (diff[i]) * (cell_position[i] - mean[i]);
if (!first_processed) {
offset = cell_position;
first_processed = true;
}

cell_position = cell_position - offset;

const point2 diff_old = cell_position - mean;
mean = mean + diff_old * weight_factor;
const point2 diff_new = cell_position - mean;

var[0] = (1.f - weight_factor) * var[0] +
weight_factor * (diff_old[0] * diff_new[0]);
var[1] = (1.f - weight_factor) * var[1] +
weight_factor * (diff_old[1] * diff_new[1]);
}

cell_links_device.at(pos) = link;
Expand All @@ -89,20 +126,15 @@ inline void aggregate_cluster(
break;
}
}
if (totalWeight > static_cast<scalar>(0.)) {
#pragma unroll
for (char i = 0; i < 2; ++i) {
var[i] /= totalWeight;
}
const auto pitch = this_module.pixel.get_pitch();
var = var + point2{pitch[0] * pitch[0] / static_cast<scalar>(12.),
pitch[1] * pitch[1] / static_cast<scalar>(12.)};
}

const auto pitch = this_module.pixel.get_pitch();
var = var + point2{pitch[0] * pitch[0] / static_cast<scalar>(12.),
pitch[1] * pitch[1] / static_cast<scalar>(12.)};

/*
* Fill output vector with calculated cluster properties
*/
out.local = mean;
out.local = mean + offset;
out.variance = var;
out.surface_link = this_module.surface_link;
out.module_link = module_link;
Expand Down
15 changes: 10 additions & 5 deletions tests/common/tests/cca_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ class ConnectedComponentAnalysisTests
traccc::cell_collection_types::host &cells = data.cells;
traccc::cell_module_collection_types::host &modules = data.modules;

traccc::scalar pitch = 1.f;

for (std::size_t i = 0; i < modules.size(); i++) {
modules.at(i).pixel = {-0.5f, -0.5f, 1.f, 1.f};
modules.at(i).pixel = {-0.5f, -0.5f, pitch, pitch};
}

std::map<traccc::geometry_id, vecmem::vector<traccc::measurement>>
Expand All @@ -112,15 +114,16 @@ class ConnectedComponentAnalysisTests

cca_truth_hit_reader truth_reader(file_truth);

traccc::scalar var_adjustment = (pitch * pitch) / 12.f;

cca_truth_hit io_truth;
while (truth_reader.read(io_truth)) {
ASSERT_TRUE(result.find(io_truth.geometry_id) != result.end());

const vecmem::vector<traccc::measurement> &meas =
result.at(io_truth.geometry_id);

traccc::scalar tol = std::max(
0.1, 0.0001 * std::max(io_truth.channel0, io_truth.channel1));
const traccc::scalar tol = 0.0001f;

auto match = std::find_if(
meas.begin(), meas.end(),
Expand All @@ -133,8 +136,10 @@ class ConnectedComponentAnalysisTests

EXPECT_NEAR(match->local[0], io_truth.channel0, tol);
EXPECT_NEAR(match->local[1], io_truth.channel1, tol);
EXPECT_NEAR(match->variance[0], io_truth.variance0, tol);
EXPECT_NEAR(match->variance[1], io_truth.variance1, tol);
EXPECT_NEAR(match->variance[0], io_truth.variance0 + var_adjustment,
tol);
EXPECT_NEAR(match->variance[1], io_truth.variance1 + var_adjustment,
tol);

++total_truth;
}
Expand Down
Loading