Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Humble enjoy mppi critic pub #18

Open
wants to merge 3 commits into
base: humble
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions nav2_mppi_controller/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(XTENSOR_USE_XSIMD 1)
find_package(ament_cmake REQUIRED)
find_package(xtensor REQUIRED)
find_package(xsimd REQUIRED)
find_package(rosidl_default_generators REQUIRED)

include_directories(
include
Expand All @@ -33,6 +34,7 @@ set(dependencies_pkgs
tf2_geometry_msgs
tf2_eigen
tf2_ros
std_msgs
)

foreach(pkg IN LISTS dependencies_pkgs)
Expand All @@ -41,6 +43,11 @@ endforeach()

nav2_package()

rosidl_generate_interfaces(${PROJECT_NAME}
"msg/CriticScores.msg"
DEPENDENCIES std_msgs
)

include(CheckCXXCompilerFlag)

check_cxx_compiler_flag("-mno-avx512f" COMPILER_SUPPORTS_AVX512)
Expand Down Expand Up @@ -119,8 +126,14 @@ if(BUILD_TESTING)
# add_subdirectory(benchmark)
endif()

rosidl_get_typesupport_target(cpp_typesupport_target
${PROJECT_NAME} rosidl_typesupport_cpp)

target_link_libraries(mppi_controller "${cpp_typesupport_target}")

ament_export_libraries(${libraries})
ament_export_dependencies(${dependencies_pkgs})
ament_export_dependencies(rosidl_default_runtime)
ament_export_include_directories(include)
pluginlib_export_plugin_description_file(nav2_core mppic.xml)
pluginlib_export_plugin_description_file(nav2_mppi_controller critics.xml)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "nav2_mppi_controller/tools/trajectory_visualizer.hpp"
#include "nav2_mppi_controller/models/constraints.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "nav2_mppi_controller/msg/critic_scores.hpp"

#include "nav2_core/controller.hpp"
#include "nav2_core/goal_checker.hpp"
Expand Down Expand Up @@ -121,10 +122,14 @@ class MPPIController : public nav2_core::Controller
TrajectoryVisualizer trajectory_visualizer_;

bool visualize_;
bool publish_critics_;

double reset_period_;
// Last time computeVelocityCommands was called
rclcpp::Time last_time_called_;

std::shared_ptr<rclcpp_lifecycle::LifecyclePublisher<nav2_mppi_controller::msg::CriticScores>>
critics_publisher_;
};

} // namespace nav2_mppi_controller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class CriticManager
* @brief Constructor for mppi::CriticManager
*/
CriticManager() = default;

/**
* @brief Virtual Destructor for mppi::CriticManager
*/
virtual ~CriticManager() = default;

/**
* @brief Configure critic manager on bringup and load plugins
* @param parent WeakPtr to node
Expand All @@ -69,6 +69,10 @@ class CriticManager
*/
void evalTrajectoriesScores(CriticData & data) const;

xt::xtensor<float, 1> evalTrajectory(CriticData & data) const;

std::vector<std::string> getCriticNames() const;

protected:
/**
* @brief Get parameters (critics to load)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class Optimizer
*/
xt::xtensor<float, 2> getOptimizedTrajectory();


/**
* @brief Get the critic costs for given trajectory
* @return Names and costs of the critics
*/
xt::xtensor<float, 1> getOptimizationResults();

std::vector<std::string> getCriticNames() const;

