Skip to content

Commit

Permalink
Nonlinear program example refactor (#10256)
Browse files Browse the repository at this point in the history
Refactor nonlinear_program_examples and other code so that it compiles
  • Loading branch information
hongkai-dai authored Dec 21, 2018
1 parent 1e8c81b commit 9eeb19f
Show file tree
Hide file tree
Showing 16 changed files with 460 additions and 294 deletions.
1 change: 1 addition & 0 deletions solvers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ drake_cc_googletest(
":mathematical_program",
":mathematical_program_test_util",
":optimization_examples",
":solve",
"//common/test_utilities:eigen_matrix_compare",
],
)
Expand Down
2 changes: 2 additions & 0 deletions solvers/gurobi_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,8 @@ void GurobiSolver::Solve(const MathematicalProgram& prog,
error = GRBoptimize(model);
}

result->set_solver_id(GurobiSolver::id());

SolutionResult solution_result = SolutionResult::kUnknownError;

GurobiSolverDetails& solver_details =
Expand Down
9 changes: 5 additions & 4 deletions solvers/test/gurobi_solver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,13 @@ GTEST_TEST(GurobiTest, TestCallbacks) {
solver.AddMipNodeCallback(mip_node_callback_function_wrapper);
solver.AddMipSolCallback(mip_sol_callback_function_wrapper);

SolutionResult result = solver.Solve(prog);
EXPECT_EQ(result, SolutionResult::kSolutionFound);
const auto& x_value = prog.GetSolution(x);
MathematicalProgramResult result;
solver.Solve(prog, {}, {}, &result);
EXPECT_EQ(result.get_solution_result(), SolutionResult::kSolutionFound);
const auto& x_value = prog.GetSolution(x, result);
EXPECT_TRUE(CompareMatrices(x_value, x_expected, 1E-6,
MatrixCompareType::absolute));
ExpectSolutionCostAccurate(prog, 1E-6);
ExpectSolutionCostAccurate(prog, result, 1E-6);
EXPECT_TRUE(cb_info.mip_sol_callback_called);
EXPECT_TRUE(cb_info.mip_node_callback_called);
}
Expand Down
47 changes: 29 additions & 18 deletions solvers/test/linear_program_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <gtest/gtest.h>

#include "drake/common/test_utilities/eigen_matrix_compare.h"
#include "drake/solvers/solver_type_converter.h"
#include "drake/solvers/test/mathematical_program_test_util.h"

using Eigen::Vector4d;
Expand Down Expand Up @@ -61,16 +62,17 @@ LinearFeasibilityProgram::LinearFeasibilityProgram(
}
}

void LinearFeasibilityProgram::CheckSolution(SolverType) const {
auto x_val = prog()->GetSolution(x_);
void LinearFeasibilityProgram::CheckSolution(
const MathematicalProgramResult& result) const {
auto x_val = prog()->GetSolution(x_, result);
Vector3d A_times_x(x_val(0) + 2 * x_val(1) + 3 * x_val(2),
x_val(1) - 2 * x_val(2), 0);
EXPECT_GE(A_times_x(0), 0 - 1e-10);
EXPECT_LE(A_times_x(0), 10 + 1e-10);
EXPECT_LE(A_times_x(1), 3 + 1E-10);
EXPECT_LE(A_times_x(2), 0 + 1E-10);
EXPECT_GE(A_times_x(2), 0 - 1E-10);
EXPECT_GE(prog()->GetSolution(x_(1)), 1 - 1E-10);
EXPECT_GE(prog()->GetSolution(x_(1), result), 1 - 1E-10);
}

LinearProgram0::LinearProgram0(CostForm cost_form,
Expand Down Expand Up @@ -130,11 +132,13 @@ LinearProgram0::LinearProgram0(CostForm cost_form,
}
}

void LinearProgram0::CheckSolution(SolverType solver_type) const {
double tol = GetSolverSolutionDefaultCompareTolerance(solver_type);
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_), x_expected_, tol,
void LinearProgram0::CheckSolution(
const MathematicalProgramResult& result) const {
const double tol = GetSolverSolutionDefaultCompareTolerance(
SolverTypeConverter::IdToType(result.get_solver_id()).value());
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_, result), x_expected_, tol,
MatrixCompareType::absolute));
ExpectSolutionCostAccurate(*prog(), tol);
ExpectSolutionCostAccurate(*prog(), result, tol);
}

