Skip to content

Commit

Permalink
simplify a bit TC SoAs
Browse files Browse the repository at this point in the history
  • Loading branch information
slava77devel committed Oct 18, 2024
1 parent cb84b9c commit 6f1ea3f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 59 deletions.
66 changes: 19 additions & 47 deletions RecoTracker/LSTCore/src/alpaka/Event.dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1602,16 +1602,16 @@ PixelQuintupletsBuffer<alpaka_common::DevHost>& Event::getPixelQuintuplets(bool
return pixelQuintupletsInCPU_.value();
}

const TrackCandidatesHostCollection& Event::getTrackCandidates(bool sync) {
const TrackCandidatesConst& Event::getTrackCandidatesWithSelection(bool inCMSSW, bool sync) {
if (!trackCandidatesHC_) {
// Get nTrackCanHost parameter to initialize host based instance
auto nTrackCanHost_buf_h = cms::alpakatools::make_host_buffer<unsigned int[]>(queue_, 1u);
alpaka::memcpy(
queue_, nTrackCanHost_buf_h, alpaka::createView(devAcc_, &(*trackCandidatesDC_)->nTrackCandidates(), 1u));
trackCandidatesHC_.emplace(n_max_nonpixel_track_candidates + n_max_pixel_track_candidates, queue_);
alpaka::wait(queue_); // wait here before we get nTrackCanHost and trackCandidatesInCPU becomes usable

auto const nTrackCanHost = *nTrackCanHost_buf_h.data();
trackCandidatesHC_.emplace(nTrackCanHost, queue_);

(*trackCandidatesHC_)->nTrackCandidates() = nTrackCanHost;
alpaka::memcpy(
Expand All @@ -1622,58 +1622,30 @@ const TrackCandidatesHostCollection& Event::getTrackCandidates(bool sync) {
alpaka::memcpy(queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->pixelSeedIndex(), nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->pixelSeedIndex(), nTrackCanHost));
alpaka::memcpy(queue_,
alpaka::createView(cms::alpakatools::host(),
(*trackCandidatesHC_)->logicalLayers()->data(),
Params_pT5::kLayers * nTrackCanHost),
alpaka::createView(
devAcc_, (*trackCandidatesDC_)->logicalLayers()->data(), Params_pT5::kLayers * nTrackCanHost));
alpaka::memcpy(
queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->directObjectIndices(), nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->directObjectIndices(), nTrackCanHost));
alpaka::memcpy(
queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->objectIndices()->data(), 2 * nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->objectIndices()->data(), 2 * nTrackCanHost));
alpaka::memcpy(
queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->trackCandidateType(), nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->trackCandidateType(), nTrackCanHost));
if (sync)
alpaka::wait(queue_); // host consumers expect filled data
}
return trackCandidatesHC_.value();
}

const TrackCandidatesHostCollection& Event::getTrackCandidatesInCMSSW(bool sync) {
if (!trackCandidatesHC_) {
// Get nTrackCanHost parameter to initialize host based instance
auto nTrackCanHost_buf_h = cms::alpakatools::make_host_buffer<unsigned int[]>(queue_, 1u);
alpaka::memcpy(
queue_, nTrackCanHost_buf_h, alpaka::createView(devAcc_, &(*trackCandidatesDC_)->nTrackCandidates(), 1u));
trackCandidatesHC_.emplace(n_max_nonpixel_track_candidates + n_max_pixel_track_candidates, queue_);
alpaka::wait(queue_); // wait for the value before using and trackCandidatesInCPU becomes usable

auto const nTrackCanHost = *nTrackCanHost_buf_h.data();

(*trackCandidatesHC_)->nTrackCandidates() = nTrackCanHost;
alpaka::memcpy(
queue_,
alpaka::createView(
cms::alpakatools::host(), (*trackCandidatesHC_)->hitIndices()->data(), Params_pT5::kHits * nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->hitIndices()->data(), Params_pT5::kHits * nTrackCanHost));
alpaka::memcpy(queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->pixelSeedIndex(), nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->pixelSeedIndex(), nTrackCanHost));
if (not inCMSSW) {
alpaka::memcpy(queue_,
alpaka::createView(cms::alpakatools::host(),
(*trackCandidatesHC_)->logicalLayers()->data(),
Params_pT5::kLayers * nTrackCanHost),
alpaka::createView(
devAcc_, (*trackCandidatesDC_)->logicalLayers()->data(), Params_pT5::kLayers * nTrackCanHost));
alpaka::memcpy(
queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->directObjectIndices(), nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->directObjectIndices(), nTrackCanHost));
alpaka::memcpy(queue_,
alpaka::createView(
cms::alpakatools::host(), (*trackCandidatesHC_)->objectIndices()->data(), 2 * nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->objectIndices()->data(), 2 * nTrackCanHost));
}
alpaka::memcpy(
queue_,
alpaka::createView(cms::alpakatools::host(), (*trackCandidatesHC_)->trackCandidateType(), nTrackCanHost),
alpaka::createView(devAcc_, (*trackCandidatesDC_)->trackCandidateType(), nTrackCanHost));
if (sync)
alpaka::wait(queue_); // host consumers expect filled data
}
return trackCandidatesHC_.value();
return trackCandidatesHC_.value().const_view();
}