/**
* @brief Set the maximum speed based on the speed limits callback
* @param speed_limit Limit of the speed for use
Expand Down
3 changes: 3 additions & 0 deletions nav2_mppi_controller/msg/CriticScores.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
std_msgs/Header header # ROS time that this log message was sent.
std_msgs/String[] critic_names
std_msgs/Float32[] critic_scores
5 changes: 5 additions & 0 deletions nav2_mppi_controller/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_ros</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>

<exec_depend>rosidl_default_runtime</exec_depend>

<depend>rclcpp</depend>
<depend>nav2_common</depend>
Expand All @@ -33,6 +36,8 @@
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>
<test_depend>ament_cmake_gtest</test_depend>

<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
<nav2_core plugin="${prefix}/mppic.xml" />
Expand Down
32 changes: 32 additions & 0 deletions nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void MPPIController::configure(
// Get high-level controller parameters
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(visualize_, "visualize", false);
getParam(publish_critics_, "publish_critics", false);
getParam(reset_period_, "reset_period", 1.0);

// Configure composed objects
Expand All @@ -48,6 +49,11 @@ void MPPIController::configure(
parent_, name_,
costmap_ros_->getGlobalFrameID(), parameters_handler_.get());

if (publish_critics_) {
critics_publisher_ = node->create_publisher<nav2_mppi_controller::msg::CriticScores>(
"/mppi_critic_scores", 1);
}

RCLCPP_INFO(logger_, "Configured MPPI Controller: %s", name_.c_str());
}

Expand All @@ -61,13 +67,15 @@ void MPPIController::cleanup()

void MPPIController::activate()
{
critics_publisher_->on_activate();
trajectory_visualizer_.on_activate();
parameters_handler_->start();
RCLCPP_INFO(logger_, "Activated MPPI Controller: %s", name_.c_str());
}

void MPPIController::deactivate()
{
critics_publisher_->on_deactivate();
trajectory_visualizer_.on_deactivate();
RCLCPP_INFO(logger_, "Deactivated MPPI Controller: %s", name_.c_str());
}
Expand Down Expand Up @@ -110,6 +118,30 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
visualize(std::move(transformed_plan));
}

if (publish_critics_) {
std::vector<std::string> critic_names = optimizer_.getCriticNames();
xt::xtensor<float, 1> critic_costs = optimizer_.getOptimizationResults();

// log critic names and costs
for (size_t i = 0; i < critic_names.size(); i++) {
RCLCPP_DEBUG(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]);
}

// make msg
auto critic_scores_ = std::make_unique<nav2_mppi_controller::msg::CriticScores>();
for (size_t i = 0; i < critic_names.size(); i++) {
std_msgs::msg::String name_msg;
name_msg.data = critic_names[i];
critic_scores_->critic_names.push_back(std::move(name_msg));

std_msgs::msg::Float32 cost_msg;
cost_msg.data = critic_costs[i];
critic_scores_->critic_scores.push_back(std::move(cost_msg));
}
critic_scores_->header.stamp = clock_->now();
critics_publisher_->publish(std::move(critic_scores_));
}

return cmd;
}

Expand Down
25 changes: 25 additions & 0 deletions nav2_mppi_controller/src/critic_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <xtensor/xtensor.hpp>

#include "nav2_mppi_controller/critic_manager.hpp"

namespace mppi
Expand Down Expand Up @@ -64,6 +66,11 @@ std::string CriticManager::getFullName(const std::string & name)
return "mppi::critics::" + name;
}

std::vector<std::string> CriticManager::getCriticNames() const
{
return critic_names_;
}

void CriticManager::evalTrajectoriesScores(
CriticData & data) const
{
Expand All @@ -75,4 +82,22 @@ void CriticManager::evalTrajectoriesScores(
}
}

xt::xtensor<float, 1> CriticManager::evalTrajectory(
CriticData & data) const
{
xt::xtensor<float, 1> critic_scores = xt::zeros<float>({critics_.size()});

for (size_t q = 0; q < critics_.size(); q++) {
if (data.fail_flag) {
break;
}
data.costs = xt::zeros<float>({1});
// log costs values
critics_[q]->score(data);
critic_scores(q) = data.costs[0];
}

return critic_scores;
}

} // namespace mppi
61 changes: 61 additions & 0 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,67 @@ void Optimizer::optimize()
}
}

xt::xtensor<float, 1> Optimizer::getOptimizationResults()
{
const xt::xtensor<float, 2> optimized_trajectory = getOptimizedTrajectory();
xt::xtensor<float, 1> costs = xt::zeros<float>({1});

/*auto size = optimized_trajectory.size(); // size = 6
auto dim = optimized_trajectory.dimension(); // dim = 2
auto shape = optimized_trajectory.shape(); // shape = {2, 3}

//log size, dim, and shape
RCLCPP_INFO(
logger_, "getOptimizedTrajectory() size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

models::Trajectories dummy_trajectories;
/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after creation] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

dummy_trajectories.reset(1, settings_.time_steps);
/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after reset] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

dummy_trajectories.x += xt::view(optimized_trajectory, xt::all(), 0);
dummy_trajectories.y += xt::view(optimized_trajectory, xt::all(), 1);
dummy_trajectories.yaws += xt::view(optimized_trajectory, xt::all(), 2);

/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after valuedump] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

CriticData dummy_data = {
state_, dummy_trajectories, path_, costs, settings_.model_dt,
false, critics_data_.goal_checker, critics_data_.motion_model, std::nullopt, std::nullopt};
// dummy_data.goal_checker = critics_data_.goal_checker;
// dummy_data.motion_model = critics_data_.motion_model;
dummy_data.furthest_reached_path_point.reset();
dummy_data.path_pts_valid.reset();

/*RCLCPP_INFO(
logger_, "dummy_data type: %s",
typeid(dummy_data).name());*/

return critic_manager_.evalTrajectory(dummy_data);
}

std::vector<std::string> Optimizer::getCriticNames() const
{
return critic_manager_.getCriticNames();
}

bool Optimizer::fallback(bool fail)
{
static size_t counter = 0;
Expand Down