diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index 1763cbb3bec..5806d7205e4 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -174,7 +174,7 @@ foreach(oneflow_single_file ${oneflow_all_src}) continue() endif() - if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt)/.*\\.(h|hpp)$") + if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt|maybe)/.*\\.(h|hpp)$") if((NOT RPC_BACKEND MATCHES "GRPC") AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/control/.*") # skip if GRPC not enabled elseif(APPLE AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/comm_network/(epoll|ibverbs)/.*") @@ -228,8 +228,8 @@ foreach(oneflow_single_file ${oneflow_all_src}) endif(BUILD_PYTHON) - if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt)/.*\\.cpp$") - if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt)/.*_test\\.cpp$") + if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt|maybe)/.*\\.cpp$") + if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt|maybe)/.*_test\\.cpp$") # test file list(APPEND of_all_test_cc ${oneflow_single_file}) elseif(APPLE AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/comm_network/(epoll|ibverbs)/.*") diff --git a/oneflow/maybe/config.h b/oneflow/maybe/config.h new file mode 100644 index 00000000000..6d4562f7c93 --- /dev/null +++ b/oneflow/maybe/config.h @@ -0,0 +1,52 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_MAYBE_CONFIG_H_ +#define ONEFLOW_MAYBE_CONFIG_H_ + +#include + +// pre-define it if you use a logging library like glog +#ifndef OF_MAYBE_ASSERT +#define OF_MAYBE_ASSERT(_cond_) assert(_cond_) +#endif + +// ASSERT_EQ is different from ASSERT in logging / testing framework +// pre-define it if you use a logging library like glog +#ifndef OF_MAYBE_ASSERT_EQ +#define OF_MAYBE_ASSERT_EQ(_lhs_, _rhs_) OF_MAYBE_ASSERT(_lhs_ == _rhs_) +#endif + +#if __GNUC__ >= 7 +#define OF_MAYBE_HAS_IS_AGGREGATE +// in old versions of clang, __has_builtin(__is_aggregate) returns false +#elif __clang__ +#if !__is_identifier(__is_aggregate) +#define OF_MAYBE_HAS_IS_AGGREGATE +#endif +#elif __has_builtin(__is_aggregate) +#define OF_MAYBE_HAS_IS_AGGREGATE +#endif + +#ifdef OF_MAYBE_HAS_IS_AGGREGATE +#define OF_MAYBE_IS_AGGREGATE(...) __is_aggregate(__VA_ARGS__) +#else +// decay to POD checking if no such builtin (because implementing __is_aggregate need reflection) +#define OF_MAYBE_IS_AGGREGATE(...) \ + std::is_standard_layout<__VA_ARGS__>::value&& std::is_trivial<__VA_ARGS__>::value +#endif + +#endif // ONEFLOW_MAYBE_CONFIG_H_ diff --git a/oneflow/maybe/type_traits.h b/oneflow/maybe/type_traits.h new file mode 100644 index 00000000000..29bd52543b4 --- /dev/null +++ b/oneflow/maybe/type_traits.h @@ -0,0 +1,152 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_MAYBE_TYPE_TRAITS_H_ +#define ONEFLOW_MAYBE_TYPE_TRAITS_H_ + +#include +#include +#include +#include +#include "config.h" + +namespace oneflow { + +namespace maybe { + +// in this file, xxxS represents struct of xxx +// for implementant aspect, xxx is an alias of xxxS::type or xxxS::value + +template +using BoolConstant = std::integral_constant; + +template +using IndexConstant = std::integral_constant; + +constexpr std::size_t NPos = -1; + +template +struct ConjS : std::true_type {}; +template +struct ConjS : B1 {}; +template +struct ConjS : std::conditional_t, B1> {}; + +template +constexpr bool Conj = ConjS::value; + +template +struct DisjS : std::false_type {}; +template +struct DisjS : B1 {}; +template +struct DisjS : std::conditional_t> {}; + +template +constexpr bool Disj = DisjS::value; + +template +struct NegS : BoolConstant {}; + +template +constexpr bool Neg = NegS::value; + +struct TypeNotFound; + +// return TypeNotFound while out of range +template +struct TypeGetS; + +template +struct TypeGetS : TypeGetS {}; + +template +struct TypeGetS<0, T1, Tn...> { + using type = T1; +}; + +template +struct TypeGetS { + using type = TypeNotFound; +}; + +template +using TypeGet = typename TypeGetS::type; + +// return NPos (-1) while not found +template +struct IndexGetFromS; + +template +struct IndexGetFromS : IndexGetFromS {}; + +template +struct IndexGetFromS : IndexConstant {}; + +template +struct IndexGetFromS : IndexConstant {}; + +template +constexpr auto IndexGet = IndexGetFromS<0, T, Ts...>::value; + +template +constexpr auto TypeIn = IndexGet != NPos; + +template +using TypeInS = BoolConstant>; + +template +struct RemoveCVRefS { + using type = std::remove_cv_t>; +}; + +template +using RemoveCVRef = typename RemoveCVRefS::type; + +template +struct IsDifferentTypesS : BoolConstant && IsDifferentTypesS::value> {}; + +template +struct IsDifferentTypesS : std::true_type {}; + +template +constexpr auto IsDifferentTypes = IsDifferentTypesS::value; + +template +struct ConstRefExceptVoidS { + using type = const T&; +}; + +template<> +struct ConstRefExceptVoidS { + using type = void; +}; + +template +using ConstRefExceptVoid = typename ConstRefExceptVoidS::type; + +template +using RemoveRValRef = + std::conditional_t::value, std::remove_reference_t, T>; + +template +constexpr bool IsAggregate = OF_MAYBE_IS_AGGREGATE(T); + +} // namespace maybe + +} // namespace oneflow + +#endif // ONEFLOW_MAYBE_TYPE_TRAITS_H_ diff --git a/oneflow/maybe/type_traits_test.cpp b/oneflow/maybe/type_traits_test.cpp new file mode 100644 index 00000000000..35b6e8f7c46 --- /dev/null +++ b/oneflow/maybe/type_traits_test.cpp @@ -0,0 +1,64 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include "oneflow/maybe/type_traits.h" + +using namespace oneflow::maybe; + +TEST(TypeTraits, Basics) { + static_assert(Conj, ""); + static_assert(!Conj, ""); + static_assert(!Conj, ""); + static_assert(!Conj, ""); + static_assert(!Conj, ""); + + static_assert(Disj, ""); + static_assert(Disj, ""); + static_assert(Disj, ""); + static_assert(!Disj, ""); + static_assert(Disj, ""); + static_assert(!Disj, ""); + + static_assert(std::is_same, int>::value, ""); + static_assert(std::is_same, int>::value, ""); + static_assert(std::is_same, float>::value, ""); + static_assert(std::is_same, bool>::value, ""); + static_assert(std::is_same, TypeNotFound>::value, ""); + static_assert(std::is_same, float>::value, ""); + static_assert(std::is_same, float>::value, ""); + static_assert(std::is_same, TypeNotFound>::value, ""); + + static_assert(IndexGet == 0, ""); + static_assert(IndexGet == NPos, ""); + static_assert(IndexGet == 0, ""); + static_assert(IndexGet == 1, ""); + static_assert(IndexGet == 3, ""); + static_assert(IndexGet == NPos, ""); + + static_assert(!TypeIn, ""); + static_assert(TypeIn, ""); + static_assert(TypeIn, ""); + static_assert(!TypeIn, ""); + static_assert(TypeIn, ""); + static_assert(TypeIn, ""); + + static_assert(IsDifferentTypes, ""); + static_assert(!IsDifferentTypes, ""); + static_assert(IsDifferentTypes, ""); + static_assert(!IsDifferentTypes, ""); +} diff --git a/oneflow/maybe/utility.h b/oneflow/maybe/utility.h new file mode 100644 index 00000000000..c62833a4cde --- /dev/null +++ b/oneflow/maybe/utility.h @@ -0,0 +1,89 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_MAYBE_UTILITY_H_ +#define ONEFLOW_MAYBE_UTILITY_H_ + +#include +#include + +namespace oneflow { + +namespace maybe { + +// unlike std::nullopt in c++17, the NullOptType is used in both Variant and Optional, +// so it is more like both std::nullopt and std::monostate (in c++17), +// the advantage of this unification is a more unifed experience, +// i.e. `return NullOpt` can be used in both Variant and Optional context +struct NullOptType { + explicit constexpr NullOptType() = default; + + bool operator==(NullOptType) const { return true; } + bool operator!=(NullOptType) const { return false; } + bool operator<(NullOptType) const { return false; } + bool operator>(NullOptType) const { return false; } + bool operator<=(NullOptType) const { return true; } + bool operator>=(NullOptType) const { return true; } +}; + +constexpr const std::size_t NullOptHash = -3333; + +constexpr NullOptType NullOpt{}; + +struct InPlaceT { + explicit constexpr InPlaceT() = default; +}; + +constexpr InPlaceT InPlace; + +template +struct InPlaceTypeT { + explicit constexpr InPlaceTypeT() = default; +}; + +template +constexpr InPlaceTypeT InPlaceType; + +template +struct InPlaceIndexT { + explicit constexpr InPlaceIndexT() = default; +}; + +template +constexpr InPlaceIndexT InPlaceIndex; + +template +constexpr void HashCombine(std::size_t& seed, const T& v) { + std::hash hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +} // namespace maybe + +} // namespace oneflow + +namespace std { + +template<> +struct hash { + size_t operator()(oneflow::maybe::NullOptType) const noexcept { + return oneflow::maybe::NullOptHash; + } +}; + +} // namespace std + +#endif // ONEFLOW_MAYBE_UTILITY_H_ diff --git a/oneflow/maybe/utility_test.cpp b/oneflow/maybe/utility_test.cpp new file mode 100644 index 00000000000..7520e1b6eea --- /dev/null +++ b/oneflow/maybe/utility_test.cpp @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include "oneflow/maybe/utility.h" + +using namespace oneflow::maybe; + +TEST(Utility, NullOpt) { + NullOptType a, b(NullOpt), c(a); // NOLINT + + a = NullOpt; + + a = b; + + ASSERT_EQ(a, NullOptType{}); + ASSERT_EQ(std::hash()(a), std::hash()(NullOpt)); + ASSERT_EQ(NullOpt, a); + ASSERT_GE(NullOpt, a); + ASSERT_LE(NullOpt, a); + ASSERT_FALSE(NullOpt < a); + ASSERT_FALSE(NullOpt > a); +} diff --git a/oneflow/maybe/variant.h b/oneflow/maybe/variant.h new file mode 100644 index 00000000000..413ed2f559f --- /dev/null +++ b/oneflow/maybe/variant.h @@ -0,0 +1,454 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_MAYBE_VARIANT_H_ +#define ONEFLOW_MAYBE_VARIANT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "oneflow/maybe/utility.h" +#include "oneflow/maybe/type_traits.h" + +namespace oneflow { + +namespace maybe { + +template +struct Variant; + +namespace details { + +// there are generally two ways to implement visit (like std::visit in c++17) +// 1. O(N) or O(log N), to iterate for all types or do a binary search on type index recursively +// 2. O(1), to store an static (storage duration) array of function pointers for every (Variant, F) +// where N = Variant::Num, and normally (in most cases) within the range [2, 5] +// the 2nd method is required in std::visit(f, x...) while sizeof...(x) == 1 +// but weakness of the 2nd method is that compilers usually cannot efficiently optimize these +// function pointers (compared to trivial recursion, which is easy to do optimization, and also +// friendly to CPU cache) here we implement visit via the first method: +// 1. for 2 <= N < 4, we use the O(N) algorithm (TrivialRecursiveVisitImpl) for better optimization +// 2. for N >= 4, we use the O(log N) algorithm (BinarySearchVisitImpl) for less recursion rounds + +struct VariantPrivateScope { + template + static R TrivialRecursiveVisitImpl(F&& f, V&& v, InPlaceIndexT::Num - 1>) { + // assume v.Index() == N - 1 now + return static_cast( + std::forward(f)(std::forward(v).template Value::Num - 1>())); + } + + template::Num - 1), int> = 0> + static R TrivialRecursiveVisitImpl(F&& f, V&& v, InPlaceIndexT) { + if (v.Index() == I) { + return static_cast(std::forward(f)(std::forward(v).template Value())); + } + + return TrivialRecursiveVisitImpl(std::forward(f), std::forward(v), + InPlaceIndex); + } + + template::Num), int> = 0> + static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT, InPlaceIndexT) { + return static_cast(std::forward(f)(std::forward(v).template Value())); + } + + template::Num), int> = 0> + static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT, InPlaceIndexT) { + constexpr std::size_t M = (I + I + 1) / 2; + constexpr std::size_t N = (M == I) ? I + 1 : I; + + if (v.Index() == M) { + return static_cast(std::forward(f)(std::forward(v).template Value())); + } else { + return static_cast(std::forward(f)(std::forward(v).template Value())); + } + } + + template::Num), int> = 0> + static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT, InPlaceIndexT) { + constexpr std::size_t M = (L + U) / 2; + + if (v.Index() < M) { + return BinarySearchVisitImpl(std::forward(f), std::forward(v), InPlaceIndex, + InPlaceIndex); + } else if (v.Index() > M) { + return BinarySearchVisitImpl(std::forward(f), std::forward(v), InPlaceIndex, + InPlaceIndex); + } else { + return static_cast(std::forward(f)(std::forward(v).template Value())); + } + } + + template::Num<4, int> = 0> static R VisitImpl(F&& f, V&& v) { + return TrivialRecursiveVisitImpl(std::forward(f), std::forward(v), InPlaceIndex<0>); + } + + template::Num >= 4, int> = 0> + static R VisitImpl(F&& f, V&& v) { + return BinarySearchVisitImpl(std::forward(f), std::forward(v), InPlaceIndex<0>, + InPlaceIndex::Num - 1>); + } +}; + +struct AutoDeducedResultType; + +template +struct VisitResultS { + using type = R; +}; + +template +struct VisitResultS { + using type = std::common_type_t()(std::declval()))...>; +}; + +template +using VisitResult = typename VisitResultS::type; + +} // namespace details + +// preconditions: template type arguments must be no less than 2 different type +// and without reference and cv qualifiers +// this Variant DO NOT guarantee exception safty +template +struct Variant { // NOLINT(cppcoreguidelines-pro-type-member-init) + public: + static_assert(sizeof...(Ts) > 1, "expected more than two types"); + static_assert(Conj>...>, "reference types are not allowed here"); + static_assert(Conj, std::is_volatile>>...>, + "cv qualifiers are not allowed here"); + // important precondition to optimize Visit via binary search + static_assert(IsDifferentTypes, "expected all of different types"); + + static constexpr std::size_t Num = sizeof...(Ts); + + template + static constexpr auto IndexOfType = IndexGet; + + template + static constexpr bool HasType = TypeIn; + + template + using TypeByIndex = TypeGet; + + template, + std::enable_if_t::value, int> = 0> + Variant() { // NOLINT(cppcoreguidelines-pro-type-member-init) + Construct<0>(); + } + + // unlike std::variant, we only accept exact types to avoid wrong construction + template>, int> = 0> + Variant(T&& v) { // NOLINT(cppcoreguidelines-pro-type-member-init, google-explicit-constructor) + Construct>(std::forward(v)); + } + + template>, int> = 0> + explicit Variant(InPlaceTypeT, // NOLINT(cppcoreguidelines-pro-type-member-init) + Args&&... args) { + Construct>(std::forward(args)...); + } + + template = 0> + explicit Variant(InPlaceIndexT, // NOLINT(cppcoreguidelines-pro-type-member-init) + Args&&... args) { + Construct(std::forward(args)...); + } + + template + decltype(auto) Visit(F&& f) & { + using Result = details::VisitResult; + return details::VariantPrivateScope::VisitImpl(std::forward(f), *this); + } + + template + decltype(auto) Visit(F&& f) && { + using Result = details::VisitResult; + return details::VariantPrivateScope::VisitImpl(std::forward(f), std::move(*this)); + } + + template + decltype(auto) Visit(F&& f) const& { + using Result = details::VisitResult; + return details::VariantPrivateScope::VisitImpl(std::forward(f), *this); + } + + Variant(const Variant& v) { // NOLINT(cppcoreguidelines-pro-type-member-init) + CopyConstruct(v); + } + + Variant(Variant&& v) noexcept { // NOLINT(cppcoreguidelines-pro-type-member-init) + CopyConstruct(std::move(v)); + } + + template>, int> = 0> + Variant& operator=(T&& v) { + using Type = RemoveCVRef; + + Emplace(std::forward(v)); + + return *this; + } + + Variant& operator=(const Variant& v) { + Copy(v); + return *this; + } + + Variant& operator=(Variant&& v) noexcept { + Copy(std::move(v)); + return *this; + } + + std::size_t Index() const { return type_index_; } + + template, int> = 0> + bool Is() const { + return type_index_ == IndexOfType; + } + + ~Variant() { Destory(); } + + bool operator==(const Variant& v) const { + if (type_index_ != v.type_index_) return false; + + return v.Visit( + [this](const auto& elem) { return elem == Value>(); }); + } + + bool operator!=(const Variant& v) const { return !operator==(v); } + + bool operator<(const Variant& v) const { + if (type_index_ < v.type_index_) return true; + if (type_index_ > v.type_index_) return false; + + return v.Visit( + [this](const auto& elem) { return Value>() < elem; }); + } + + bool operator>=(const Variant& v) const { return !(*this < v); } + + bool operator>(const Variant& v) const { + if (type_index_ > v.type_index_) return true; + if (type_index_ < v.type_index_) return false; + + return v.Visit( + [this](const auto& elem) { return Value>() > elem; }); + } + + bool operator<=(const Variant& v) const { return !(*this > v); } + + template, int> = 0> + friend bool operator==(const Variant& v, const T& x) { + if (v.type_index_ != IndexOfType) return false; + + return v.Value() == x; + } + + template, int> = 0> + friend bool operator!=(const Variant& v, const T& x) { + return !(v == x); + } + + template, int> = 0> + friend bool operator==(const T& x, const Variant& v) { + return v == x; + } + + template, int> = 0> + friend bool operator!=(const T& x, const Variant& v) { + return !(v == x); + } + + template + T& Emplace(Args&&... args) { + if (Is()) { + return Value() = T(std::forward(args)...); + } else { + Destory(); + Construct(std::forward(args)...); + return Value(); + } + } + + template + decltype(auto) Emplace(Args&&... args) { + return Emplace>(std::forward(args)...); + } + + template, int> = 0> + T& Get() & { + OF_MAYBE_ASSERT_EQ(Index(), IndexOfType); + return Value(); + } + + template, int> = 0> + T&& Get() && { + OF_MAYBE_ASSERT_EQ(Index(), IndexOfType); + return std::move(*this).template Value(); + } + + template, int> = 0> + const T& Get() const& { + OF_MAYBE_ASSERT_EQ(Index(), IndexOfType); + return Value(); + } + + template = 0> + TypeByIndex& Get() & { + OF_MAYBE_ASSERT_EQ(Index(), I); + return Value(); + } + + template = 0> + TypeByIndex&& Get() && { + OF_MAYBE_ASSERT_EQ(Index(), I); + return std::move(*this).template Value(); + } + + template = 0> + const TypeByIndex& Get() const& { + OF_MAYBE_ASSERT_EQ(Index(), I); + return Value(); + } + + protected: + // use std::launder while updating to c++17 + template, int> = 0> + T& Value() & { + return *reinterpret_cast(storage_); + } + + template, int> = 0> + T&& Value() && { + return std::move(*reinterpret_cast(storage_)); + } + + template, int> = 0> + const T& Value() const& { + return *reinterpret_cast(storage_); + } + + template = 0> + TypeByIndex& Value() & { + return *reinterpret_cast*>(storage_); + } + + template = 0> + TypeByIndex&& Value() && { + return std::move(*reinterpret_cast*>(storage_)); + } + + template = 0> + const TypeByIndex& Value() const& { + return *reinterpret_cast*>(storage_); + } + + private: + static constexpr const std::size_t size = std::max({sizeof(Ts)...}); + + alignas(Ts...) unsigned char storage_[size]; + std::uint8_t type_index_; + + friend struct details::VariantPrivateScope; + + template && IsAggregate, int> = 0> + void Construct(Args&&... args) { + new (storage_) T{std::forward(args)...}; + type_index_ = IndexOfType; + } + + template && !IsAggregate, int> = 0> + void Construct(Args&&... args) { + new (storage_) T(std::forward(args)...); + type_index_ = IndexOfType; + } + + template = 0> + void Construct(Args&&... args) { + Construct>(std::forward(args)...); + } + + template + void CopyConstruct(V&& v) { + std::forward(v).Visit([this](auto&& elem) { + using T = RemoveCVRef; + + new (storage_) T(std::forward(elem)); + type_index_ = IndexOfType; + }); + } + + template + void Copy(V&& v) { + std::forward(v).Visit([this](auto&& elem) { + using T = RemoveCVRef; + + if (Is()) { + Value() = std::forward(elem); + } else { + Destory(); + Construct(std::forward(elem)); + } + }); + } + + void Destory() { + Visit([this](auto& elem) { + using T = RemoveCVRef; + + Value().~T(); + }); + } +}; + +template +using OptionalVariant = Variant; + +} // namespace maybe + +} // namespace oneflow + +namespace std { + +template +struct hash> { + size_t operator()(const oneflow::maybe::Variant& v) const noexcept { + size_t seed = hash()(v.Index()); + + v.Visit([&seed](const auto& x) { + using type = oneflow::maybe::RemoveCVRef; + oneflow::maybe::HashCombine(seed, x); + }); + + return seed; + } +}; + +} // namespace std + +#endif // ONEFLOW_MAYBE_VARIANT_H_ diff --git a/oneflow/maybe/variant_test.cpp b/oneflow/maybe/variant_test.cpp new file mode 100644 index 00000000000..7d1ebe128ae --- /dev/null +++ b/oneflow/maybe/variant_test.cpp @@ -0,0 +1,236 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include "oneflow/maybe/variant.h" + +using namespace oneflow::maybe; +using namespace std::string_literals; + +TEST(Variant, Basics) { + Variant a, b(1), c(1.2f), d(InPlaceType, 'a'), e(InPlaceType, 6.66); + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), 0); + ASSERT_TRUE(b.Is()); + ASSERT_EQ(b.Get(), 1); + ASSERT_TRUE(c.Is()); + ASSERT_EQ(c.Get(), 1.2f); + ASSERT_TRUE(d.Is()); + ASSERT_EQ(d.Get(), 'a'); + ASSERT_TRUE(e.Is()); + ASSERT_FLOAT_EQ(e.Get(), 6.66); + + Variant f(b), g(c), h(InPlaceIndex<1>, 2.33), i(InPlaceIndex<0>, 2.33); + ASSERT_TRUE(f.Is()); + ASSERT_EQ(f.Get(), 1); + ASSERT_TRUE(g.Is()); + ASSERT_EQ(g.Get(), 1.2f); + ASSERT_TRUE(h.Is()); + ASSERT_FLOAT_EQ(h.Get(), 2.33); + ASSERT_TRUE(i.Is()); + ASSERT_EQ(i.Get(), 2); + + a = 1; + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), 1); + + a = 1.3f; + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), 1.3f); + + a = b; + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), 1); + + a = c; + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), 1.2f); + + ASSERT_EQ((b.Visit>([](auto&& x) { return x + 1; })), + (Variant(2))); + ASSERT_EQ((c.Visit>([](auto&& x) { return x + 1; })), + (Variant(2.2f))); + + ASSERT_EQ(a.Emplace<1>(1.3f), 1.3f); + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get<1>(), 1.3f); + + ASSERT_EQ(a.Emplace<0>(233), 233); + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get<0>(), 233); +} + +TEST(Variant, NonPOD) { + Variant> a; + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), false); + + a = true; + ASSERT_TRUE(a.Is()); + ASSERT_EQ(a.Get(), true); + + a = std::make_shared(233); + ASSERT_EQ(a.Index(), 1); + ASSERT_EQ(*a.Get<1>(), 233); + ASSERT_EQ(a.Get<1>().use_count(), 1); + + { + Variant> b = a; + ASSERT_EQ(b.Index(), 1); + ASSERT_EQ(*b.Get<1>(), 233); + ASSERT_EQ(a.Get<1>().use_count(), 2); + *b.Get<1>() = 234; + } + + ASSERT_EQ(a.Get<1>().use_count(), 1); + ASSERT_EQ(*a.Get<1>(), 234); + + Variant> b = std::move(a); + ASSERT_EQ(b.Get<1>().use_count(), 1); + ASSERT_EQ(*b.Get<1>(), 234); + + Variant> c = b; + ASSERT_EQ(c.Get<1>().use_count(), 2); + ASSERT_EQ(b, c); + + b = true; + ASSERT_EQ(c.Get<1>().use_count(), 1); + + ASSERT_NE(b, c); +} + +TEST(Variant, Optional) { + OptionalVariant a, b(NullOpt), c(a); + + const char* hello = "hello"; + + std::size_t hash = 0, hash2 = 1, hash3 = 2; + HashCombine(hash, NullOpt); + HashCombine(hash2, 1); + HashCombine(hash3, hello); + + ASSERT_TRUE(a == NullOpt); + ASSERT_EQ(std::hash()(a), hash); + + a = 1; + ASSERT_EQ(a, 1); + ASSERT_EQ(std::hash()(a), hash2); + + a = NullOpt; + ASSERT_EQ(a, NullOpt); + ASSERT_EQ(std::hash()(a), hash); + + a = hello; + ASSERT_EQ(a, hello); + ASSERT_EQ(std::hash()(a), hash3); + + ASSERT_EQ(b, NullOpt); + ASSERT_EQ(c, NullOpt); + ASSERT_NE(a, b); +} + +TEST(Variant, BinarySearchVisit) { + const char* hello = "hello"; + + OptionalVariant x, y(123), z(1.2f), w(true); + OptionalVariant a, b(123), c(1.2f), d(true), e(hello); + + ASSERT_EQ(x, NullOpt); + ASSERT_EQ(y, 123); + ASSERT_EQ(z, 1.2f); + ASSERT_EQ(w, true); + ASSERT_EQ(a, NullOpt); + ASSERT_EQ(b, 123); + ASSERT_EQ(c, 1.2f); + ASSERT_EQ(d, true); + ASSERT_EQ(e, hello); + + OptionalVariant a1(a), b1(b), c1(c), d1(d), e1(e); + + ASSERT_EQ(a1, NullOpt); + ASSERT_EQ(b1, 123); + ASSERT_EQ(c1, 1.2f); + ASSERT_EQ(d1, true); + ASSERT_EQ(e1, hello); + + a = 233; + ASSERT_EQ(a, 233); + + a = hello; + ASSERT_EQ(a, hello); + + a = c; + ASSERT_EQ(a, 1.2f); + ASSERT_EQ(1.2f, a); + ASSERT_EQ(a, c); + ASSERT_NE(a, b); +} + +TEST(Variant, Compare) { + OptionalVariant a, b, c(0), d(5), dd(5), e(-1.2f), f(2.3f), g(false), h(true); + + ASSERT_EQ(a, b); + ASSERT_EQ(d, dd); + ASSERT_NE(a, c); + ASSERT_NE(c, d); + ASSERT_NE(d, e); + ASSERT_NE(e, f); + ASSERT_NE(f, g); + ASSERT_NE(g, h); + ASSERT_LT(a, c); + ASSERT_LT(c, d); + ASSERT_LT(d, e); + ASSERT_LT(e, f); + ASSERT_LT(f, g); + ASSERT_LT(g, h); + ASSERT_GT(c, a); + ASSERT_GT(d, c); + ASSERT_GT(e, d); + ASSERT_GT(f, e); + ASSERT_GT(g, f); + ASSERT_GT(h, g); + ASSERT_LE(a, b); + ASSERT_LE(b, c); + ASSERT_LE(c, d); + ASSERT_LE(d, dd); + + std::set> s{100, 2.3f, true, 3.3f, NullOpt, + 0, false, 22, true, NullOpt}; + ASSERT_EQ(s.size(), 8); + + auto iter = s.begin(); + ASSERT_EQ(*(iter++), NullOpt); + ASSERT_EQ(*(iter++), 0); + ASSERT_EQ(*(iter++), 22); + ASSERT_EQ(*(iter++), 100); + ASSERT_EQ(*(iter++), 2.3f); + ASSERT_EQ(*(iter++), 3.3f); + ASSERT_EQ(*(iter++), false); + ASSERT_EQ(*(iter++), true); +} + +TEST(Variant, UniquePtr) { + Variant> a("hello"s), b(std::make_unique(1)); + + ASSERT_EQ(a, "hello"s); + ASSERT_EQ(*b.Get<1>(), 1); + + Variant> c(std::move(a)), d(std::move(b)); + + ASSERT_EQ(c, "hello"s); + ASSERT_EQ(*d.Get<1>(), 1); +}