Skip to content

Commit acd1c62

Browse files
committed
Use the traccc track parameter types everywhere and allow templating of the algebra type
1 parent fb3c355 commit acd1c62

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+128
-112
lines changed

benchmarks/common/benchmarks/toy_detector_benchmark.hpp

+14-16
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
6565

6666
// Detector type
6767
using detector_type = traccc::toy_detector::host;
68+
using algebra_type = typename detector_type::algebra_type;
6869
using scalar_type = detector_type::scalar_type;
6970

7071
// B field value and its type
7172
// @TODO: Set B field as argument
7273
using b_field_t = covfie::field<detray::bfield::const_bknd_t<scalar_type>>;
7374

74-
static constexpr traccc::vector3 B{0, 0,
75-
2 * traccc::unit<traccc::scalar>::T};
75+
static constexpr traccc::vector3 B{0, 0, 2 * traccc::unit<scalar_type>::T};
7676

7777
ToyDetectorBenchmark() {
7878

@@ -86,20 +86,18 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
8686

8787
// Use deterministic random number generator for testing
8888
using uniform_gen_t = detray::detail::random_numbers<
89-
traccc::scalar, std::uniform_real_distribution<traccc::scalar>>;
89+
scalar_type, std::uniform_real_distribution<scalar_type>>;
9090

9191
// Build the detector
9292
auto [det, name_map] =
93-
detray::build_toy_detector<traccc::default_algebra>(
94-
host_mr, get_toy_config());
93+
detray::build_toy_detector<algebra_type>(host_mr, get_toy_config());
9594

9695
// B field
9796
auto field = detray::bfield::create_const_field<scalar_type>(B);
9897

9998
// Origin of particles
100-
using generator_type =
101-
detray::random_track_generator<traccc::free_track_parameters,
102-
uniform_gen_t>;
99+
using generator_type = detray::random_track_generator<
100+
traccc::free_track_parameters<algebra_type>, uniform_gen_t>;
103101
generator_type::configuration gen_cfg{};
104102
gen_cfg.n_tracks(n_tracks);
105103
gen_cfg.phi_range(phi_range);
@@ -109,12 +107,12 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
109107

110108
// Smearing value for measurements
111109
traccc::measurement_smearer<traccc::default_algebra> meas_smearer(
112-
50 * traccc::unit<traccc::scalar>::um,
113-
50 * traccc::unit<traccc::scalar>::um);
110+
50 * traccc::unit<scalar_type>::um,
111+
50 * traccc::unit<scalar_type>::um);
114112

115113
// Type declarations
116-
using writer_type = traccc::smearing_writer<
117-
traccc::measurement_smearer<traccc::default_algebra>>;
114+
using writer_type =
115+
traccc::smearing_writer<traccc::measurement_smearer<algebra_type>>;
118116

119117
// Writer config
120118
typename writer_type::config smearer_writer_cfg{meas_smearer};
@@ -126,7 +124,7 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
126124

127125
auto sim = traccc::simulator<detector_type, b_field_t, generator_type,
128126
writer_type>(
129-
detray::muon<traccc::scalar>(), n_events, det, field,
127+
detray::muon<scalar_type>(), n_events, det, field,
130128
std::move(generator), std::move(smearer_writer_cfg), full_path);
131129

132130
// Same propagation configuration for sim and reco
@@ -147,14 +145,14 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
147145
detray::io::write_detector(det, name_map, writer_cfg);
148146
}
149147

150-
detray::toy_det_config<traccc::scalar> get_toy_config() const {
148+
detray::toy_det_config<scalar_type> get_toy_config() const {
151149

152150
// Create the toy geometry
153-
detray::toy_det_config<traccc::scalar> toy_cfg{};
151+
detray::toy_det_config<scalar_type> toy_cfg{};
154152
toy_cfg.n_brl_layers(4u).n_edc_layers(7u).do_check(false);
155153

156154
// @TODO: Increase the material budget again
157-
toy_cfg.module_mat_thickness(0.11f * traccc::unit<traccc::scalar>::mm);
155+
toy_cfg.module_mat_thickness(0.11f * traccc::unit<scalar_type>::mm);
158156

159157
return toy_cfg;
160158
}

core/include/traccc/edm/track_candidate.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ using track_candidate = measurement;
2323
using track_candidate_collection_types = collection_types<track_candidate>;
2424
/// Declare a track candidates container type
2525
using track_candidate_container_types =
26-
container_types<bound_track_parameters, track_candidate>;
26+
container_types<bound_track_parameters<>, track_candidate>;
2727

2828
} // namespace traccc

