Skip to content

Commit 5590d65

Browse files
committed
Use the traccc track parameter types everywhere and allow templating of the algebra type
1 parent efb743d commit 5590d65

File tree

49 files changed

+136
-120
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+136
-120
lines changed

benchmarks/common/benchmarks/toy_detector_benchmark.hpp

+14-16
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
6565

6666
// Detector type
6767
using detector_type = traccc::toy_detector::host;
68+
using algebra_type = typename detector_type::algebra_type;
6869
using scalar_type = detector_type::scalar_type;
6970

7071
// B field value and its type
7172
// @TODO: Set B field as argument
7273
using b_field_t = covfie::field<detray::bfield::const_bknd_t<scalar_type>>;
7374

74-
static constexpr traccc::vector3 B{0, 0,
75-
2 * traccc::unit<traccc::scalar>::T};
75+
static constexpr traccc::vector3 B{0, 0, 2 * traccc::unit<scalar_type>::T};
7676

7777
ToyDetectorBenchmark() {
7878

@@ -86,20 +86,18 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
8686

8787
// Use deterministic random number generator for testing
8888
using uniform_gen_t = detray::detail::random_numbers<
89-
traccc::scalar, std::uniform_real_distribution<traccc::scalar>>;
89+
scalar_type, std::uniform_real_distribution<scalar_type>>;
9090

9191
// Build the detector
9292
auto [det, name_map] =
93-
detray::build_toy_detector<traccc::default_algebra>(
94-
host_mr, get_toy_config());
93+
detray::build_toy_detector<algebra_type>(host_mr, get_toy_config());
9594

9695
// B field
9796
auto field = detray::bfield::create_const_field<scalar_type>(B);
9897

9998
// Origin of particles
100-
using generator_type =
101-
detray::random_track_generator<traccc::free_track_parameters,
102-
uniform_gen_t>;
99+
using generator_type = detray::random_track_generator<
100+
traccc::free_track_parameters<algebra_type>, uniform_gen_t>;
103101
generator_type::configuration gen_cfg{};
104102
gen_cfg.n_tracks(n_tracks);
105103
gen_cfg.phi_range(phi_range);
@@ -109,12 +107,12 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
109107

110108
// Smearing value for measurements
111109
traccc::measurement_smearer<traccc::default_algebra> meas_smearer(
112-
50 * traccc::unit<traccc::scalar>::um,
113-
50 * traccc::unit<traccc::scalar>::um);
110+
50 * traccc::unit<scalar_type>::um,
111+
50 * traccc::unit<scalar_type>::um);
114112

115113
// Type declarations
116-
using writer_type = traccc::smearing_writer<
117-
traccc::measurement_smearer<traccc::default_algebra>>;
114+
using writer_type =
115+
traccc::smearing_writer<traccc::measurement_smearer<algebra_type>>;
118116

119117
// Writer config
120118
typename writer_type::config smearer_writer_cfg{meas_smearer};
@@ -126,7 +124,7 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
126124

127125
auto sim = traccc::simulator<detector_type, b_field_t, generator_type,
128126
writer_type>(
129-
detray::muon<traccc::scalar>(), n_events, det, field,
127+
detray::muon<scalar_type>(), n_events, det, field,
130128
std::move(generator), std::move(smearer_writer_cfg), full_path);
131129

132130
// Same propagation configuration for sim and reco
@@ -147,14 +145,14 @@ class ToyDetectorBenchmark : public benchmark::Fixture {
147145
detray::io::write_detector(det, name_map, writer_cfg);
148146
}
149147

