Skip to content

Commit

Permalink
Expose a bundle_adjustment function (#146)
Browse files Browse the repository at this point in the history
* Expose a `bundle_adjustment` function
* Formatting
* Cleanup
* Remove overload of increamental_mapping

---------

Signed-off-by: Gaoyang Zhang <gy@blurgy.xyz>
Co-authored-by: Paul-Edouard Sarlin <paul.edouard.sarlin@gmail.com>
  • Loading branch information
blurgyy and sarlinpe authored Sep 19, 2023
1 parent 67d84e1 commit 97484c0
Showing 1 changed file with 91 additions and 34 deletions.
125 changes: 91 additions & 34 deletions pipeline/sfm.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Author: Paul-Edouard Sarlin (skydes)
#include "colmap/exe/sfm.h"

#include "colmap/controllers/bundle_adjustment.h"
#include "colmap/controllers/incremental_mapper.h"
#include "colmap/scene/reconstruction.h"
#include "colmap/sensor/models.h"
Expand Down Expand Up @@ -50,7 +51,6 @@ std::shared_ptr<Reconstruction> triangulate_points(
return reconstruction;
}

// Copied from colmap/exe/sfm.cc
std::map<size_t, std::shared_ptr<Reconstruction>> incremental_mapping(
const py::object database_path_,
const py::object image_path_,
Expand Down Expand Up @@ -109,18 +109,14 @@ std::map<size_t, std::shared_ptr<Reconstruction>> incremental_mapping(
return reconstructions;
}

std::map<size_t, std::shared_ptr<Reconstruction>> incremental_mapping(
const py::object database_path_,
const py::object image_path_,
const py::object output_path_,
const int num_threads,
const int min_num_matches,
const py::object input_path_) {
IncrementalMapperOptions options;
options.num_threads = num_threads;
options.min_num_matches = min_num_matches;
return incremental_mapping(
database_path_, image_path_, output_path_, options, input_path_);
void bundle_adjustment(std::shared_ptr<Reconstruction> reconstruction,
const BundleAdjustmentOptions& options) {
py::gil_scoped_release release;
OptionManager option_manager;
*option_manager.bundle_adjustment = options;
BundleAdjustmentController controller(option_manager, reconstruction);
controller.Start();
PyWait(&controller);
}

void init_sfm(py::module& m) {
Expand Down Expand Up @@ -186,6 +182,83 @@ void init_sfm(py::module& m) {
make_dataclass(PyIncrementalMapperOptions);
auto mapper_options = PyIncrementalMapperOptions().cast<Opts>();

using BAOpts = BundleAdjustmentOptions;
auto PyBALossFunctionType =
py::enum_<BAOpts::LossFunctionType>(m, "LossFunctionType")
.value("TRIVIAL", BAOpts::LossFunctionType::TRIVIAL)
.value("SOFT_L1", BAOpts::LossFunctionType::SOFT_L1)
.value("CAUCHY", BAOpts::LossFunctionType::CAUCHY);
AddStringToEnumConstructor(PyBALossFunctionType);
using CSOpts = ceres::Solver::Options;
auto PyCeresSolverOptions =
py::class_<CSOpts>(
m,
"CeresSolverOptions",
// If ceres::Solver::Options is registered by pycolmap AND a
// downstream library, importing the downstream library results in
// error:
// ImportError: generic_type: type "CeresSolverOptions" is already
// registered!
// Adding a `py::module_local()` fixes this.
// https://github.com/pybind/pybind11/issues/439#issuecomment-1338251822
py::module_local())
.def(py::init<>())
.def_readwrite("function_tolerance", &CSOpts::function_tolerance)
.def_readwrite("gradient_tolerance", &CSOpts::gradient_tolerance)
.def_readwrite("parameter_tolerance", &CSOpts::parameter_tolerance)
.def_readwrite("minimizer_progress_to_stdout",
&CSOpts::minimizer_progress_to_stdout)
.def_readwrite("minimizer_progress_to_stdout",
&CSOpts::minimizer_progress_to_stdout)
.def_readwrite("max_num_iterations", &CSOpts::max_num_iterations)
.def_readwrite("max_linear_solver_iterations",
&CSOpts::max_linear_solver_iterations)
.def_readwrite("max_num_consecutive_invalid_steps",
&CSOpts::max_num_consecutive_invalid_steps)
.def_readwrite("max_consecutive_nonmonotonic_steps",
&CSOpts::max_consecutive_nonmonotonic_steps)
.def_readwrite("num_threads", &CSOpts::num_threads);
make_dataclass(PyCeresSolverOptions);
auto PyBundleAdjustmentOptions =
py::class_<BAOpts>(m, "BundleAdjustmentOptions")
.def(py::init<>())
.def_readwrite("loss_function_type",
&BAOpts::loss_function_type,
"Loss function types: Trivial (non-robust) and Cauchy "
"(robust) loss.")
.def_readwrite("loss_function_scale",
&BAOpts::loss_function_scale,
"Scaling factor determines residual at which "
"robustification takes place.")
.def_readwrite("refine_focal_length",
&BAOpts::refine_focal_length,
"Whether to refine the focal length parameter group.")
.def_readwrite(
"refine_principal_point",
&BAOpts::refine_principal_point,
"Whether to refine the principal point parameter group.")
.def_readwrite("refine_extra_params",
&BAOpts::refine_extra_params,
"Whether to refine the extra parameter group.")
.def_readwrite("refine_extrinsics",
&BAOpts::refine_extrinsics,
"Whether to refine the extrinsic parameter group.")
.def_readwrite("print_summary",
&BAOpts::print_summary,
"Whether to print a final summary.")
.def_readwrite("min_num_residuals_for_multi_threading",
&BAOpts::min_num_residuals_for_multi_threading,
"Minimum number of residuals to enable "
"multi-threading. Note that "
"single-threaded is typically better for small bundle "
"adjustment problems "
"due to the overhead of threading. ")
.def_readwrite("solver_options",
&BAOpts::solver_options,
"Ceres-Solver options.");
make_dataclass(PyBundleAdjustmentOptions);
auto ba_options = PyBundleAdjustmentOptions().cast<BAOpts>();

m.def("triangulate_points",
&triangulate_points,
"reconstruction"_a,
Expand All @@ -198,32 +271,16 @@ void init_sfm(py::module& m) {
"Triangulate 3D points from known camera poses");

m.def("incremental_mapping",
static_cast<std::map<size_t, std::shared_ptr<Reconstruction>> (*)(
const py::object,
const py::object,
const py::object,
const IncrementalMapperOptions&,
const py::object)>(&incremental_mapping),
&incremental_mapping,
"database_path"_a,
"image_path"_a,
"output_path"_a,
"options"_a = mapper_options,
"input_path"_a = py::str(""),
"Triangulate 3D points from known poses");

m.def("incremental_mapping",
static_cast<std::map<size_t, std::shared_ptr<Reconstruction>> (*)(
const py::object,
const py::object,
const py::object,
const int,
const int,
const py::object)>(&incremental_mapping),
"database_path"_a,
"image_path"_a,
"output_path"_a,
"num_threads"_a = mapper_options.num_threads,
"min_num_matches"_a = mapper_options.min_num_matches,
"input_path"_a = py::str(""),
"Triangulate 3D points from known poses");
m.def("bundle_adjustment",
&bundle_adjustment,
"reconstruction"_a,
"options"_a = ba_options);
}

0 comments on commit 97484c0

Please sign in to comment.