ModulesBuffer<alpaka_common::DevHost>& Event::getModules(bool isFull, bool sync) {
Expand Down
9 changes: 7 additions & 2 deletions RecoTracker/LSTCore/src/alpaka/Event.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,13 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
QuintupletsBuffer<DevHost>& getQuintuplets(bool sync = true);
PixelTripletsBuffer<DevHost>& getPixelTriplets(bool sync = true);
PixelQuintupletsBuffer<DevHost>& getPixelQuintuplets(bool sync = true);
const TrackCandidatesHostCollection& getTrackCandidates(bool sync = true);
const TrackCandidatesHostCollection& getTrackCandidatesInCMSSW(bool sync = true);
const TrackCandidatesConst& getTrackCandidatesWithSelection(bool inCMSSW, bool sync);
const TrackCandidatesConst& getTrackCandidates(bool sync = true) {
return getTrackCandidatesWithSelection(false, sync);
}
const TrackCandidatesConst& getTrackCandidatesInCMSSW(bool sync = true) {
return getTrackCandidatesWithSelection(true, sync);
}
ModulesBuffer<DevHost>& getModules(bool isFull = false, bool sync = true);
};

Expand Down
2 changes: 1 addition & 1 deletion RecoTracker/LSTCore/src/alpaka/LST.dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void LST::getOutput(Event& event) {
std::vector<short> tc_trackCandidateType;

HitsBuffer<alpaka::DevCpu>& hitsBuffer = event.getHitsInCMSSW(false); // sync on next line
auto const& trackCandidates = event.getTrackCandidatesInCMSSW().const_view();
auto const& trackCandidates = event.getTrackCandidatesInCMSSW();

unsigned int nTrackCandidates = trackCandidates.nTrackCandidates();

Expand Down
4 changes: 2 additions & 2 deletions RecoTracker/LSTCore/standalone/code/core/AccessHelper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ std::tuple<std::vector<unsigned int>, std::vector<unsigned int>> getHitIdxsAndHi
//____________________________________________________________________________________________
std::vector<unsigned int> getLSsFromTC(Event* event, unsigned int iTC) {
// Get the type of the track candidate
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
short type = trackCandidates.trackCandidateType()[iTC];
unsigned int objidx = trackCandidates.directObjectIndices()[iTC];
switch (type) {
Expand All @@ -435,7 +435,7 @@ std::vector<unsigned int> getLSsFromTC(Event* event, unsigned int iTC) {
std::tuple<std::vector<unsigned int>, std::vector<unsigned int>> getHitIdxsAndHitTypesFromTC(Event* event,
unsigned iTC) {
// Get the type of the track candidate
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
short type = trackCandidates.trackCandidateType()[iTC];
unsigned int objidx = trackCandidates.directObjectIndices()[iTC];
switch (type) {
Expand Down
14 changes: 7 additions & 7 deletions RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void setOutputBranches(Event* event) {
std::vector<std::vector<int>> tc_matched_simIdx;

// ============ Track candidates =============
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
unsigned int nTrackCandidates = trackCandidates.nTrackCandidates();
for (unsigned int idx = 0; idx < nTrackCandidates; idx++) {
// Compute reco quantities of track candidate based on final object
Expand Down Expand Up @@ -506,7 +506,7 @@ void setGnnNtupleBranches(Event* event) {
Hits const* hitsEvt = event->getHits().data();
Modules const* modules = event->getModules().data();
ObjectRanges const* ranges = event->getRanges().data();
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();

std::set<unsigned int> mds_used_in_sg;
std::map<unsigned int, unsigned int> md_index_map;
Expand Down Expand Up @@ -710,7 +710,7 @@ void setGnnNtupleMiniDoublet(Event* event, unsigned int MD) {
//________________________________________________________________________________________________________________________________
std::tuple<int, float, float, float, int, std::vector<int>> parseTrackCandidate(Event* event, unsigned int idx) {
// Get the type of the track candidate
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
short type = trackCandidates.trackCandidateType()[idx];

enum { pT5 = 7, pT3 = 5, T5 = 4, pLS = 8 };
Expand Down Expand Up @@ -744,7 +744,7 @@ std::tuple<int, float, float, float, int, std::vector<int>> parseTrackCandidate(
std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned int>> parsepT5(Event* event,
unsigned int idx) {
// Get relevant information
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
Quintuplets const* quintuplets = event->getQuintuplets().data();
SegmentsPixelConst segmentsPixel = event->getSegments<SegmentsPixelSoA>();

Expand Down Expand Up @@ -856,7 +856,7 @@ std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned
std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned int>> parsepT3(Event* event,
unsigned int idx) {
// Get relevant information
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
Triplets const* triplets = event->getTriplets().data();
SegmentsPixelConst segmentsPixel = event->getSegments<SegmentsPixelSoA>();

Expand Down Expand Up @@ -890,7 +890,7 @@ std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned
//________________________________________________________________________________________________________________________________
std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned int>> parseT5(Event* event,
unsigned int idx) {
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
Quintuplets const* quintuplets = event->getQuintuplets().data();
unsigned int T5 = trackCandidates.directObjectIndices()[idx];
std::vector<unsigned int> hits = getHitsFromT5(event, T5);
Expand Down Expand Up @@ -924,7 +924,7 @@ std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned
//________________________________________________________________________________________________________________________________
std::tuple<float, float, float, std::vector<unsigned int>, std::vector<unsigned int>> parsepLS(Event* event,
unsigned int idx) {
auto const& trackCandidates = event->getTrackCandidates().const_view();
auto const& trackCandidates = event->getTrackCandidates();
SegmentsPixelConst segmentsPixel = event->getSegments<SegmentsPixelSoA>();

// Getting pLS index
Expand Down

0 comments on commit 6f1ea3f

Please sign in to comment.