150-
detray::toy_det_config<traccc::scalar> get_toy_config() const {
148+
detray::toy_det_config<scalar_type> get_toy_config() const {
151149

152150
// Create the toy geometry
153-
detray::toy_det_config<traccc::scalar> toy_cfg{};
151+
detray::toy_det_config<scalar_type> toy_cfg{};
154152
toy_cfg.n_brl_layers(4u).n_edc_layers(7u).do_check(false);
155153

156154
// @TODO: Increase the material budget again
157-
toy_cfg.module_mat_thickness(0.11f * traccc::unit<traccc::scalar>::mm);
155+
toy_cfg.module_mat_thickness(0.11f * traccc::unit<scalar_type>::mm);
158156

159157
return toy_cfg;
160158
}

core/include/traccc/edm/track_parameters.hpp

+25-13
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,33 @@
2020

2121
namespace traccc {
2222

23-
using free_track_parameters =
24-
detray::free_track_parameters<traccc::default_algebra>;
25-
using bound_track_parameters =
26-
detray::bound_track_parameters<traccc::default_algebra>;
27-
using free_vector = free_track_parameters::vector_type;
28-
using bound_vector = bound_track_parameters::vector_type;
29-
using bound_covariance = bound_track_parameters::covariance_type;
23+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
24+
using free_track_parameters = detray::free_track_parameters<algebra_t>;
25+
26+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
27+
using bound_track_parameters = detray::bound_track_parameters<algebra_t>;
28+
29+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
30+
using free_vector = typename free_track_parameters<algebra_t>::vector_type;
31+
32+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
33+
using bound_vector = typename bound_track_parameters<algebra_t>::vector_type;
34+
35+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
36+
using bound_covariance =
37+
typename bound_track_parameters<algebra_t>::covariance_type;
38+
39+
template <detray::concepts::algebra algebra_t = traccc::default_algebra>
40+
using bound_matrix = detray::bound_matrix<algebra_t>;
3041

3142
/// Declare all track_parameters collection types
3243
using bound_track_parameters_collection_types =
33-
collection_types<bound_track_parameters>;
44+
collection_types<bound_track_parameters<>>;
3445

3546
// Wrap the phi of track parameters to [-pi,pi]
36-
TRACCC_HOST_DEVICE
37-
inline void wrap_phi(bound_track_parameters& param) {
47+
template <detray::concepts::algebra algebra_t>
48+
TRACCC_HOST_DEVICE inline void wrap_phi(
49+
bound_track_parameters<algebra_t>& param) {
3850

3951
traccc::scalar phi = param.phi();
4052
static constexpr traccc::scalar TWOPI =
@@ -49,9 +61,9 @@ inline void wrap_phi(bound_track_parameters& param) {
4961
}
5062

5163
/// Covariance inflation used for track fitting
52-
TRACCC_HOST_DEVICE
53-
inline void inflate_covariance(bound_track_parameters& param,
54-
const traccc::scalar inf_fac) {
64+
template <detray::concepts::algebra algebra_t>
65+
TRACCC_HOST_DEVICE inline void inflate_covariance(
66+
bound_track_parameters<algebra_t>& param, const traccc::scalar inf_fac) {
5567
auto& cov = param.covariance();
5668
for (unsigned int i = 0; i < e_bound_size; i++) {
5769
for (unsigned int j = 0; j < e_bound_size; j++) {

core/include/traccc/edm/track_state.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct fitting_result {
2626
using scalar_type = detray::dscalar<algebra_t>;
2727

2828
/// Fitted track parameter
29-
detray::bound_track_parameters<algebra_t> fit_params;
29+
traccc::bound_track_parameters<algebra_t> fit_params;
3030

3131
/// Track quality
3232
traccc::track_quality trk_quality;
@@ -40,8 +40,8 @@ struct track_state {
4040
using size_type = detray::dsize_type<algebra_t>;
4141

4242
using bound_track_parameters_type =
43-
detray::bound_track_parameters<algebra_t>;
44-
using bound_matrix = detray::bound_matrix<algebra_t>;
43+
traccc::bound_track_parameters<algebra_t>;
44+
using bound_matrix_type = traccc::bound_matrix<algebra_t>;
4545
template <size_type ROWS, size_type COLS>
4646
using matrix_type = detray::dmatrix<algebra_t, ROWS, COLS>;
4747

@@ -139,11 +139,11 @@ struct track_state {
139139

140140
/// @return the non-const transport jacobian
141141
TRACCC_HOST_DEVICE
142-
inline bound_matrix& jacobian() { return m_jacobian; }
142+
inline bound_matrix_type& jacobian() { return m_jacobian; }
143143

144144
/// @return the const transport jacobian
145145
TRACCC_HOST_DEVICE
146-
inline const bound_matrix& jacobian() const { return m_jacobian; }
146+
inline const bound_matrix_type& jacobian() const { return m_jacobian; }
147147

148148
/// @return the non-const chi square of filtered parameter
149149
TRACCC_HOST_DEVICE
@@ -195,7 +195,7 @@ struct track_state {
195195
private:
196196
detray::geometry::barcode m_surface_link;
197197
measurement m_measurement;
198-
bound_matrix m_jacobian = matrix::zero<bound_matrix>();
198+
bound_matrix_type m_jacobian = matrix::zero<bound_matrix_type>();
199199
bound_track_parameters_type m_predicted;
200200
scalar_type m_filtered_chi2 = 0.f;
201201
bound_track_parameters_type m_filtered;

core/include/traccc/finding/details/find_tracks.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ track_candidate_container_types::host find_tracks(
133133
bound_track_parameters_collection_types::const_device seeds{seeds_view};
134134

135135
// Copy seed to input parameters
136-
std::vector<bound_track_parameters> in_params(seeds.size());
136+
std::vector<bound_track_parameters<algebra_type>> in_params(seeds.size());
137137
std::copy(seeds.begin(), seeds.end(), in_params.begin());
138138
std::vector<unsigned int> n_trks_per_seed(seeds.size());
139139

140-
std::vector<bound_track_parameters> out_params;
140+
std::vector<bound_track_parameters<algebra_type>> out_params;
141141

142142
for (unsigned int step = 0u; step < config.max_track_candidates_per_track;
143143
step++) {
@@ -160,12 +160,13 @@ track_candidate_container_types::host find_tracks(
160160
std::fill(n_trks_per_seed.begin(), n_trks_per_seed.end(), 0u);
161161

162162
// Parameters updated by Kalman fitter
163-
std::vector<bound_track_parameters> updated_params;
163+
std::vector<bound_track_parameters<algebra_type>> updated_params;
164164

165165
for (unsigned int in_param_id = 0; in_param_id < n_in_params;
166166
in_param_id++) {
167167

168-
bound_track_parameters& in_param = in_params[in_param_id];
168+
bound_track_parameters<algebra_type>& in_param =
169+
in_params[in_param_id];
169170
const unsigned int orig_param_id =
170171
(step == 0
171172
? in_param_id

core/include/traccc/fitting/kalman_filter/gain_matrix_smoother.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ struct gain_matrix_smoother {
2727
using size_type = detray::dsize_type<algebra_t>;
2828
template <size_type ROWS, size_type COLS>
2929
using matrix_type = detray::dmatrix<algebra_t, ROWS, COLS>;
30-
using bound_vector_type = detray::bound_vector<algebra_t>;
31-
using bound_matrix_type = detray::bound_matrix<algebra_t>;
30+
using bound_vector_type = traccc::bound_vector<algebra_t>;
31+
using bound_matrix_type = traccc::bound_matrix<algebra_t>;
3232

3333
/// Gain matrix smoother operation
3434
///

core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ struct gain_matrix_updater {
2525
using size_type = detray::dsize_type<algebra_t>;
2626
template <size_type ROWS, size_type COLS>
2727
using matrix_type = detray::dmatrix<algebra_t, ROWS, COLS>;
28-
using bound_vector_type = detray::bound_vector<algebra_t>;
29-
using bound_matrix_type = detray::bound_matrix<algebra_t>;
28+
using bound_vector_type = traccc::bound_vector<algebra_t>;
29+
using bound_matrix_type = traccc::bound_matrix<algebra_t>;
3030

3131
/// Gain matrix updater operation
3232
///
@@ -43,7 +43,7 @@ struct gain_matrix_updater {
4343
TRACCC_HOST_DEVICE inline bool operator()(
4444
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
4545
track_state<algebra_t>& trk_state,
46-
const bound_track_parameters& bound_params) const {
46+
const bound_track_parameters<algebra_t>& bound_params) const {
4747

4848
using shape_type = typename mask_group_t::value_type::shape;
4949

@@ -66,7 +66,7 @@ struct gain_matrix_updater {
6666
template <size_type D, typename shape_t>
6767
TRACCC_HOST_DEVICE inline bool update(
6868
track_state<algebra_t>& trk_state,
69-
const bound_track_parameters& bound_params) const {
69+
const bound_track_parameters<algebra_t>& bound_params) const {
7070

7171
static_assert(((D == 1u) || (D == 2u)),
7272
"The measurement dimension should be 1 or 2");

core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct two_filters_smoother {
3535
TRACCC_HOST_DEVICE inline bool operator()(
3636
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
3737
track_state<algebra_t>& trk_state,
38-
bound_track_parameters& bound_params) const {
38+
bound_track_parameters<algebra_t>& bound_params) const {
3939

4040
using shape_type = typename mask_group_t::value_type::shape;
4141

@@ -55,7 +55,7 @@ struct two_filters_smoother {
5555
template <size_type D, typename shape_t>
5656
TRACCC_HOST_DEVICE inline bool smoothe(
5757
track_state<algebra_t>& trk_state,
58-
bound_track_parameters& bound_params) const {
58+
bound_track_parameters<algebra_t>& bound_params) const {
5959

6060
assert(trk_state.filtered().surface_link() ==
6161
bound_params.surface_link());

core/include/traccc/seeding/track_params_estimation_helper.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ inline TRACCC_HOST_DEVICE vector2 uv_transform(const scalar& x,
3939
/// @param bfield is the magnetic field
4040
/// @param mass is the mass of particle
4141
template <typename spacepoint_collection_t>
42-
inline TRACCC_HOST_DEVICE bound_vector
43-
seed_to_bound_vector(const spacepoint_collection_t& sp_collection,
44-
const seed& seed, const vector3& bfield) {
42+
inline TRACCC_HOST_DEVICE bound_vector<> seed_to_bound_vector(
43+
const spacepoint_collection_t& sp_collection, const seed& seed,
44+
const vector3& bfield) {
4545

46-
bound_vector params = matrix::zero<bound_vector>();
46+
bound_vector<> params = matrix::zero<bound_vector<>>();
4747

4848
const auto& spB =
4949
sp_collection.at(static_cast<unsigned int>(seed.spB_link));

core/include/traccc/utils/particle.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ template <typename scalar_t>
7777
TRACCC_HOST_DEVICE inline detray::pdg_particle<scalar_t>
7878
correct_particle_hypothesis(
7979
const detray::pdg_particle<scalar_t>& ptc_hypothesis,
80-
const bound_track_parameters& params) {
80+
const bound_track_parameters<>& params) {
8181

8282
if (ptc_hypothesis.charge() * params.qop() > 0.f) {
8383
return ptc_hypothesis;
@@ -90,7 +90,7 @@ template <typename scalar_t>
9090
TRACCC_HOST_DEVICE inline detray::pdg_particle<scalar_t>
9191
correct_particle_hypothesis(
9292
const detray::pdg_particle<scalar_t>& ptc_hypothesis,
93-
const free_track_parameters& params) {
93+
const free_track_parameters<>& params) {
9494

9595
if (ptc_hypothesis.charge() * params.qop() > 0.f) {
9696
return ptc_hypothesis;

core/include/traccc/utils/seed_generator.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,20 @@ struct seed_generator {
4242
///
4343
/// @param vertex vertex of particle
4444
/// @param stddevs standard deviations for track parameter smearing
45-
bound_track_parameters operator()(
45+
bound_track_parameters<algebra_type> operator()(
4646
const detray::geometry::barcode surface_link,
47-
const free_track_parameters& free_param,
47+
const free_track_parameters<algebra_type>& free_param,
4848
const detray::pdg_particle<scalar>& ptc_type) {
4949

5050
// Get bound parameter
5151
const detray::tracking_surface sf{m_detector, surface_link};
5252

5353
const cxt_t ctx{};
5454
auto bound_vec = sf.free_to_bound_vector(ctx, free_param);
55-
auto bound_cov = matrix::zero<detray::bound_matrix<algebra_type>>();
55+
auto bound_cov = matrix::zero<traccc::bound_matrix<algebra_type>>();
5656

57-
bound_track_parameters bound_param{surface_link, bound_vec, bound_cov};
57+
bound_track_parameters<algebra_type> bound_param{surface_link,
58+
bound_vec, bound_cov};
5859

5960
// Type definitions
6061
using interactor_type =

core/src/seeding/track_params_estimation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ track_params_estimation::output_type track_params_estimation::operator()(
2525
output_type result(num_seeds, &m_mr.get());
2626

2727
for (seed_collection_types::host::size_type i = 0; i < num_seeds; ++i) {
28-
bound_track_parameters track_params;
28+
bound_track_parameters<> track_params;
2929
track_params.set_vector(
3030
seed_to_bound_vector(spacepoints, seeds[i], bfield));
3131

device/common/include/traccc/edm/device/sort_key.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ namespace traccc::device {
1616

1717
using sort_key = traccc::scalar;
1818

19-
TRACCC_HOST_DEVICE
20-
inline sort_key get_sort_key(const bound_track_parameters& params) {
19+
template <detray::concepts::algebra algebra_t>
20+
TRACCC_HOST_DEVICE inline sort_key get_sort_key(
21+
const bound_track_parameters<algebra_t>& params) {
2122
// key = |theta - pi/2|
2223
return math::fabs(params.theta() - constant<traccc::scalar>::pi_2);
2324
}

device/common/include/traccc/finding/device/impl/find_tracks.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ TRACCC_DEVICE inline void find_tracks(
194194
owner_local_thread_id +
195195
thread_id.getBlockDimX() * thread_id.getBlockIdX();
196196
assert(in_params_liveness.at(owner_global_thread_id) != 0u);
197-
const bound_track_parameters& in_par =
197+
const bound_track_parameters<>& in_par =
198198
in_params.at(owner_global_thread_id);
199199
const unsigned int meas_idx =
200200
shared_payload.shared_candidates[thread_id.getLocalThreadIdX()]

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
7474
}
7575

7676
// Input bound track parameter
77-
const bound_track_parameters in_par = params.at(param_id);
77+
const bound_track_parameters<> in_par = params.at(param_id);
7878

7979
// Create propagator
8080
propagator_t propagator(cfg.propagation);

device/common/include/traccc/seeding/device/impl/estimate_track_params.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ inline void estimate_track_params(
3535
const seed& this_seed = seeds_device.at(globalIndex);
3636

3737
// Get bound track parameter
38-
bound_track_parameters track_params;
38+
bound_track_parameters<> track_params;
3939
track_params.set_vector(
4040
seed_to_bound_vector(spacepoints_device, this_seed, bfield));
4141

examples/run/alpaka/seeding_example_alpaka.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
246246
vecmem::get_data(seeds_alpaka));
247247

248248
// Compare the track parameters made on the host and on the device.
249-
traccc::collection_comparator<traccc::bound_track_parameters>
249+
traccc::collection_comparator<traccc::bound_track_parameters<>>
250250
compare_track_parameters{"track parameters"};
251251
compare_track_parameters(vecmem::get_data(params),
252252
vecmem::get_data(params_alpaka));

examples/run/alpaka/seq_example_alpaka.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ int seq_run(const traccc::opts::detector& detector_opts,
279279
vecmem::get_data(seeds_alpaka));
280280

281281
// Compare the track parameters made on the host and on the device.
282-
traccc::collection_comparator<traccc::bound_track_parameters>
282+
traccc::collection_comparator<traccc::bound_track_parameters<>>
283283
compare_track_parameters{"track parameters"};
284284
compare_track_parameters(vecmem::get_data(params),
285285
vecmem::get_data(params_alpaka));

0 commit comments

Comments
 (0)