Skip to content

Commit

Permalink
changing MPPI's SG filter to 9-point formulation (prev. 5) (ros-navig…
Browse files Browse the repository at this point in the history
…ation#3444)

* changing filter to 9

* fix tests
  • Loading branch information
SteveMacenski authored Mar 3, 2023
1 parent 8d4f6f4 commit 7aee1e7
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class Optimizer

models::State state_;
models::ControlSequence control_sequence_;
std::array<mppi::models::Control, 2> control_history_;
std::array<mppi::models::Control, 4> control_history_;
models::Trajectories generated_trajectories_;
models::Path path_;
xt::xtensor<float, 1> costs_;
Expand Down
111 changes: 96 additions & 15 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,17 @@ inline double posePointAngle(const geometry_msgs::msg::Pose & pose, double point
*/
inline void savitskyGolayFilter(
models::ControlSequence & control_sequence,
std::array<mppi::models::Control, 2> & control_history,
std::array<mppi::models::Control, 4> & control_history,
const models::OptimizerSettings & settings)
{
// Savitzky-Golay Quadratic, 5-point Coefficients
xt::xarray<float> filter = {-3.0, 12.0, 17.0, 12.0, -3.0};
filter /= 35.0;
// Savitzky-Golay Quadratic, 9-point Coefficients
xt::xarray<float> filter = {-21.0, 14.0, 39.0, 54.0, 59.0, 54.0, 39.0, 14.0, -21.0};
filter /= 231.0;

const unsigned int num_sequences = control_sequence.vx.shape(0);
const unsigned int num_sequences = control_sequence.vx.shape(0) - 1;

// Too short to smooth meaningfully
if (num_sequences < 10) {
if (num_sequences < 20) {
return;
}

Expand All @@ -455,64 +455,145 @@ inline void savitskyGolayFilter(
};

auto applyFilterOverAxis =
[&](xt::xtensor<float, 1> & sequence, const float hist_0, const float hist_1) -> void
[&](xt::xtensor<float, 1> & sequence,
const float hist_0, const float hist_1, const float hist_2, const float hist_3) -> void
{
unsigned int idx = 0;
sequence(idx) = applyFilter(
{
hist_0,
hist_1,
hist_2,
hist_3,
sequence(idx),
sequence(idx + 1),
sequence(idx + 2)});
sequence(idx + 2),
sequence(idx + 3),
sequence(idx + 4)});

idx++;
sequence(idx) = applyFilter(
{
hist_1,
hist_2,
hist_3,
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 2)});
sequence(idx + 2),
sequence(idx + 3),
sequence(idx + 4)});

idx++;
sequence(idx) = applyFilter(
{
hist_2,
hist_3,
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 2),
sequence(idx + 3),
sequence(idx + 4)});

for (idx = 2; idx != num_sequences - 3; idx++) {
idx++;
sequence(idx) = applyFilter(
{
hist_3,
sequence(idx - 3),
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 2),
sequence(idx + 3),
sequence(idx + 4)});

for (idx = 4; idx != num_sequences - 4; idx++) {
sequence(idx) = applyFilter(
{
sequence(idx - 4),
sequence(idx - 3),
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 2)});
sequence(idx + 2),
sequence(idx + 3),
sequence(idx + 4)});
}

idx++;
sequence(idx) = applyFilter(
{
sequence(idx - 4),
sequence(idx - 3),
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 2),
sequence(idx + 3),
sequence(idx + 3)});

idx++;
sequence(idx) = applyFilter(
{
sequence(idx - 4),
sequence(idx - 3),
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 2),
sequence(idx + 2),
sequence(idx + 2)});

idx++;
sequence(idx) = applyFilter(
{
sequence(idx - 4),
sequence(idx - 3),
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx + 1),
sequence(idx + 1),
sequence(idx + 1),
sequence(idx + 1)});

idx++;
sequence(idx) = applyFilter(
{
sequence(idx - 4),
sequence(idx - 3),
sequence(idx - 2),
sequence(idx - 1),
sequence(idx),
sequence(idx),
sequence(idx),
sequence(idx),
sequence(idx)});
};

// Filter trajectories
applyFilterOverAxis(control_sequence.vx, control_history[0].vx, control_history[1].vx);
applyFilterOverAxis(control_sequence.vy, control_history[0].vy, control_history[1].vy);
applyFilterOverAxis(control_sequence.wz, control_history[0].wz, control_history[1].wz);
applyFilterOverAxis(
control_sequence.vx, control_history[0].vx,
control_history[1].vx, control_history[2].vx, control_history[3].vx);
applyFilterOverAxis(
control_sequence.vy, control_history[0].vy,
control_history[1].vy, control_history[2].vy, control_history[3].vy);
applyFilterOverAxis(
control_sequence.wz, control_history[0].wz,
control_history[1].wz, control_history[2].wz, control_history[3].wz);

// Update control history
unsigned int offset = settings.shift_control_sequence ? 1 : 0;
control_history[0] = control_history[1];
control_history[1] = {
control_history[1] = control_history[2];
control_history[2] = control_history[3];
control_history[3] = {
control_sequence.vx(offset),
control_sequence.vy(offset),
control_sequence.wz(offset)};
Expand Down
2 changes: 2 additions & 0 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void Optimizer::reset()
control_sequence_.reset(settings_.time_steps);
control_history_[0] = {0.0, 0.0, 0.0};
control_history_[1] = {0.0, 0.0, 0.0};
control_history_[2] = {0.0, 0.0, 0.0};
control_history_[3] = {0.0, 0.0, 0.0};

costs_ = xt::zeros<float>({settings_.batch_size});
generated_trajectories_.reset(settings_.batch_size, settings_.time_steps);
Expand Down
20 changes: 13 additions & 7 deletions nav2_mppi_controller/test/utils_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,13 @@ TEST(UtilsTests, SmootherTest)
noisey_sequence.wz += noises;
sequence_init = noisey_sequence;

std::array<mppi::models::Control, 2> history, history_init;
std::array<mppi::models::Control, 4> history, history_init;
history[3].vx = 0.1;
history[3].vy = 0.0;
history[3].wz = 0.3;
history[2].vx = 0.1;
history[2].vy = 0.0;
history[2].wz = 0.3;
history[1].vx = 0.1;
history[1].vy = 0.0;
history[1].wz = 0.3;
Expand All @@ -332,14 +338,14 @@ TEST(UtilsTests, SmootherTest)
savitskyGolayFilter(noisey_sequence, history, settings);

// Check history is propogated backward
EXPECT_NEAR(history_init[1].vx, history[0].vx, 0.02);
EXPECT_NEAR(history_init[1].vy, history[0].vy, 0.02);
EXPECT_NEAR(history_init[1].wz, history[0].wz, 0.02);
EXPECT_NEAR(history_init[3].vx, history[2].vx, 0.02);
EXPECT_NEAR(history_init[3].vy, history[2].vy, 0.02);
EXPECT_NEAR(history_init[3].wz, history[2].wz, 0.02);

// Check history element is updated for first command
EXPECT_NEAR(history[1].vx, 0.2, 0.05);
EXPECT_NEAR(history[1].vy, 0.0, 0.02);
EXPECT_NEAR(history[1].wz, 0.23, 0.02);
EXPECT_NEAR(history[3].vx, 0.2, 0.05);
EXPECT_NEAR(history[3].vy, 0.0, 0.035);
EXPECT_NEAR(history[3].wz, 0.23, 0.02);

// Check that path is smoother
float smoothed_val, original_val;
Expand Down

0 comments on commit 7aee1e7

Please sign in to comment.