LinearProgram1::LinearProgram1(CostForm cost_form,
Expand Down Expand Up @@ -171,11 +175,13 @@ LinearProgram1::LinearProgram1(CostForm cost_form,
}
}

void LinearProgram1::CheckSolution(SolverType solver_type) const {
double tol = GetSolverSolutionDefaultCompareTolerance(solver_type);
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_), x_expected_, tol,
void LinearProgram1::CheckSolution(
const MathematicalProgramResult& result) const {
const double tol = GetSolverSolutionDefaultCompareTolerance(
SolverTypeConverter::IdToType(result.get_solver_id()).value());
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_, result), x_expected_, tol,
MatrixCompareType::absolute));
ExpectSolutionCostAccurate(*prog(), tol);
ExpectSolutionCostAccurate(*prog(), result, tol);
}

LinearProgram2::LinearProgram2(CostForm cost_form,
Expand Down Expand Up @@ -259,11 +265,13 @@ LinearProgram2::LinearProgram2(CostForm cost_form,
}
}

void LinearProgram2::CheckSolution(SolverType solver_type) const {
double tol = GetSolverSolutionDefaultCompareTolerance(solver_type);
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_), x_expected_, tol,
void LinearProgram2::CheckSolution(
const MathematicalProgramResult& result) const {
const double tol = GetSolverSolutionDefaultCompareTolerance(
SolverTypeConverter::IdToType(result.get_solver_id()).value());
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_, result), x_expected_, tol,
MatrixCompareType::absolute));
ExpectSolutionCostAccurate(*prog(), tol);
ExpectSolutionCostAccurate(*prog(), result, tol);
}

LinearProgram3::LinearProgram3(CostForm cost_form,
Expand Down Expand Up @@ -331,8 +339,11 @@ LinearProgram3::LinearProgram3(CostForm cost_form,
}
}

