Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI options to configuration type conversion #626

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions examples/options/include/traccc/options/clusterization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,29 @@
#pragma once

// Local include(s).
#include "traccc/options/details/config_provider.hpp"
#include "traccc/options/details/interface.hpp"

namespace traccc::opts {

/// Options for the cell clusterization algorithm(s)
class clusterization : public interface {
class clusterization : public interface,
public config_provider<unsigned short> {

public:
/// Constructor
clusterization();

/// Configuration conversion
operator unsigned short() const override;

private:
/// @name Options
/// @{

/// The number of cells to merge in a partition
unsigned short target_cells_per_partition = 1024;

/// @}

/// Constructor
clusterization();

private:
/// Print the specific options of this class
std::ostream& print_impl(std::ostream& out) const override;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

namespace traccc::opts {
/**
* @brief Mixin type to indicate that some set of program options can be
* converted to some configuration type.
*
* @tparam Config The config type to which this can be converted
*/
template <typename Config>
class config_provider {
public:
using config_type = Config;

virtual operator config_type() const = 0;
};
} // namespace traccc::opts
20 changes: 13 additions & 7 deletions examples/options/include/traccc/options/track_finding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma once

// Project include(s).
#include "traccc/finding/finding_config.hpp"
#include "traccc/options/details/config_provider.hpp"
#include "traccc/options/details/interface.hpp"
#include "traccc/options/details/value_array.hpp"

Expand All @@ -20,12 +22,21 @@
namespace traccc::opts {

/// Configuration for track finding
class track_finding : public interface {
class track_finding : public interface,
public config_provider<finding_config<float>>,
public config_provider<finding_config<double>> {

public:
/// Constructor
track_finding();

/// Configuration conversion operators
operator finding_config<float>() const override;
operator finding_config<double>() const override;

private:
/// @name Options
/// @{

/// Number of track candidates per seed
opts::value_array<unsigned int, 2> track_candidates_range{3, 100};
/// Minimum step length that track should make to reach the next surface. It
Expand All @@ -40,13 +51,8 @@ class track_finding : public interface {
unsigned int nmax_per_seed = 10;
/// Maximum allowed number of skipped steps per candidate
unsigned int max_num_skipping_per_cand = 3;

/// @}

/// Constructor
track_finding();

private:
/// Print the specific options of this class
std::ostream& print_impl(std::ostream& out) const override;

Expand Down
21 changes: 12 additions & 9 deletions examples/options/include/traccc/options/track_propagation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#pragma once

// Local include(s).
#include "traccc/options/details/config_provider.hpp"
#include "traccc/options/details/interface.hpp"
#include "traccc/options/details/value_array.hpp"

Expand All @@ -17,17 +18,10 @@
namespace traccc::opts {

/// Command line options used in the propagation tests
class track_propagation : public interface {
class track_propagation : public interface,
public config_provider<detray::propagation::config> {

public:
/// @name Options
/// @{

/// Propagation configuration object
detray::propagation::config config;

/// @}

/// Constructor
track_propagation();

Expand All @@ -37,7 +31,16 @@ class track_propagation : public interface {
///
void read(const boost::program_options::variables_map& vm) override;

/// Configuration provider
operator detray::propagation::config() const override;

private:
/// @name Options
/// @{
/// Propagation configuration object
detray::propagation::config config;
/// @}

/// Print the specific options of this class
std::ostream& print_impl(std::ostream& out) const override;

Expand Down
4 changes: 4 additions & 0 deletions examples/options/src/clusterization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ clusterization::clusterization() : interface("Clusterization Options") {
"The number of cells to merge in a partition");
}

clusterization::operator unsigned short() const {
return target_cells_per_partition;
}

std::ostream& clusterization::print_impl(std::ostream& out) const {

out << " Target cells per partition: " << target_cells_per_partition;
Expand Down
24 changes: 24 additions & 0 deletions examples/options/src/track_finding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ track_finding::track_finding() : interface("Track Finding Options") {
"Maximum allowed number of skipped steps per candidate");
}

track_finding::operator finding_config<float>() const {
finding_config<float> out;
out.min_track_candidates_per_track = track_candidates_range[0];
out.max_track_candidates_per_track = track_candidates_range[1];
out.min_step_length_for_next_surface = min_step_length_for_next_surface;
out.max_step_counts_for_next_surface = max_step_counts_for_next_surface;
out.chi2_max = chi2_max;
out.max_num_branches_per_seed = nmax_per_seed;
out.max_num_skipping_per_cand = max_num_skipping_per_cand;
return out;
}

track_finding::operator finding_config<double>() const {
finding_config<double> out;
out.min_track_candidates_per_track = track_candidates_range[0];
out.max_track_candidates_per_track = track_candidates_range[1];
out.min_step_length_for_next_surface = min_step_length_for_next_surface;
out.max_step_counts_for_next_surface = max_step_counts_for_next_surface;
out.chi2_max = chi2_max;
out.max_num_branches_per_seed = nmax_per_seed;
out.max_num_skipping_per_cand = max_num_skipping_per_cand;
return out;
}

std::ostream& track_finding::print_impl(std::ostream& out) const {

out << " Track candidates range : " << track_candidates_range << "\n"
Expand Down
4 changes: 4 additions & 0 deletions examples/options/src/track_propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ void track_propagation::read(const po::variables_map&) {
config.navigation.search_window = m_search_window;
}

track_propagation::operator detray::propagation::config() const {
return config;
}

std::ostream& track_propagation::print_impl(std::ostream& out) const {

out << config;
Expand Down
23 changes: 6 additions & 17 deletions examples/run/common/throughput_mt.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,13 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
cached_host_mrs{threading_opts.threads + 1};

// Algorithm configuration(s).
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg;
finding_cfg.min_track_candidates_per_track =
finding_opts.track_candidates_range[0];
finding_cfg.max_track_candidates_per_track =
finding_opts.track_candidates_range[1];
finding_cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
finding_cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
finding_cfg.chi2_max = finding_opts.chi2_max;
finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
finding_cfg.max_num_skipping_per_cand =
finding_opts.max_num_skipping_per_cand;
finding_cfg.propagation = propagation_opts.config;
detray::propagation::config propagation_config(propagation_opts);
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg(
finding_opts);
finding_cfg.propagation = propagation_config;

typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg;
fitting_cfg.propagation = propagation_opts.config;
fitting_cfg.propagation = propagation_config;

// Set up the full-chain algorithm(s). One for each thread.
std::vector<FULL_CHAIN_ALG> algs;
Expand All @@ -170,7 +160,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
: static_cast<vecmem::memory_resource&>(uncached_host_mr);
algs.push_back(
{alg_host_mr,
clusterization_opts.target_cells_per_partition,
clusterization_opts,
seeding_opts.seedfinder,
{seeding_opts.seedfinder},
seeding_opts.seedfilter,
Expand Down Expand Up @@ -267,7 +257,6 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
<< "," << threading_opts.threads << "," << input_opts.events
<< "," << throughput_opts.cold_run_events << ","
<< throughput_opts.processed_events << ","
<< clusterization_opts.target_cells_per_partition << ","
<< times.get_time("Warm-up processing").count() << ","
<< times.get_time("Event processing").count() << std::endl;
logFile.close();
Expand Down
23 changes: 6 additions & 17 deletions examples/run/common/throughput_st.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,17 @@ int throughput_st(std::string_view description, int argc, char* argv[],
}

// Algorithm configuration(s).
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg;
finding_cfg.min_track_candidates_per_track =
finding_opts.track_candidates_range[0];
finding_cfg.max_track_candidates_per_track =
finding_opts.track_candidates_range[1];
finding_cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
finding_cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
finding_cfg.chi2_max = finding_opts.chi2_max;
finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
finding_cfg.max_num_skipping_per_cand =
finding_opts.max_num_skipping_per_cand;
finding_cfg.propagation = propagation_opts.config;
detray::propagation::config propagation_config(propagation_opts);
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg(
finding_opts);
finding_cfg.propagation = propagation_config;

typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg;
fitting_cfg.propagation = propagation_opts.config;
fitting_cfg.propagation = propagation_config;

// Set up the full-chain algorithm.
std::unique_ptr<FULL_CHAIN_ALG> alg = std::make_unique<FULL_CHAIN_ALG>(
alg_host_mr, clusterization_opts.target_cells_per_partition,
seeding_opts.seedfinder,
alg_host_mr, clusterization_opts, seeding_opts.seedfinder,
spacepoint_grid_config{seeding_opts.seedfinder},
seeding_opts.seedfilter, finding_cfg, fitting_cfg,
(detector_opts.use_detray_detector ? &detector : nullptr));
Expand Down
22 changes: 8 additions & 14 deletions examples/run/cpu/seeding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,21 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
seeding_opts.seedfilter, host_mr);
traccc::track_params_estimation tp(host_mr);

// Propagation configuration
detray::propagation::config propagation_config(propagation_opts);

// Finding algorithm configuration
typename traccc::finding_algorithm<rk_stepper_type,
host_navigator_type>::config_type cfg;

cfg.min_track_candidates_per_track = finding_opts.track_candidates_range[0];
cfg.max_track_candidates_per_track = finding_opts.track_candidates_range[1];
cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
cfg.chi2_max = finding_opts.chi2_max;
cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
cfg.max_num_skipping_per_cand = finding_opts.max_num_skipping_per_cand;
cfg.propagation = propagation_opts.config;
typename traccc::finding_algorithm<
rk_stepper_type, host_navigator_type>::config_type cfg(finding_opts);

cfg.propagation = propagation_config;

traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.propagation = propagation_opts.config;
fit_cfg.propagation = propagation_config;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

Expand Down
20 changes: 5 additions & 15 deletions examples/run/cpu/seq_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,13 @@ int seq_run(const traccc::opts::input_data& input_opts,
detray::bfield::create_const_field(field_vec);

// Algorithm configuration(s).
finding_algorithm::config_type finding_cfg;
finding_cfg.min_track_candidates_per_track =
finding_opts.track_candidates_range[0];
finding_cfg.max_track_candidates_per_track =
finding_opts.track_candidates_range[1];
finding_cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
finding_cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
finding_cfg.chi2_max = finding_opts.chi2_max;
finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
finding_cfg.max_num_skipping_per_cand =
finding_opts.max_num_skipping_per_cand;
finding_cfg.propagation = propagation_opts.config;
detray::propagation::config propagation_config(propagation_opts);

finding_algorithm::config_type finding_cfg(finding_opts);
finding_cfg.propagation = propagation_config;

fitting_algorithm::config_type fitting_cfg;
fitting_cfg.propagation = propagation_opts.config;
fitting_cfg.propagation = propagation_config;

// Algorithms
traccc::host::clusterization_algorithm ca(host_mr);
Expand Down
20 changes: 7 additions & 13 deletions examples/run/cpu/truth_finding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,21 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
1e-4 / detray::unit<traccc::scalar>::GeV,
1e-4 * detray::unit<traccc::scalar>::ns};

// Propagation configuration
detray::propagation::config propagation_config(propagation_opts);

// Finding algorithm configuration
typename traccc::finding_algorithm<rk_stepper_type,
host_navigator_type>::config_type cfg;
cfg.min_track_candidates_per_track = finding_opts.track_candidates_range[0];
cfg.max_track_candidates_per_track = finding_opts.track_candidates_range[1];
cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
cfg.chi2_max = finding_opts.chi2_max;
cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
cfg.max_num_skipping_per_cand = finding_opts.max_num_skipping_per_cand;
cfg.propagation = propagation_opts.config;
typename traccc::finding_algorithm<
rk_stepper_type, host_navigator_type>::config_type cfg(finding_opts);
cfg.propagation = propagation_config;

// Finding algorithm object
traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.propagation = propagation_opts.config;
fit_cfg.propagation = propagation_config;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

Expand Down
2 changes: 1 addition & 1 deletion examples/run/cpu/truth_fitting_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ int main(int argc, char* argv[]) {

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.propagation = propagation_opts.config;
fit_cfg.propagation = propagation_opts;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

Expand Down
Loading
Loading