diff --git a/libmambapy/src/libmambapy/bindings/expected_caster.hpp b/libmambapy/src/libmambapy/bindings/expected_caster.hpp new file mode 100644 index 0000000000..41801cde3b --- /dev/null +++ b/libmambapy/src/libmambapy/bindings/expected_caster.hpp @@ -0,0 +1,93 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#include +#include +#include + +#include +#include +#include +#include + +#ifndef MAMBA_PY_EXPECTED_CASTER +#define MAMBA_PY_EXPECTED_CASTER + +namespace PYBIND11_NAMESPACE +{ + namespace detail + { + namespace + { + template < + typename Expected, + typename T = typename Expected::value_type, + typename E = typename Expected::error_type> + auto expected_to_variant(Expected&& expected) -> std::variant + { + if (expected) + { + return { std::forward(expected).value() }; + } + return { std::forward(expected).error() }; + } + + template < + typename Variant, + typename T = std::decay_t(std::declval()))>, + typename E = std::decay_t(std::declval()))>> + auto expected_to_variant(Variant&& var) -> tl::expected + { + static_assert(std::variant_size_v == 2); + return std::visit( + [](auto&& v) -> tl::expected { return { std::forward(v) }; }, + var + ); + } + } + + /** + * A caster for tl::expected that converts to a union. + * + * The caster works by converting to a the expected to a variant and then calls the + * variant caster. + * + * A future direction could be considered to wrap the union into a Python "Expected", + * with methods such as ``and_then``, ``or_else``, and thowing method like ``value`` + * and ``error``. + */ + template + struct type_caster> + { + using value_type = tl::expected; + using variant_type = std::variant; + using caster_type = std::decay_t())>; + + auto load(handle src, bool convert) -> bool + { + auto caster = make_caster(); + if (caster.load(src, convert)) + { + value = variant_to_expected(cast_op(std::move(caster))); + } + return false; + } + + template + static auto cast(Expected&& src, return_value_policy policy, handle parent) -> handle + { + return caster_type::cast(expected_to_variant(std::forward(src)), policy, parent); + } + + PYBIND11_TYPE_CASTER( + value_type, + const_name(R"(Union[)") + detail::concat(make_caster::name, make_caster::name) + + const_name(R"(])") + ); + }; + } +} +#endif diff --git a/libmambapy/src/libmambapy/bindings/legacy.cpp b/libmambapy/src/libmambapy/bindings/legacy.cpp index 0ca12b13b3..1dfdf877f4 100644 --- a/libmambapy/src/libmambapy/bindings/legacy.cpp +++ b/libmambapy/src/libmambapy/bindings/legacy.cpp @@ -39,6 +39,7 @@ #include "mamba/validation/update_framework_v0_6.hpp" #include "bindings.hpp" +#include "expected_caster.hpp" #include "flat_set_caster.hpp" namespace py = pybind11;