void LinearProgram3::CheckSolution(SolverType solver_type) const {
void LinearProgram3::CheckSolution(
const MathematicalProgramResult& result) const {
// Mosek has a looser tolerance.
const SolverType solver_type =
SolverTypeConverter::IdToType(result.get_solver_id()).value();
double tol = GetSolverSolutionDefaultCompareTolerance(solver_type);
if (solver_type == SolverType::kMosek) {
tol = 1E-6;
Expand All @@ -342,9 +353,9 @@ void LinearProgram3::CheckSolution(SolverType solver_type) const {
if (solver_type == SolverType::kIpopt) {
cost_tol = 1E-5;
}
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_), x_expected_, tol,
EXPECT_TRUE(CompareMatrices(prog()->GetSolution(x_, result), x_expected_, tol,
MatrixCompareType::absolute));
ExpectSolutionCostAccurate(*prog(), cost_tol);
ExpectSolutionCostAccurate(*prog(), result, cost_tol);
}

LinearProgramTest::LinearProgramTest() {
Expand Down
10 changes: 5 additions & 5 deletions solvers/test/linear_program_examples.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class LinearFeasibilityProgram : public OptimizationProgram {

~LinearFeasibilityProgram() override {};

void CheckSolution(SolverType solver_type) const override;
void CheckSolution(const MathematicalProgramResult& result) const override;

private:
VectorDecisionVariable<3> x_;
Expand All @@ -47,7 +47,7 @@ class LinearProgram0 : public OptimizationProgram {

~LinearProgram0() override {};

void CheckSolution(SolverType solver_type) const override;
void CheckSolution(const MathematicalProgramResult& result) const override;

private:
VectorDecisionVariable<2> x_;
Expand All @@ -68,7 +68,7 @@ class LinearProgram1 : public OptimizationProgram {

~LinearProgram1() override {};

void CheckSolution(SolverType solver_type) const override;
void CheckSolution(const MathematicalProgramResult& result) const override;

private:
VectorDecisionVariable<2> x_;
Expand Down Expand Up @@ -96,7 +96,7 @@ class LinearProgram2 : public OptimizationProgram {

~LinearProgram2() override {};

void CheckSolution(SolverType solver_type) const override;
void CheckSolution(const MathematicalProgramResult& result) const override;

private:
VectorDecisionVariable<4> x_;
Expand All @@ -121,7 +121,7 @@ class LinearProgram3 : public OptimizationProgram {

~LinearProgram3() override {};

void CheckSolution(SolverType solver_type) const override;
void CheckSolution(const MathematicalProgramResult& result) const override;

private:
VectorDecisionVariable<3> x_;
Expand Down
7 changes: 4 additions & 3 deletions solvers/test/linear_system_solver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <gtest/gtest.h>

#include "drake/common/test_utilities/eigen_matrix_compare.h"
#include "drake/solvers/solve.h"
#include "drake/solvers/test/mathematical_program_test_util.h"
#include "drake/solvers/test/optimization_examples.h"

Expand All @@ -12,9 +13,9 @@ namespace test {

namespace {
void TestLinearSystemExample(LinearSystemExample1* example) {
example->prog()->Solve();
CheckSolver(*(example->prog()), LinearSystemSolver::id());
example->CheckSolution();
const MathematicalProgramResult result = Solve(*(example->prog()), {}, {});
EXPECT_EQ(result.get_solution_result(), SolutionResult::kSolutionFound);
example->CheckSolution(result);
}
} // namespace

Expand Down
14 changes: 9 additions & 5 deletions solvers/test/mathematical_program_test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ void CheckSolver(const MathematicalProgram& prog, SolverId desired_solver_id) {
EXPECT_EQ(*solver_id, desired_solver_id);
}

void RunSolver(MathematicalProgram* prog,
const MathematicalProgramSolverInterface& solver) {
MathematicalProgramResult RunSolver(
const MathematicalProgram& prog,
const MathematicalProgramSolverInterface& solver,
const optional<Eigen::VectorXd>& initial_guess) {
if (!solver.available()) {
throw std::runtime_error(
"Solver " + solver.solver_id().name() + " is not available");
}

SolutionResult result = solver.Solve(*prog);
EXPECT_EQ(result, SolutionResult::kSolutionFound);
if (result != SolutionResult::kSolutionFound) {
MathematicalProgramResult result;
solver.Solve(prog, initial_guess, {}, &result);
EXPECT_EQ(result.get_solution_result(), SolutionResult::kSolutionFound);
if (result.get_solution_result() != SolutionResult::kSolutionFound) {
throw std::runtime_error(
"Solver " + solver.solver_id().name() + " fails to find the solution");
}
return result;
}
} // namespace test
} // namespace solvers
Expand Down
6 changes: 4 additions & 2 deletions solvers/test/mathematical_program_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ void CheckSolver(const MathematicalProgram& prog, SolverId desired_solver_id);
/// not find a solution, stop immediately via an exception. (Were we to
/// continue, testing statements that examine the results would be likely to
/// fail with confusing messages, so best to avoid them entirely.)
void RunSolver(MathematicalProgram* prog,
const MathematicalProgramSolverInterface& solver);
MathematicalProgramResult RunSolver(
const MathematicalProgram& prog,
const MathematicalProgramSolverInterface& solver,
const optional<Eigen::VectorXd>& initial_guess = {});

/// Determine if two bindings are the same. Two bindings are the same if
///
Expand Down
8 changes: 4 additions & 4 deletions solvers/test/mixed_integer_optimization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ GTEST_TEST(TestMixedIntegerOptimization, TestMixedIntegerLinearProgram1) {
prog.AddLinearConstraint(a2, 1, std::numeric_limits<double>::infinity(),
x.head<2>());

RunSolver(&prog, *solver);
const MathematicalProgramResult result = RunSolver(prog, *solver);

Eigen::Vector3d x_expected(1, 0, 1);
const auto& x_value = prog.GetSolution(x);
const auto& x_value = prog.GetSolution(x, result);
EXPECT_TRUE(CompareMatrices(x_value, x_expected, 1E-6,
MatrixCompareType::absolute));
}
Expand All @@ -67,10 +67,10 @@ GTEST_TEST(TestMixedIntegerOptimization, TestMixedIntegerLinearProgram2) {
prog.AddLinearConstraint(a1, 1.8, std::numeric_limits<double>::infinity(),
x);

RunSolver(&prog, *solver);
const MathematicalProgramResult result = RunSolver(prog, *solver);

Eigen::Vector3d x_expected(1, 1, 1);
const auto& x_value = prog.GetSolution(x);
const auto& x_value = prog.GetSolution(x, result);
EXPECT_TRUE(CompareMatrices(x_value, x_expected, 1E-6,
MatrixCompareType::absolute));
}
Expand Down
Loading

0 comments on commit 9eeb19f

Please sign in to comment.