core/include/traccc/edm/track_parameters.hpp

+25-13
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,33 @@
2020

2121
namespace traccc {
2222

23-
using free_track_parameters =
24-
detray::free_track_parameters<traccc::default_algebra>;
25-
using bound_track_parameters =
26-
detray::bound_track_parameters<traccc::default_algebra>;
27-
using free_vector = free_track_parameters::vector_type;
28-
using bound_vector = bound_track_parameters::vector_type;
29-
using bound_covariance = bound_track_parameters::covariance_type;
23+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
24+
using free_track_parameters = detray::free_track_parameters<algebra_t>;
25+
26+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
27+
using bound_track_parameters = detray::bound_track_parameters<algebra_t>;
28+
29+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
30+
using free_vector = typename free_track_parameters<algebra_t>::vector_type;
31+
32+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
33+
using bound_vector = typename bound_track_parameters<algebra_t>::vector_type;
34+
35+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
36+
using bound_covariance =
37+
typename bound_track_parameters<algebra_t>::covariance_type;
38+
39+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
40+
using bound_matrix = detray::bound_matrix<algebra_t>;
3041

3142
/// Declare all track_parameters collection types
3243
using bound_track_parameters_collection_types =
33-
collection_types<bound_track_parameters>;
44+
collection_types<bound_track_parameters<>>;
3445

3546
// Wrap the phi of track parameters to [-pi,pi]
36-
TRACCC_HOST_DEVICE
37-
inline void wrap_phi(bound_track_parameters& param) {
47+
template <detray::concepts::algebra algebra_t>
48+
TRACCC_HOST_DEVICE inline void wrap_phi(
49+
bound_track_parameters<algebra_t>& param) {
3850

3951
traccc::scalar phi = param.phi();
4052
static constexpr traccc::scalar TWOPI =
@@ -49,9 +61,9 @@ inline void wrap_phi(bound_track_parameters& param) {
4961
}
5062

5163
/// Covariance inflation used for track fitting
52-
TRACCC_HOST_DEVICE
53-
inline void inflate_covariance(bound_track_parameters& param,
54-
const traccc::scalar inf_fac) {
64+
template <detray::concepts::algebra algebra_t>
65+
TRACCC_HOST_DEVICE inline void inflate_covariance(
66+
bound_track_parameters<algebra_t>& param, const traccc::scalar inf_fac) {
5567
auto& cov = param.covariance();
5668
for (unsigned int i = 0; i < e_bound_size; i++) {
5769
for (unsigned int j = 0; j < e_bound_size; j++) {

core/include/traccc/edm/track_state.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct fitting_result {
2626
using scalar_type = detray::dscalar<algebra_t>;
2727

2828
/// Fitted track parameter
29-
detray::bound_track_parameters<algebra_t> fit_params;
29+
traccc::bound_track_parameters<algebra_t> fit_params;
3030

3131
/// Number of degree of freedoms of fitted track
3232
scalar_type ndf{0};
@@ -55,8 +55,8 @@ struct track_state {
5555
using size_type = detray::dsize_type<algebra_t>;
5656

5757
using bound_track_parameters_type =
58-
detray::bound_track_parameters<algebra_t>;
59-
using bound_matrix = detray::bound_matrix<algebra_t>;
58+
traccc::bound_track_parameters<algebra_t>;
59+
using bound_matrix_type = traccc::bound_matrix<algebra_t>;
6060
template <size_type ROWS, size_type COLS>
6161
using matrix_type = detray::dmatrix<algebra_t, ROWS, COLS>;
6262

@@ -154,11 +154,11 @@ struct track_state {
154154

155155
/// @return the non-const transport jacobian
156156
TRACCC_HOST_DEVICE
157-
inline bound_matrix& jacobian() { return m_jacobian; }
157+
inline bound_matrix_type& jacobian() { return m_jacobian; }
158158

159159
/// @return the const transport jacobian
160160
TRACCC_HOST_DEVICE
161-
inline const bound_matrix& jacobian() const { return m_jacobian; }
161+
inline const bound_matrix_type& jacobian() const { return m_jacobian; }
162162

163163
/// @return the non-const chi square of filtered parameter
164164
TRACCC_HOST_DEVICE
@@ -210,7 +210,7 @@ struct track_state {
210210
private:
211211
detray::geometry::barcode m_surface_link;
212212
measurement m_measurement;
213-
bound_matrix m_jacobian = matrix::zero<bound_matrix>();
213+
bound_matrix_type m_jacobian = matrix::zero<bound_matrix_type>();
214214
bound_track_parameters_type m_predicted;
215215
scalar_type m_filtered_chi2 = 0.f;
216216
bound_track_parameters_type m_filtered;

core/include/traccc/finding/details/find_tracks.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ track_candidate_container_types::host find_tracks(
137137
bound_track_parameters_collection_types::const_device seeds{seeds_view};
138138

139139
// Copy seed to input parameters
140-
std::vector<bound_track_parameters> in_params(seeds.size());
140+
std::vector<bound_track_parameters<algebra_type>> in_params(seeds.size());
141141
std::copy(seeds.begin(), seeds.end(), in_params.begin());
142142
std::vector<unsigned int> n_trks_per_seed(seeds.size());
143143

144-
std::vector<bound_track_parameters> out_params;
144+
std::vector<bound_track_parameters<algebra_type>> out_params;
145145

146146
for (unsigned int step = 0u; step < config.max_track_candidates_per_track;
147147
step++) {
@@ -164,12 +164,13 @@ track_candidate_container_types::host find_tracks(
164164
std::fill(n_trks_per_seed.begin(), n_trks_per_seed.end(), 0u);
165165

166166
// Parameters updated by Kalman fitter
167-
std::vector<bound_track_parameters> updated_params;
167+
std::vector<bound_track_parameters<algebra_type>> updated_params;
168168

169169
for (unsigned int in_param_id = 0; in_param_id < n_in_params;
170170
in_param_id++) {
171171

172-
bound_track_parameters& in_param = in_params[in_param_id];
172+
bound_track_parameters<algebra_type>& in_param =
173+
in_params[in_param_id];
173174
const unsigned int orig_param_id =
174175
(step == 0
175176
? in_param_id

core/include/traccc/fitting/kalman_filter/gain_matrix_smoother.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ struct gain_matrix_smoother {
2424
using size_type = detray::dsize_type<algebra_t>;
2525
template <size_type ROWS, size_type COLS>
2626
using matrix_type = detray::dmatrix<algebra_t, ROWS, COLS>;
27-
using bound_vector_type = detray::bound_vector<algebra_t>;
28-
using bound_matrix_type = detray::bound_matrix<algebra_t>;
27+
using bound_vector_type = traccc::bound_vector<algebra_t>;
28+
using bound_matrix_type = traccc::bound_matrix<algebra_t>;
2929

3030
/// Gain matrix smoother operation
3131
///

core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ struct gain_matrix_updater {
2222
using size_type = detray::dsize_type<algebra_t>;
2323
template <size_type ROWS, size_type COLS>
2424
using matrix_type = detray::dmatrix<algebra_t, ROWS, COLS>;
25-
using bound_vector_type = detray::bound_vector<algebra_t>;
26-
using bound_matrix_type = detray::bound_matrix<algebra_t>;
25+
using bound_vector_type = traccc::bound_vector<algebra_t>;
26+
using bound_matrix_type = traccc::bound_matrix<algebra_t>;
2727

2828
/// Gain matrix updater operation
2929
///
@@ -40,7 +40,7 @@ struct gain_matrix_updater {
4040
TRACCC_HOST_DEVICE inline bool operator()(
4141
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
4242
track_state<algebra_t>& trk_state,
43-
const bound_track_parameters& bound_params) const {
43+
const bound_track_parameters<algebra_t>& bound_params) const {
4444

4545
using shape_type = typename mask_group_t::value_type::shape;
4646

@@ -63,7 +63,7 @@ struct gain_matrix_updater {
6363
template <size_type D, typename shape_t>
6464
TRACCC_HOST_DEVICE inline bool update(
6565
track_state<algebra_t>& trk_state,
66-
const bound_track_parameters& bound_params) const {
66+
const bound_track_parameters<algebra_t>& bound_params) const {
6767

6868
static_assert(((D == 1u) || (D == 2u)),
6969
"The measurement dimension should be 1 or 2");

core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct two_filters_smoother {
3535
TRACCC_HOST_DEVICE inline bool operator()(
3636
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
3737
track_state<algebra_t>& trk_state,
38-
bound_track_parameters& bound_params) const {
38+
bound_track_parameters<algebra_t>& bound_params) const {
3939

4040
using shape_type = typename mask_group_t::value_type::shape;
4141

@@ -55,7 +55,7 @@ struct two_filters_smoother {
5555
template <size_type D, typename shape_t>
5656
TRACCC_HOST_DEVICE inline bool smoothe(
5757
track_state<algebra_t>& trk_state,
58-
bound_track_parameters& bound_params) const {
58+
bound_track_parameters<algebra_t>& bound_params) const {
5959

6060
assert(trk_state.filtered().surface_link() ==
6161
bound_params.surface_link());

core/include/traccc/seeding/track_params_estimation_helper.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ inline TRACCC_HOST_DEVICE vector2 uv_transform(const scalar& x,
3939
/// @param bfield is the magnetic field
4040
/// @param mass is the mass of particle
4141
template <typename spacepoint_collection_t>
42-
inline TRACCC_HOST_DEVICE bound_vector
43-
seed_to_bound_vector(const spacepoint_collection_t& sp_collection,
44-
const seed& seed, const vector3& bfield) {
42+
inline TRACCC_HOST_DEVICE bound_vector<> seed_to_bound_vector(
43+
const spacepoint_collection_t& sp_collection, const seed& seed,
44+
const vector3& bfield) {
4545

46-
bound_vector params = matrix::zero<bound_vector>();
46+
bound_vector<> params = matrix::zero<bound_vector<>>();
4747

4848
const auto& spB =
4949
sp_collection.at(static_cast<unsigned int>(seed.spB_link));

core/include/traccc/utils/particle.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ template <typename scalar_t>
7777
TRACCC_HOST_DEVICE inline detray::pdg_particle<scalar_t>
7878
correct_particle_hypothesis(
7979
const detray::pdg_particle<scalar_t>& ptc_hypothesis,
80-
const bound_track_parameters& params) {
80+
const bound_track_parameters<>& params) {
8181

8282
if (ptc_hypothesis.charge() * params.qop() > 0.f) {
8383
return ptc_hypothesis;
@@ -90,7 +90,7 @@ template <typename scalar_t>
9090
TRACCC_HOST_DEVICE inline detray::pdg_particle<scalar_t>
9191
correct_particle_hypothesis(
9292
const detray::pdg_particle<scalar_t>& ptc_hypothesis,
93-
const free_track_parameters& params) {
93+
const free_track_parameters<>& params) {
9494

9595
if (ptc_hypothesis.charge() * params.qop() > 0.f) {
9696
return ptc_hypothesis;

core/include/traccc/utils/seed_generator.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,20 @@ struct seed_generator {
4747
///
4848
/// @param vertex vertex of particle
4949
/// @param stddevs standard deviations for track parameter smearing
50-
bound_track_parameters operator()(
50+
bound_track_parameters<algebra_type> operator()(
5151
const detray::geometry::barcode surface_link,
52-
const free_track_parameters& free_param,
52+
const free_track_parameters<algebra_type>& free_param,
5353
const detray::pdg_particle<scalar>& ptc_type) {
5454

5555
// Get bound parameter
5656
const detray::tracking_surface sf{m_detector, surface_link};
5757

5858
const cxt_t ctx{};
5959
auto bound_vec = sf.free_to_bound_vector(ctx, free_param);
60-
auto bound_cov = matrix::zero<detray::bound_matrix<algebra_type>>();
60+
auto bound_cov = matrix::zero<traccc::bound_matrix<algebra_type>>();
6161

62-
bound_track_parameters bound_param{surface_link, bound_vec, bound_cov};
62+
bound_track_parameters<algebra_type> bound_param{surface_link,
63+
bound_vec, bound_cov};
6364

6465
// Type definitions
6566
using interactor_type =

core/src/seeding/track_params_estimation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ track_params_estimation::output_type track_params_estimation::operator()(
2525
output_type result(num_seeds, &m_mr.get());
2626

2727
for (seed_collection_types::host::size_type i = 0; i < num_seeds; ++i) {
28-
bound_track_parameters track_params;
28+
bound_track_parameters<> track_params;
2929
track_params.set_vector(
3030
seed_to_bound_vector(spacepoints, seeds[i], bfield));
3131

device/common/include/traccc/edm/device/sort_key.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ namespace traccc::device {
1616

1717
using sort_key = traccc::scalar;
1818

19-
TRACCC_HOST_DEVICE
20-
inline sort_key get_sort_key(const bound_track_parameters& params) {
19+
template <detray::concepts::algebra algebra_t>
20+
TRACCC_HOST_DEVICE inline sort_key get_sort_key(
21+
const bound_track_parameters<algebra_t>& params) {
2122
// key = |theta - pi/2|
2223
return math::fabs(params.theta() - constant<traccc::scalar>::pi_2);
2324
}

device/common/include/traccc/finding/device/impl/find_tracks.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ TRACCC_DEVICE inline void find_tracks(
191191
owner_local_thread_id +
192192
thread_id.getBlockDimX() * thread_id.getBlockIdX();
193193
assert(in_params_liveness.at(owner_global_thread_id) != 0u);
194-
const bound_track_parameters& in_par =
194+
const bound_track_parameters<>& in_par =
195195
in_params.at(owner_global_thread_id);
196196
const unsigned int meas_idx =
197197
shared_payload.shared_candidates[thread_id.getLocalThreadIdX()]

device/common/include/traccc/seeding/device/impl/estimate_track_params.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ inline void estimate_track_params(
3535
const seed& this_seed = seeds_device.at(globalIndex);
3636

3737
// Get bound track parameter
38-
bound_track_parameters track_params;
38+
bound_track_parameters<> track_params;
3939
track_params.set_vector(
4040
seed_to_bound_vector(spacepoints_device, this_seed, bfield));
4141

examples/run/cuda/seeding_example_cuda.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
368368
vecmem::get_data(seeds_cuda));
369369

370370
// Compare the track parameters made on the host and on the device.
371-
traccc::collection_comparator<traccc::bound_track_parameters>
371+
traccc::collection_comparator<traccc::bound_track_parameters<>>
372372
compare_track_parameters{"track parameters"};
373373
compare_track_parameters(vecmem::get_data(params),
374374
vecmem::get_data(params_cuda));

examples/run/cuda/seq_example_cuda.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ int seq_run(const traccc::opts::detector& detector_opts,
396396
vecmem::get_data(seeds_cuda));
397397

398398
// Compare the track parameters made on the host and on the device.
399-
traccc::collection_comparator<traccc::bound_track_parameters>
399+
traccc::collection_comparator<traccc::bound_track_parameters<>>
400400
compare_track_parameters{"track parameters"};
401401
compare_track_parameters(vecmem::get_data(params),
402402
vecmem::get_data(params_cuda));

0 commit comments

Comments
 (0)