Skip to content

Commit

Permalink
refactor!: TrackStateType as bitset view instead of bitset (#2068)
Browse files Browse the repository at this point in the history
This decouples the storage a bit from the semantics. It allows the backends to store `uint64_t` and have the rest of the code transparently interpret this as a bitset.
  • Loading branch information
paulgessinger authored May 3, 2023
1 parent 2a9c039 commit 37bc755
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 12 deletions.
90 changes: 85 additions & 5 deletions Core/include/Acts/EventData/MultiTrajectory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,85 @@ enum TrackStateFlag {
NumTrackStateFlags = 6
};

using TrackStateType = std::bitset<TrackStateFlag::NumTrackStateFlags>;
class ConstTrackStateType;

/// View type over a bitset stored in a 64 bit integer
/// This view allows modifications.
class TrackStateType {
public:
using raw_type = unsigned long long;
/// Constructor from a reference to the underlying value container
/// @param raw the value container
TrackStateType(raw_type& raw) : m_raw{&raw} {}

/// Assign the value from another set of flags
/// @param other the other set of flags to assign
/// @return this object
TrackStateType& operator=(const TrackStateType& other) {
*m_raw = *other.m_raw;
return *this;
}

/// Assign the value from another set of flags
/// @param other the other set of flags to assign
/// @return this object
TrackStateType& operator=(const ConstTrackStateType& other);

/// Return if the bit at position @p pos is 1
/// @param pos the bit position
/// @return if the bit at @p pos is one or not
bool test(std::size_t pos) const {
std::bitset<sizeof(raw_type)> bs{*m_raw};
return bs.test(pos);
}

/// Change the value of the bit at position @p pos to @p value.
/// @param pos the position of the bit to change
/// @param value the value to change the bit to
void set(std::size_t pos, bool value = true) {
std::bitset<sizeof(raw_type)> bs{*m_raw};
bs.set(pos, value);
*m_raw = bs.to_ullong();
}

/// Change the value of the bit at position at @p pos to @c false
/// @param pos the position of the bit to change
void reset(std::size_t pos) { set(pos, false); }

private:
raw_type* m_raw{nullptr};
};

/// View type over a bitset stored in a 64 bit integer
/// This view does not allow modifications
class ConstTrackStateType {
public:
using raw_type = unsigned long long;

/// Constructor from a reference to the underlying value container
/// @param raw the value container
ConstTrackStateType(const raw_type& raw) : m_raw{&raw} {}

/// Return if the bit at position @p pos is 1
/// @param pos the bit position
/// @return if the bit at @p pos is one or not
bool test(std::size_t pos) const {
std::bitset<sizeof(raw_type)> bs{*m_raw};
return bs.test(pos);
}

private:
friend class TrackStateType;
const raw_type* m_raw{nullptr};
};

inline TrackStateType& TrackStateType::operator=(
const ConstTrackStateType& other) {
*m_raw = *other.m_raw;
return *this;
}

// using TrackStateType = std::bitset<TrackStateFlag::NumTrackStateFlags>;

// forward declarations
template <typename derived_t>
Expand Down Expand Up @@ -944,14 +1022,16 @@ class TrackStateProxy {
/// reference.
/// @return reference to the type flags.
template <bool RO = ReadOnly, typename = std::enable_if_t<!RO>>
TrackStateType& typeFlags() {
return component<TrackStateType, hashString("typeFlags")>();
TrackStateType typeFlags() {
return TrackStateType{
component<TrackStateType::raw_type, hashString("typeFlags")>()};
}

/// Getter for the type flags. Returns a copy of the type flags value.
/// @return The type flags of this track state
TrackStateType typeFlags() const {
return component<TrackStateType, hashString("typeFlags")>();
ConstTrackStateType typeFlags() const {
return ConstTrackStateType{
component<TrackStateType::raw_type, hashString("typeFlags")>()};
}

template <bool RO = ReadOnly, typename = std::enable_if_t<!RO>>
Expand Down
2 changes: 1 addition & 1 deletion Core/include/Acts/EventData/VectorMultiTrajectory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class VectorMultiTrajectoryBase {

double chi2 = 0;
double pathLength = 0;
TrackStateType typeFlags;
TrackStateType::raw_type typeFlags{};

IndexType iuncalibrated = kInvalid;
IndexType icalibratedsourcelink = kInvalid;
Expand Down
4 changes: 2 additions & 2 deletions Core/include/Acts/TrackFinding/CombinatorialKalmanFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ class CombinatorialKalmanFilter {
trackState.allocateCalibrated(candidateTrackState.calibratedSize());
trackState.copyFrom(candidateTrackState, mask, false);

auto& typeFlags = trackState.typeFlags();
auto typeFlags = trackState.typeFlags();
if (trackState.referenceSurface().surfaceMaterial() != nullptr) {
typeFlags.set(TrackStateFlag::MaterialFlag);
}
Expand Down Expand Up @@ -1013,7 +1013,7 @@ class CombinatorialKalmanFilter {
// parameter

// Set the track state flags
auto& typeFlags = trackStateProxy.typeFlags();
auto typeFlags = trackStateProxy.typeFlags();
if (trackStateProxy.referenceSurface().surfaceMaterial() != nullptr) {
typeFlags.set(TrackStateFlag::MaterialFlag);
}
Expand Down
4 changes: 2 additions & 2 deletions Core/include/Acts/TrackFitting/Chi2Fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class Chi2Fitter {
//====================================

// Get and set the type flags
auto& typeFlags = trackStateProxy.typeFlags();
auto typeFlags = trackStateProxy.typeFlags();
typeFlags.set(TrackStateFlag::ParameterFlag);
if (surface->surfaceMaterial() != nullptr) {
typeFlags.set(TrackStateFlag::MaterialFlag);
Expand Down Expand Up @@ -501,7 +501,7 @@ class Chi2Fitter {
trackStateProxy.setReferenceSurface(surface->getSharedPtr());

// Set the track state flags
auto& typeFlags = trackStateProxy.typeFlags();
auto typeFlags = trackStateProxy.typeFlags();
typeFlags.set(TrackStateFlag::ParameterFlag);
if (surface->surfaceMaterial() != nullptr) {
typeFlags.set(TrackStateFlag::MaterialFlag);
Expand Down
4 changes: 2 additions & 2 deletions Core/include/Acts/TrackFitting/detail/KalmanUpdateHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ auto kalmanHandleMeasurement(
extensions.calibrator(state.geoContext, trackStateProxy);

// Get and set the type flags
auto &typeFlags = trackStateProxy.typeFlags();
auto typeFlags = trackStateProxy.typeFlags();
typeFlags.set(TrackStateFlag::ParameterFlag);
if (surface.surfaceMaterial() != nullptr) {
typeFlags.set(TrackStateFlag::MaterialFlag);
Expand Down Expand Up @@ -142,7 +142,7 @@ auto kalmanHandleNoMeasurement(
trackStateProxy.setReferenceSurface(surface.getSharedPtr());

// Set the track state flags
auto &typeFlags = trackStateProxy.typeFlags();
auto typeFlags = trackStateProxy.typeFlags();
typeFlags.set(TrackStateFlag::ParameterFlag);
if (surface.surfaceMaterial() != nullptr) {
typeFlags.set(TrackStateFlag::MaterialFlag);
Expand Down

0 comments on commit 37bc755

Please sign in to comment.