13
13
#include < vecmem/utils/cuda/async_copy.hpp>
14
14
15
15
#include " tests/cca_test.hpp"
16
+ #include " traccc/clusterization/clustering_config.hpp"
16
17
#include " traccc/cuda/clusterization/clusterization_algorithm.hpp"
17
18
#include " traccc/cuda/utils/stream.hpp"
18
19
19
20
namespace {
20
21
21
- cca_function_t f = [](const traccc::cell_collection_types::host& cells,
22
- const traccc::cell_module_collection_types::host&
23
- modules) {
24
- std::map<traccc::geometry_id, vecmem::vector<traccc::measurement>> result;
22
+ cca_function_t get_f_with (traccc::clustering_config cfg) {
23
+ return [cfg](const traccc::cell_collection_types::host& cells,
24
+ const traccc::cell_module_collection_types::host& modules) {
25
+ std::map<traccc::geometry_id, vecmem::vector<traccc::measurement>>
26
+ result;
25
27
26
- vecmem::host_memory_resource host_mr;
27
- traccc::cuda::stream stream;
28
- vecmem::cuda::device_memory_resource device_mr;
29
- vecmem::cuda::async_copy copy{stream.cudaStream ()};
28
+ vecmem::host_memory_resource host_mr;
29
+ traccc::cuda::stream stream;
30
+ vecmem::cuda::device_memory_resource device_mr;
31
+ vecmem::cuda::async_copy copy{stream.cudaStream ()};
30
32
31
- traccc::cuda::clusterization_algorithm cc ({device_mr}, copy, stream,
32
- default_ccl_test_config () );
33
+ traccc::cuda::clusterization_algorithm cc ({device_mr}, copy, stream,
34
+ cfg );
33
35
34
- traccc::cell_collection_types::buffer cells_buffer{
35
- static_cast <traccc::cell_collection_types::buffer::size_type>(
36
- cells.size ()),
37
- device_mr};
38
- copy.setup (cells_buffer);
39
- copy (vecmem::get_data (cells), cells_buffer)->ignore ();
36
+ traccc::cell_collection_types::buffer cells_buffer{
37
+ static_cast <traccc::cell_collection_types::buffer::size_type>(
38
+ cells.size ()),
39
+ device_mr};
40
+ copy.setup (cells_buffer);
41
+ copy (vecmem::get_data (cells), cells_buffer)->ignore ();
40
42
41
- traccc::cell_module_collection_types::buffer modules_buffer{
42
- static_cast <traccc::cell_module_collection_types::buffer::size_type>(
43
- modules.size ()),
44
- device_mr};
45
- copy.setup (modules_buffer);
46
- copy (vecmem::get_data (modules), modules_buffer)->ignore ();
43
+ traccc::cell_module_collection_types::buffer modules_buffer{
44
+ static_cast <
45
+ traccc::cell_module_collection_types::buffer::size_type>(
46
+ modules.size ()),
47
+ device_mr};
48
+ copy.setup (modules_buffer);
49
+ copy (vecmem::get_data (modules), modules_buffer)->ignore ();
47
50
48
- auto measurements_buffer = cc (cells_buffer, modules_buffer);
49
- traccc::measurement_collection_types::host measurements{&host_mr};
50
- copy (measurements_buffer, measurements)->wait ();
51
+ auto measurements_buffer = cc (cells_buffer, modules_buffer);
52
+ traccc::measurement_collection_types::host measurements{&host_mr};
53
+ copy (measurements_buffer, measurements)->wait ();
51
54
52
- for (std::size_t i = 0 ; i < measurements.size (); i++) {
53
- result[modules.at (measurements.at (i).module_link ).surface_link .value ()]
54
- .push_back (measurements.at (i));
55
- }
55
+ for (std::size_t i = 0 ; i < measurements.size (); i++) {
56
+ result[modules.at (measurements.at (i).module_link )
57
+ .surface_link .value ()]
58
+ .push_back (measurements.at (i));
59
+ }
56
60
57
- return result;
58
- };
61
+ return result;
62
+ };
63
+ }
59
64
} // namespace
60
65
61
66
TEST_P (ConnectedComponentAnalysisTests, Run) {
@@ -65,6 +70,14 @@ TEST_P(ConnectedComponentAnalysisTests, Run) {
65
70
INSTANTIATE_TEST_SUITE_P (
66
71
CUDAFastSvAlgorithm, ConnectedComponentAnalysisTests,
67
72
::testing::Combine (
68
- ::testing::Values (f ),
73
+ ::testing::Values (get_f_with(default_ccl_test_config()) ),
69
74
::testing::ValuesIn(ConnectedComponentAnalysisTests::get_test_files())),
70
75
ConnectedComponentAnalysisTests::get_test_name);
76
+
77
+ INSTANTIATE_TEST_SUITE_P (
78
+ CUDAFastSvAlgorithmWithScratch, ConnectedComponentAnalysisTests,
79
+ ::testing::Combine (
80
+ ::testing::Values (get_f_with(tiny_ccl_test_config())),
81
+ ::testing::ValuesIn(
82
+ ConnectedComponentAnalysisTests::get_test_files_short ())),
83
+ ConnectedComponentAnalysisTests::get_test_name);
0 commit comments