Skip to content

Commit

Permalink
Add goal pose to CriticData
Browse files Browse the repository at this point in the history
  • Loading branch information
redvinaa committed Dec 20, 2024
1 parent 507b365 commit 4465836
Show file tree
Hide file tree
Showing 21 changed files with 145 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ namespace mppi

/**
* @struct mppi::CriticData
* @brief Data to pass to critics for scoring, including state, trajectories, path, costs, and
* important parameters to share
* @brief Data to pass to critics for scoring, including state, trajectories,
* pruned path, global goal, costs, and important parameters to share
*/
struct CriticData
{
const models::State & state;
const models::Trajectories & trajectories;
const models::Path & path;
const geometry_msgs::msg::Pose & goal;

xt::xtensor<float, 1> & costs;
float & model_dt;
Expand Down
11 changes: 7 additions & 4 deletions nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Optimizer
geometry_msgs::msg::TwistStamped evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed, const nav_msgs::msg::Path & plan,
nav2_core::GoalChecker * goal_checker);
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Get the trajectories generated in a cycle for visualization
Expand Down Expand Up @@ -132,7 +132,8 @@ class Optimizer
void prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker);
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Obtain the main controller's parameters
Expand Down Expand Up @@ -250,10 +251,12 @@ class Optimizer
std::array<mppi::models::Control, 4> control_history_;
models::Trajectories generated_trajectories_;
models::Path path_;
geometry_msgs::msg::Pose goal_;
xt::xtensor<float, 1> costs_;

