Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ContactResultMap class to reduce heap allocations for multiple contact requests #869

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tesseract_collision/bullet/src/bullet_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,10 +688,10 @@ btScalar addDiscreteSingleResult(btManifoldPoint& cp,
const auto* cd0 = static_cast<const CollisionObjectWrapper*>(colObj0Wrap->getCollisionObject()); // NOLINT
const auto* cd1 = static_cast<const CollisionObjectWrapper*>(colObj1Wrap->getCollisionObject()); // NOLINT

ObjectPairKey pc = getObjectPairKey(cd0->getName(), cd1->getName());
ObjectPairKey pc = tesseract_common::makeOrderedLinkPair(cd0->getName(), cd1->getName());

const auto& it = collisions.res->find(pc);
bool found = (it != collisions.res->end());
const auto it = collisions.res->find(pc);
bool found = (it != collisions.res->end() && !it->second.empty());

// size_t l = 0;
// if (found)
Expand Down Expand Up @@ -823,8 +823,8 @@ btScalar addCastSingleResult(btManifoldPoint& cp,
std::make_pair(cd0->getName(), cd1->getName()) :
std::make_pair(cd1->getName(), cd0->getName());

auto it = collisions.res->find(pc);
bool found = it != collisions.res->end();
const auto it = collisions.res->find(pc);
bool found = (it != collisions.res->end() && !it->second.empty());

// size_t l = 0;
// if (found)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ namespace tesseract_collision
{
using ObjectPairKey = std::pair<std::string, std::string>;

/**
* @brief Get a key for two object to search the collision matrix
* @param obj1 First collision object name
* @param obj2 Second collision object name
* @return The collision pair key
*/
ObjectPairKey getObjectPairKey(const std::string& obj1, const std::string& obj2);

/**
* @brief Get a vector of possible collision object pairs
* @todo Should this also filter out links without geometry?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ void load(Archive& ar, tesseract_collision::ContactResult& g, const unsigned int
template <class Archive>
void serialize(Archive& ar, tesseract_collision::ContactResult& g, const unsigned int version); // NOLINT

template <class Archive>
void save(Archive& ar, const tesseract_collision::ContactResultMap& g, const unsigned int version); // NOLINT

template <class Archive>
void load(Archive& ar, tesseract_collision::ContactResultMap& g, const unsigned int version); // NOLINT

template <class Archive>
void serialize(Archive& ar, tesseract_collision::ContactResultMap& g, const unsigned int version); // NOLINT
} // namespace boost::serialization

#endif // TESSERACT_COLLISION_SERIALIZATION_H
141 changes: 131 additions & 10 deletions tesseract_collision/core/include/tesseract_collision/core/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,137 @@ struct ContactResult
};

using ContactResultVector = tesseract_common::AlignedVector<ContactResult>;
using ContactResultMap = tesseract_common::AlignedMap<std::pair<std::string, std::string>, ContactResultVector>;
class ContactResultMap
{
public:
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
using KeyType = std::pair<std::string, std::string>;
using MappedType = ContactResultVector;
using ContainerType = tesseract_common::AlignedMap<KeyType, MappedType>;
using ConstReferenceType = typename tesseract_common::AlignedMap<KeyType, MappedType>::const_reference;
using ConstIteratorType = typename tesseract_common::AlignedMap<KeyType, MappedType>::const_iterator;
using PairType = typename std::pair<const KeyType, MappedType>;
using FilterFn = std::function<void(PairType&)>;

/**
* @brief Add contact results for the provided key
* @param key The key to append the results to
* @param result The results to add
*/
ContactResult& addContactResult(const KeyType& key, ContactResult result);

/**
* @brief Add contact results for the provided key
* @param key The key to append the results to
* @param result The results to add
*/
ContactResult& addContactResult(const KeyType& key, const MappedType& results);

/**
* @brief Set contact results for the provided key
* @param key The key to append the results to
* @param result The results to add
*/
ContactResult& setContactResult(const KeyType& key, ContactResult result);

/**
* @brief Set contact results for the provided key
* @param key The key to append the results to
* @param result The results to add
*/
ContactResult& setContactResult(const KeyType& key, const MappedType& results);

/**
* @brief This processes interpolated contact results by updating the cc_time and cc_type and then adds the result
* @details This is copied from the trajopt utility processInterpolatedCollisionResults
* @param sub_segment_results The interpolated results to process
* @param sub_segment_index The current sub segment index
* @param sub_segment_last_index The last sub segment index
* @param active_link_names The active link names
* @param segment_dt The segment dt
* @param discrete If discrete contact checker was used
* @param filter An option filter to exclude results
*/
void addInterpolatedCollisionResults(ContactResultMap& sub_segment_results,
int sub_segment_index,
int sub_segment_last_index,
const std::vector<std::string>& active_link_names,
double segment_dt,
bool discrete,
const ContactResultMap::FilterFn& filter = nullptr);

// Flatten functions
void flattenMoveResults(ContactResultVector& v);
void flattenCopyResults(ContactResultVector& v) const;
void flattenWrapperResults(std::vector<std::reference_wrapper<ContactResult>>& v);
void flattenWrapperResults(std::vector<std::reference_wrapper<const ContactResult>>& v) const;

/**
* @brief Filter out results using the provided function
* @param fn The filter function
*/
void filter(const FilterFn& filter);

/**
* @brief Get the total number of contact results storted
* @return The number of contact results
*/
long count() const;

/**
* @brief Get the size of the map
* @details This loops over the internal map and counts entries with contacts
* @return The number of entries with contacts
*/
std::size_t size() const;

/**
* @brief Check if results are present
* @return
*/
bool empty() const;

/**
* @brief This is a consurvative clear.
* @details This does not call clear on the internal map but instead loops over each link pair entry and calls clear
* on the underlying vector. This way the vector capacity remains the same to avoid uneccessary heap allocation for
* subsequent contact requests.
* @note Use release to fully clear the internal data structure
*/
void clear();

/** @brief Fully clear all internal data */
void release();

/**
* @brief Get the underlying container
* @warning Do not use this for anything other than debugging or serialization
*/
const ContainerType& getContainer() const;

///////////////
// Iterators //
///////////////
/** @brief returns an iterator to the beginning */
ConstIteratorType begin() const;
/** @brief returns an iterator to the end */
ConstIteratorType end() const;
/** @brief returns an iterator to the beginning */
ConstIteratorType cbegin() const;
/** @brief returns an iterator to the end */
ConstIteratorType cend() const;

////////////////////
// Element Access //
////////////////////
/** @brief access specified element with bounds checking */
const ContactResultVector& at(const KeyType& key) const;
ConstIteratorType find(const KeyType& key) const;

private:
ContainerType data_;
long cnt_{ 0 };
};

/**
* @brief Should return true if contact results are valid, otherwise false.
Expand Down Expand Up @@ -163,15 +293,6 @@ struct ContactRequest
ContactRequest(ContactTestType type = ContactTestType::ALL);
};

std::size_t flattenMoveResults(ContactResultMap&& m, ContactResultVector& v);

std::size_t flattenCopyResults(const ContactResultMap& m, ContactResultVector& v);

std::size_t flattenWrapperResults(ContactResultMap& m, std::vector<std::reference_wrapper<ContactResult>>& v);

std::size_t flattenWrapperResults(const ContactResultMap& m,
std::vector<std::reference_wrapper<const ContactResult>>& v);

/**
* @brief This data is intended only to be used internal to the collision checkers as a container and should not
* be externally used by other libraries or packages.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,15 @@ static void BM_LARGE_DATASET_MULTILINK(benchmark::State& state,
checker->setCollisionMarginData(CollisionMarginData(0.1));
checker->setCollisionObjectsTransform(location);

ContactResultMap result;
ContactResultVector result_vector;

for (auto _ : state) // NOLINT
{
ContactResultMap result;
result.clear();
result_vector.clear();
checker->contactTest(result, ContactTestType::ALL);
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);
}
}

Expand Down Expand Up @@ -181,14 +182,15 @@ static void BM_LARGE_DATASET_SINGLELINK(benchmark::State& state,
checker->setCollisionMarginData(CollisionMarginData(0.1));
// checker->setCollisionObjectsTransform(location);

ContactResultMap result;
ContactResultVector result_vector;

for (auto _ : state) // NOLINT
{
ContactResultMap result;
result.clear();
result_vector.clear();
checker->contactTest(result, ContactTestType::ALL);
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ inline void runTest(ContinuousContactManager& checker)
checker.contactTest(result, ContactRequest(t));

ContactResultVector result_vector;
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);

EXPECT_TRUE(!result_vector.empty());
EXPECT_NEAR(result_vector[0].distance, -0.2475, 0.001);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
checker.contactTest(result, ContactRequest(test_type));

ContactResultVector result_vector;
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);

EXPECT_TRUE(!result_vector.empty());
EXPECT_NEAR(result_vector[0].distance, -1.30, 0.001);
Expand Down Expand Up @@ -183,7 +183,6 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
////////////////////////////////////////////////
{
location["box_link"].translation() = Eigen::Vector3d(1.60, 0, 0);
result = ContactResultMap();
result.clear();
result_vector.clear();

Expand All @@ -192,7 +191,7 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
tesseract_common::VectorIsometry3d transforms = { location["box_link"] };
checker.setCollisionObjectsTransform(names, transforms);
checker.contactTest(result, test_type);
flattenCopyResults(result, result_vector);
result.flattenCopyResults(result_vector);

EXPECT_TRUE(result_vector.empty());
}
Expand All @@ -207,7 +206,6 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
EXPECT_EQ(checker.getCollisionMarginData().getMaxCollisionMargin(), 1.7);
EXPECT_NEAR(checker.getCollisionMarginData().getPairCollisionMargin("box_link", "second_box_link"), 0.1, 1e-5);
location["box_link"].translation() = Eigen::Vector3d(1.60, 0, 0);
result = ContactResultMap();
result.clear();
result_vector.clear();

Expand All @@ -216,15 +214,14 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
tesseract_common::VectorIsometry3d transforms = { location["box_link"] };
checker.setCollisionObjectsTransform(names, transforms);
checker.contactTest(result, test_type);
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);

EXPECT_TRUE(result_vector.empty());
}
/////////////////////////////////////////////
// Test object inside the contact distance only for this link pair
/////////////////////////////////////////////
{
result = ContactResultMap();
result.clear();
result_vector.clear();

Expand All @@ -235,7 +232,7 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
EXPECT_NEAR(checker.getCollisionMarginData().getPairCollisionMargin("box_link", "second_box_link"), 0.25, 1e-5);
EXPECT_NEAR(checker.getCollisionMarginData().getMaxCollisionMargin(), 0.25, 1e-5);
checker.contactTest(result, ContactRequest(test_type));
flattenCopyResults(result, result_vector);
result.flattenCopyResults(result_vector);

EXPECT_TRUE(!result_vector.empty());
EXPECT_NEAR(result_vector[0].distance, 0.1, 0.001);
Expand Down Expand Up @@ -266,14 +263,13 @@ inline void runTestTyped(DiscreteContactManager& checker, ContactTestType test_t
// Test object inside the contact distance
/////////////////////////////////////////////
{
result = ContactResultMap();
result.clear();
result_vector.clear();

checker.setCollisionMarginData(CollisionMarginData(0.25));
EXPECT_NEAR(checker.getCollisionMarginData().getMaxCollisionMargin(), 0.25, 1e-5);
checker.contactTest(result, ContactRequest(test_type));
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);

EXPECT_TRUE(!result_vector.empty());
EXPECT_NEAR(result_vector[0].distance, 0.1, 0.001);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ inline void runTest(DiscreteContactManager& checker)
checker.contactTest(result, ContactRequest(ContactTestType::CLOSEST));

ContactResultVector result_vector;
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);

EXPECT_TRUE(!result_vector.empty());
EXPECT_NEAR(result_vector[0].distance, -0.55, 0.0001);
Expand Down Expand Up @@ -163,27 +163,25 @@ inline void runTest(DiscreteContactManager& checker)
// Test object is out side the contact distance
////////////////////////////////////////////////
location["capsule_link"].translation() = Eigen::Vector3d(0, 0, 1);
result = ContactResultMap();
result.clear();
result_vector.clear();
checker.setCollisionObjectsTransform(location);

checker.contactTest(result, ContactRequest(ContactTestType::CLOSEST));
flattenCopyResults(result, result_vector);
result.flattenCopyResults(result_vector);

EXPECT_TRUE(result_vector.empty());

/////////////////////////////////////////////
// Test object inside the contact distance
/////////////////////////////////////////////
result = ContactResultMap();
result.clear();
result_vector.clear();

checker.setCollisionMarginData(CollisionMarginData(0.251));
EXPECT_NEAR(checker.getCollisionMarginData().getMaxCollisionMargin(), 0.251, 1e-5);
checker.contactTest(result, ContactRequest(ContactTestType::CLOSEST));
flattenMoveResults(std::move(result), result_vector);
result.flattenMoveResults(result_vector);

EXPECT_TRUE(!result_vector.empty());
EXPECT_NEAR(result_vector[0].distance, 0.125, 0.001);
Expand Down
Loading