Skip to content

Commit

Permalink
Move MANN class into the ML component
Browse files Browse the repository at this point in the history
  • Loading branch information
GiulioRomualdi committed Apr 17, 2023
1 parent 3913646 commit 6886bb6
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 59 deletions.
6 changes: 5 additions & 1 deletion cmake/BipedalLocomotionFrameworkDependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ framework_dependent_option(FRAMEWORK_COMPILE_Contact

framework_dependent_option(FRAMEWORK_COMPILE_Planners
"Compile Planners libraries?" ON
"FRAMEWORK_USE_Qhull;FRAMEWORK_USE_casadi;FRAMEWORK_USE_onnxruntime;FRAMEWORK_USE_manif;FRAMEWORK_COMPILE_Math;FRAMEWORK_COMPILE_Contact" OFF)
"FRAMEWORK_USE_Qhull;FRAMEWORK_USE_casadi;FRAMEWORK_USE_manif;FRAMEWORK_COMPILE_Math;FRAMEWORK_COMPILE_Contact" OFF)

framework_dependent_option(FRAMEWORK_COMPILE_ContactModels
"Compile ContactModels library?" ON
Expand Down Expand Up @@ -200,6 +200,10 @@ framework_dependent_option(FRAMEWORK_COMPILE_IK
"Compile IK library?" ON
"FRAMEWORK_COMPILE_System;FRAMEWORK_USE_LieGroupControllers;FRAMEWORK_COMPILE_ManifConversions;FRAMEWORK_USE_manif;FRAMEWORK_USE_OsqpEigen" OFF)

framework_dependent_option(FRAMEWORK_COMPILE_ML
"Compile machine learning libraries?" ON
"FRAMEWORK_USE_onnxruntime;FRAMEWORK_USE_manif" OFF)

framework_dependent_option(FRAMEWORK_COMPILE_SimplifiedModelControllers
"Compile SimplifiedModelControllers library?" ON
"FRAMEWORK_USE_manif" OFF)
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ add_subdirectory(TSID)
add_subdirectory(Perception)
add_subdirectory(IK)
add_subdirectory(SimplifiedModelControllers)
add_subdirectory(ML)
19 changes: 19 additions & 0 deletions src/ML/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (C) 2023 Istituto Italiano di Tecnologia (IIT). All rights reserved.
# This software may be modified and distributed under the terms of the
# BSD-3-Clause license.

if (FRAMEWORK_COMPILE_ML)

set(H_PREFIX include/BipedalLocomotion/ML)

add_bipedal_locomotion_library(
NAME ML
PUBLIC_HEADERS ${H_PREFIX}/MANN.h
SOURCES src/MANN.cpp
PUBLIC_LINK_LIBRARIES Eigen3::Eigen BipedalLocomotion::ParametersHandler BipedalLocomotion::System
PRIVATE_LINK_LIBRARIES BipedalLocomotion::TextLogging onnxruntime::onnxruntime
INSTALLATION_FOLDER ML)

add_subdirectory(tests)

endif()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace BipedalLocomotion
{
namespace Planners
namespace ML
{

/**
Expand Down Expand Up @@ -143,5 +143,5 @@ class MANN : public BipedalLocomotion::System::Advanceable<MANNInput, MANNOutput
std::unique_ptr<Impl> m_pimpl;
};

} // namespace Planners
} // namespace ML
} // namespace BipedalLocomotion
6 changes: 3 additions & 3 deletions src/Planners/src/MANN.cpp → src/ML/src/MANN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
// onnxruntime
#include <onnxruntime_cxx_api.h>

#include <BipedalLocomotion/ML/MANN.h>
#include <BipedalLocomotion/ParametersHandler/IParametersHandler.h>
#include <BipedalLocomotion/Planners/MANN.h>
#include <BipedalLocomotion/System/VariablesHandler.h>
#include <BipedalLocomotion/TextLogging/Logger.h>

using namespace BipedalLocomotion::Planners;
using namespace BipedalLocomotion::ML;
using namespace BipedalLocomotion;

struct MANN::Impl
Expand Down Expand Up @@ -123,7 +123,7 @@ bool MANN::Impl::populateInput(const MANNInput& input)
// y1, y2, ......................................y12]
Eigen::Ref<const Eigen::MatrixXd> tmp
= input.basePositionTrajectory.rightCols(input.basePositionTrajectory.cols() / 2);
double trajectoryLength
const double trajectoryLength
= (tmp.rightCols(tmp.cols() - 1) - tmp.leftCols(tmp.cols() - 1)).colwise().norm().sum();

bool ok = populateVectorData("joint_velocities", input.jointVelocities);
Expand Down
11 changes: 11 additions & 0 deletions src/ML/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (C) 2023 Istituto Italiano di Tecnologia (IIT). All rights reserved.
# This software may be modified and distributed under the terms of the
# BSD-3-Clause license.

include_directories(${CMAKE_CURRENT_BINARY_DIR})
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/FolderPath.h.in" "${CMAKE_CURRENT_BINARY_DIR}/MANNModelFolderPath.h" @ONLY)

add_bipedal_test(
NAME MANN
SOURCES MANNTest.cpp
LINKS BipedalLocomotion::ML)
File renamed without changes.
136 changes: 94 additions & 42 deletions src/Planners/tests/MANNTest.cpp → src/ML/tests/MANNTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
// Catch2
#include <catch2/catch.hpp>

#include <BipedalLocomotion/ML/MANN.h>
#include <BipedalLocomotion/ParametersHandler/StdImplementation.h>
#include <BipedalLocomotion/Planners/MANN.h>
#include <memory>

#include <MANNModelFolderPath.h>

using namespace BipedalLocomotion::Planners;
using namespace BipedalLocomotion::ML;
using namespace BipedalLocomotion::ParametersHandler;

TEST_CASE("MANN")
Expand Down Expand Up @@ -163,7 +163,7 @@ TEST_CASE("MANN")
0.0,
2.7428489205439933e-05,
0.0});
handler->setParameter("input_svd",
handler->setParameter("input_std",
std::vector<double>{0.26721796262751996,
0.17229308341193786,
0.21567989879842053,
Expand Down Expand Up @@ -405,7 +405,7 @@ TEST_CASE("MANN")
0.05994241344425131,
-8.95169574747821e-05,
6.243232025220322e-18});
handler->setParameter("output_svd",
handler->setParameter("output_std",
std::vector<double>{0.05678350766289934,
0.04484781368390023,
0.11198061053544575,
Expand Down Expand Up @@ -520,53 +520,105 @@ TEST_CASE("MANN")
input.jointPositions.resize(32);
input.jointVelocities.resize(32);



std::array<double, 12 * 2> basePositionTemp{0.020465626137500123, -0.028526278638462008, 0.0204657739768999, -0.02852627656493037,
0.020465822166082566, -0.028526210469724575, 0.020465802925311055, -0.028526161526930927,
0.020465772830464538, -0.02852614179540976, 0.0, 0.0,
-0.009585547395551581, -0.02175600333393604, -0.017817201285188853, -0.0332971138031881,
-0.02033611982326177, -0.030715595478538575, -0.016440085949668892, -0.01989833515449035,
-0.008254458531998493, -0.008547476348220699, 0.0, 0.0};
std::array<double, 12 * 2> basePositionTemp{0.020465626137500123,
-0.028526278638462008,
0.0204657739768999,
-0.02852627656493037,
0.020465822166082566,
-0.028526210469724575,
0.020465802925311055,
-0.028526161526930927,
0.020465772830464538,
-0.02852614179540976,
0.0,
0.0,
-0.009585547395551581,
-0.02175600333393604,
-0.017817201285188853,
-0.0332971138031881,
-0.02033611982326177,
-0.030715595478538575,
-0.016440085949668892,
-0.01989833515449035,
-0.008254458531998493,
-0.008547476348220699,
0.0,
0.0};

input.basePositionTrajectory = Eigen::Map<Eigen::MatrixXd>(basePositionTemp.data(), 2, 12);

std::array<double, 12 * 2> facingDirectionTemp{0.9998046173508225, 0.019766818762118485, 0.9998046180476308, 0.019766783517618902,
0.9998046178527036, 0.01976679337697099, 0.9998046175664108, 0.01976680785767637,
0.9998046173199947, 0.019766820321406378, 1.0, 0.0,
1.000195227589745, -0.0019179446719683577, 1.0010723856647088, -0.002920366320472865,
1.0019448535892073, -0.006219539528750589, 1.00220806134242, -0.005959674409629769,
1.0015735374358077, -0.0019812837541796284, 1.0, 0.0};
std::array<double, 12 * 2> facingDirectionTemp{0.9998046173508225,
0.019766818762118485,
0.9998046180476308,
0.019766783517618902,
0.9998046178527036,
0.01976679337697099,
0.9998046175664108,
0.01976680785767637,
0.9998046173199947,
0.019766820321406378,
1.0,
0.0,
1.000195227589745,
-0.0019179446719683577,
1.0010723856647088,
-0.002920366320472865,
1.0019448535892073,
-0.006219539528750589,
1.00220806134242,
-0.005959674409629769,
1.0015735374358077,
-0.0019812837541796284,
1.0,
0.0};
input.facingDirectionTrajectory
= Eigen::Map<Eigen::MatrixXd>(facingDirectionTemp.data(), 2, 12);

std::array<double, 12 * 2> baseVelocitiesTemp{0.02856285193696938, 0.019224769398957234, 0.028562823489490678, 0.019225725173215375,
0.02856269486540722, 0.01922613391891199, 0.028562619171457148, 0.01922618174545223,
0.02856256157025901, 0.019226074516058275, -0.03709010698840217, -0.08405826644833612,
-0.03298855398555943, -0.09948148489499548, -0.022790401966434587, -0.03630620503764005,
-0.006813425765586863, -0.005878922358647386, 0.00036041265075565515, 0.011310742547698205,
0.007901453422695888, 0.0018729868979242416, 0.0, 1.3552527156068805e-20};
std::array<double, 12 * 2> baseVelocitiesTemp{0.02856285193696938,
0.019224769398957234,
0.028562823489490678,
0.019225725173215375,
0.02856269486540722,
0.01922613391891199,
0.028562619171457148,
0.01922618174545223,
0.02856256157025901,
0.019226074516058275,
-0.03709010698840217,
-0.08405826644833612,
-0.03298855398555943,
-0.09948148489499548,
-0.022790401966434587,
-0.03630620503764005,
-0.006813425765586863,
-0.005878922358647386,
0.00036041265075565515,
0.011310742547698205,
0.007901453422695888,
0.0018729868979242416,
0.0,
1.3552527156068805e-20};
input.baseVelocitiesTrajectory = Eigen::Map<Eigen::MatrixXd>(baseVelocitiesTemp.data(), 2, 12);

// 0.026810580122418215

input.jointPositions << -0.08229444762971491, 0.1377980352398303, 0.014799786094367792, -0.17823363484560167,
-0.18445155789555484, -0.13905690533655496, -0.006777036784569264, 0.14304762524836648,
0.031516238814005136, -0.1426330272419353, -0.06098283408307653, -0.1341771055787397,
0.19911560020803554, 0.0328098631623236, 0.024781975683781033, 0.01324623650231824,
0.03895569208112638, -0.5063181281080337, -0.15188637555678797, -0.09926268052516171,
-0.35648960192960166, 0.8564996250726968, 1.3064277172088623, -0.05746928859305393,
0.16694140434265137, -0.1773331496644323, -0.09618105207864167, -0.3637387390826281,
-0.12775099875607665, 0.3775743246078491, -0.06556004402010714, -0.9403902888298035;
input.jointPositions << -0.08229444762971491, 0.1377980352398303, 0.014799786094367792,
-0.17823363484560167, -0.18445155789555484, -0.13905690533655496, -0.006777036784569264,
0.14304762524836648, 0.031516238814005136, -0.1426330272419353, -0.06098283408307653,
-0.1341771055787397, 0.19911560020803554, 0.0328098631623236, 0.024781975683781033,
0.01324623650231824, 0.03895569208112638, -0.5063181281080337, -0.15188637555678797,
-0.09926268052516171, -0.35648960192960166, 0.8564996250726968, 1.3064277172088623,
-0.05746928859305393, 0.16694140434265137, -0.1773331496644323, -0.09618105207864167,
-0.3637387390826281, -0.12775099875607665, 0.3775743246078491, -0.06556004402010714,
-0.9403902888298035;

input.jointVelocities << 0.0982293093461868, 0.12588027548389802, 0.018234994616183622, -0.26306079352524886,
-0.17989701726927035, -0.13636938658177092, -0.018404592896727315, -0.04892431424018461,
0.011816611105509, -0.030103574006318207, -0.054470913637429216, 0.04110477832027632,
0.006149116530581617, 0.0022679536303376294, -0.01772658684486485, 0.0018849445068382577,
0.0027670107680974943, 0.003149718526371251, 0.009372946811675064, -0.019232750841138426,
0.00124201362405774, 0.0019440759412122286, -0.08404139429330826, 2.529428034093097e-05,
0.16792771220207214, -0.02287937013013993, -0.002927608219628366, 0.018047571904430328,
-0.01034426546939903, -0.023162882775068283, 2.750370144352529e-05, 0.0011318349279463291;
input.jointVelocities << 0.0982293093461868, 0.12588027548389802, 0.018234994616183622,
-0.26306079352524886, -0.17989701726927035, -0.13636938658177092, -0.018404592896727315,
-0.04892431424018461, 0.011816611105509, -0.030103574006318207, -0.054470913637429216,
0.04110477832027632, 0.006149116530581617, 0.0022679536303376294, -0.01772658684486485,
0.0018849445068382577, 0.0027670107680974943, 0.003149718526371251, 0.009372946811675064,
-0.019232750841138426, 0.00124201362405774, 0.0019440759412122286, -0.08404139429330826,
2.529428034093097e-05, 0.16792771220207214, -0.02287937013013993, -0.002927608219628366,
0.018047571904430328, -0.01034426546939903, -0.023162882775068283, 2.750370144352529e-05,
0.0011318349279463291;

REQUIRE(mann.setInput(input));
REQUIRE(mann.advance());
Expand Down
File renamed without changes.
5 changes: 2 additions & 3 deletions src/Planners/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ if (FRAMEWORK_COMPILE_Planners)
PUBLIC_HEADERS ${H_PREFIX}/ConvexHullHelper.h ${H_PREFIX}/DCMPlanner.h ${H_PREFIX}/TimeVaryingDCMPlanner.h
${H_PREFIX}/Spline.h ${H_PREFIX}/QuinticSpline.h ${H_PREFIX}/SO3Planner.h
${H_PREFIX}/SO3Planner.tpp ${H_PREFIX}/SwingFootPlanner.h ${H_PREFIX}/CubicSpline.h
${H_PREFIX}/MANN.h
SOURCES src/ConvexHullHelper.cpp src/DCMPlanner.cpp src/TimeVaryingDCMPlanner.cpp src/QuinticSpline.cpp
src/SwingFootPlanner.cpp src/CubicSpline.cpp src/MANN.cpp
src/SwingFootPlanner.cpp src/CubicSpline.cpp
PUBLIC_LINK_LIBRARIES Eigen3::Eigen BipedalLocomotion::ParametersHandler BipedalLocomotion::System BipedalLocomotion::Contacts
PRIVATE_LINK_LIBRARIES Qhull::qhullcpp Qhull::qhull_r casadi iDynTree::idyntree-core BipedalLocomotion::Math BipedalLocomotion::TextLogging onnxruntime::onnxruntime
PRIVATE_LINK_LIBRARIES Qhull::qhullcpp Qhull::qhull_r casadi iDynTree::idyntree-core BipedalLocomotion::Math BipedalLocomotion::TextLogging
INSTALLATION_FOLDER Planners)

add_subdirectory(tests)
Expand Down
8 changes: 0 additions & 8 deletions src/Planners/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,3 @@ add_bipedal_test(
NAME SwingFootPlanner
SOURCES SwingFootPlannerTest.cpp
LINKS BipedalLocomotion::Planners)

include_directories(${CMAKE_CURRENT_BINARY_DIR})
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/FolderPath.h.in" "${CMAKE_CURRENT_BINARY_DIR}/MANNModelFolderPath.h" @ONLY)

add_bipedal_test(
NAME MANN
SOURCES MANNTest.cpp
LINKS BipedalLocomotion::Planners)

0 comments on commit 6886bb6

Please sign in to comment.