CriticData critics_data_ =
{state_, generated_trajectories_, path_, costs_, settings_.model_dt, false, nullptr, nullptr,
CriticData critics_data_ = {
state_, generated_trajectories_, path_, goal_,
costs_, settings_.model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt}; /// Caution, keep references

rclcpp::Logger logger_{rclcpp::get_logger("MPPIController")};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class PathHandler
*/
nav_msgs::msg::Path transformPath(const geometry_msgs::msg::PoseStamped & robot_pose);

/**
* @brief Get the global goal pose transformed to the local frame
* @return Transformed goal pose
*/
geometry_msgs::msg::PoseStamped getTransformedGoal();

protected:
/**
* @brief Transform a pose to another frame
Expand Down
28 changes: 28 additions & 0 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ inline bool withinPositionGoalTolerance(

/**
* @brief Check if the robot pose is within tolerance to the goal
*
* Deprecated!
* This uses the end of pruned path as the goal, if you want to use
* the global goal, use the other overload
*
* @param pose_tolerance Pose tolerance to use
* @param robot Pose of robot
* @param path Path to retreive goal pose from
Expand Down Expand Up @@ -257,6 +262,29 @@ inline bool withinPositionGoalTolerance(
return false;
}

/**
* @brief Check if the robot pose is within tolerance to the goal
* @param pose_tolerance Pose tolerance to use
* @param robot Pose of robot
* @param path Path to retreive goal pose from
* @return bool If robot is within tolerance to the goal
*/
inline bool withinPositionGoalTolerance(
float pose_tolerance, const CriticData & data)
{
const double & dist_sq =
std::pow(data.goal.position.x - data.state.pose.pose.position.x, 2) +
std::pow(data.goal.position.y - data.state.pose.pose.position.y, 2);

const auto pose_tolerance_sq = pose_tolerance * pose_tolerance;

if (dist_sq < pose_tolerance_sq) {
return true;
}

return false;
}

/**
* @brief normalize
* Normalizes the angle to be -M_PI circle to +M_PI circle
Expand Down
4 changes: 3 additions & 1 deletion nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
last_time_called_ = clock_->now();

std::lock_guard<std::mutex> param_lock(*parameters_handler_->getLock());
geometry_msgs::msg::Pose goal = path_handler_.getTransformedGoal().pose;

nav_msgs::msg::Path transformed_plan = path_handler_.transformPath(robot_pose);

nav2_costmap_2d::Costmap2D * costmap = costmap_ros_->getCostmap();
std::unique_lock<nav2_costmap_2d::Costmap2D::mutex_t> costmap_lock(*(costmap->getMutex()));

geometry_msgs::msg::TwistStamped cmd =
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal_checker);
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal, goal_checker);

#ifdef BENCHMARK_TESTING
auto end = std::chrono::system_clock::now();
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/cost_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void CostCritic::score(CriticData & data)

// If near the goal, don't apply the preferential term since the goal is near obstacles
bool near_goal = false;
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(near_goal_distance_, data)) {
near_goal = true;
}

Expand Down
3 changes: 1 addition & 2 deletions nav2_mppi_controller/src/critics/goal_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ void GoalAngleCritic::initialize()

void GoalAngleCritic::score(CriticData & data)
{
if (!enabled_ || !utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || !utils::withinPositionGoalTolerance(threshold_to_consider_, data))
{
return;
}
Expand Down
9 changes: 3 additions & 6 deletions nav2_mppi_controller/src/critics/goal_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,13 @@ void GoalCritic::initialize()

void GoalCritic::score(CriticData & data)
{
if (!enabled_ || !utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || !utils::withinPositionGoalTolerance(threshold_to_consider_, data))
{
return;
}

const auto goal_idx = data.path.x.shape(0) - 1;

const auto goal_x = data.path.x(goal_idx);
const auto goal_y = data.path.y(goal_idx);
const auto & goal_x = data.goal.position.x;
const auto & goal_y = data.goal.position.y;

const auto traj_x = xt::view(data.trajectories.x, xt::all(), xt::all());
const auto traj_y = xt::view(data.trajectories.y, xt::all(), xt::all());
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/obstacles_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void ObstaclesCritic::score(CriticData & data)

// If near the goal, don't apply the preferential term since the goal is near obstacles
bool near_goal = false;
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(near_goal_distance_, data)) {
near_goal = true;
}

Expand Down
3 changes: 1 addition & 2 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ void PathAlignCritic::initialize()
void PathAlignCritic::score(CriticData & data)
{
// Don't apply close to goal, let the goal critics take over
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(threshold_to_consider_, data))
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ void PathAlignLegacyCritic::initialize()
void PathAlignLegacyCritic::score(CriticData & data)
{
// Don't apply close to goal, let the goal critics take over
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(threshold_to_consider_, data))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/path_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void PathAngleCritic::score(CriticData & data)
return;
}

if (utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(threshold_to_consider_, data)) {
return;
}

Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/path_follow_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void PathFollowCritic::initialize()
void PathFollowCritic::score(CriticData & data)
{
if (!enabled_ || data.path.x.shape(0) < 2 ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(threshold_to_consider_, data))
{
return;
}
Expand Down
3 changes: 1 addition & 2 deletions nav2_mppi_controller/src/critics/prefer_forward_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ void PreferForwardCritic::initialize()
void PreferForwardCritic::score(CriticData & data)
{
using xt::evaluation_strategy::immediate;
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(threshold_to_consider_, data))
{
return;
}
Expand Down
12 changes: 9 additions & 3 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ void Optimizer::reset()
geometry_msgs::msg::TwistStamped Optimizer::evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker)
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal,
nav2_core::GoalChecker * goal_checker)
{
prepare(robot_pose, robot_speed, plan, goal_checker);
prepare(robot_pose, robot_speed, plan, goal, goal_checker);

do {
optimize();
Expand Down Expand Up @@ -183,11 +185,15 @@ bool Optimizer::fallback(bool fail)
void Optimizer::prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker)
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal,
nav2_core::GoalChecker * goal_checker)
{
state_.pose = robot_pose;
state_.speed = robot_speed;
path_ = utils::toTensor(plan);
goal_ = goal;

costs_.fill(0);

critics_data_.fail_flag = false;
Expand Down
14 changes: 14 additions & 0 deletions nav2_mppi_controller/src/path_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ void PathHandler::prunePlan(nav_msgs::msg::Path & plan, const PathIterator end)
plan.poses.erase(plan.poses.begin(), end);
}

geometry_msgs::msg::PoseStamped PathHandler::getTransformedGoal()
{
auto goal = global_plan_.poses.back();
goal.header.stamp = rclcpp::Time(0);
if (goal.header.frame_id.empty()) {
throw std::runtime_error("Goal pose has an empty frame_id");
}
geometry_msgs::msg::PoseStamped transformed_goal;
if (!transformPose(costmap_->getGlobalFrameID(), goal, transformed_goal)) {
throw std::runtime_error("Unable to transform goal pose into costmap frame");
}
return transformed_goal;
}

bool PathHandler::isWithinInversionTolerances(const geometry_msgs::msg::PoseStamped & robot_pose)
{
// Keep full path if we are within tolerance of the inversion pose
Expand Down
3 changes: 2 additions & 1 deletion nav2_mppi_controller/test/critic_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ TEST(CriticManagerTests, BasicCriticOperations)
models::ControlSequence control_sequence;
models::Trajectories generated_trajectories;
models::Path path;
geometry_msgs::msg::Pose goal;
xt::xtensor<float, 1> costs;
float model_dt = 0.1;
CriticData data =
{state, generated_trajectories, path, costs, model_dt, false, nullptr, nullptr,
{state, generated_trajectories, path, goal, costs, model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt};

data.fail_flag = true;
Expand Down
Loading

0 comments on commit 4465836

Please sign in to comment.