diff --git a/ortools/algorithms/BUILD.bazel b/ortools/algorithms/BUILD.bazel index e6da9cbbae..72eac04dc2 100644 --- a/ortools/algorithms/BUILD.bazel +++ b/ortools/algorithms/BUILD.bazel @@ -534,3 +534,39 @@ cc_test( "//ortools/base:gmock_main", ], ) + +cc_library( + name = "n_choose_k", + srcs = ["n_choose_k.cc"], + hdrs = ["n_choose_k.h"], + deps = [ + ":binary_search", + "//ortools/base:mathutil", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "n_choose_k_test", + srcs = ["n_choose_k_test.cc"], + deps = [ + ":n_choose_k", + "//ortools/base:dump_vars", + "//ortools/base:fuzztest", + "//ortools/base:gmock_main", + "//ortools/base:mathutil", + "//ortools/util:flat_matrix", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_benchmark//:benchmark", + ], +) diff --git a/ortools/algorithms/n_choose_k.cc b/ortools/algorithms/n_choose_k.cc new file mode 100644 index 0000000000..8799a43c9f --- /dev/null +++ b/ortools/algorithms/n_choose_k.cc @@ -0,0 +1,169 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/algorithms/n_choose_k.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/numeric/int128.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "ortools/algorithms/binary_search.h" +#include "ortools/base/logging.h" +#include "ortools/base/mathutil.h" + +namespace operations_research { +namespace { +// This is the actual computation. It's in O(k). +template +Int InternalChoose(Int n, Int k) { + DCHECK_LE(k, n - k); + DCHECK_GT(k, 0); // Having k>0 lets us start with i=2 (small optimization). + // We compute n * (n-1) * ... * (n-k+1) / k! in the best possible order to + // guarantee exact results, while trying to avoid overflows. It's not + // perfect: we finish with a division by k, which means that me may overflow + // even if the result doesn't (by a factor of up to k). + Int result = n; + for (Int i = 2; i <= k; ++i) { + result *= n + 1 - i; + result /= i; // The product of i consecutive numbers is divisible by i!. + } + return result; +} + +// This function precomputes the maximum N such that (N choose K) doesn't +// overflow, for all K. +// When `overflows_intermediate_computation` is true, "overflow" means +// "some overflow happens inside InternalChoose()", and when it's false +// it simply means "the result doesn't fit in an int64_t". +// This is only used in contexts where K ≤ N-K, which implies N ≥ 2K, thus we +// can stop when (2K Choose K) overflows, because at and beyond such K, +// (N Choose K) will always overflow. In practice that happens for K=31 or 34 +// depending on `overflows_intermediate_computation`. +template +std::vector LastNThatDoesNotOverflowForAllK( + bool overflows_intermediate_computation) { + absl::Time start_time = absl::Now(); + // Given the algorithm used in InternalChoose(), it's not hard to + // find out when (N choose K) overflows an int64_t during its internal + // computation: that's when (N choose K) > MAX_INT / k. + + // For K ≤ 2, we hardcode the values of the maximum N. That's because + // the binary search done below uses MathUtil::LogCombinations, which only + // works on int32_t, and that's problematic for the max N we get for K=2. + // + // For K=2, we want N(N-1) ≤ 2^num_digits, or N(N-1)/2 ≤ 2^num_digits if + // !overflows_intermediate_computation, i.e. N(N-1) ≤ 2^(num_digits+1). + // Then, when d is even, N(N-1) ≤ 2^d ⇔ N ≤ 2^(d/2), which is simple. + // When d is odd, it's harder: N(N-1)≈(N-0.5)² and thus we get the bound + // N ≤ pow(2.0, d/2)+0.5. + const int bound_digits = std::numeric_limits::digits + + (overflows_intermediate_computation ? 0 : 1); + std::vector result = { + std::numeric_limits::max(), // K=0 + std::numeric_limits::max(), // K=1 + bound_digits % 2 == 0 + ? Int{1} << (bound_digits / 2) + : static_cast( + 0.5 + std::pow(2.0, 0.5 * std::numeric_limits::digits)), + }; + // We find the last N with binary search, for all K. We stop growing K + // when (2*K Choose K) overflows. + for (Int k = 3;; ++k) { + const double max_log_comb = + overflows_intermediate_computation + ? std::numeric_limits::digits * std::log(2) - std::log(k) + : std::numeric_limits::digits * std::log(2); + result.push_back(BinarySearch( + /*x_true*/ k, + // x_false=X, X needs to be large enough so that X choose 3 overflows: + // (X choose 3)≈(X-1)³/6, so we pick X = 2+6*2^(num_digits/3+1). + /*x_false=*/ + (static_cast( + 2 + 6 * std::pow(2.0, std::numeric_limits::digits / 3 + 1))), + [k, max_log_comb](Int n) { + return MathUtil::LogCombinations(n, k) <= max_log_comb; + })); + if (result.back() < 2 * k) { + result.pop_back(); + break; + } + } + // Some DCHECKs for int64_t, which should validate the general formulaes. + if constexpr (std::numeric_limits::digits == 63) { + DCHECK_EQ(result.size(), + overflows_intermediate_computation + ? 31 // 60 Choose 30 < 2^63/30 but 62 Choose 31 > 2^63/31. + : 34); // 66 Choose 33 < 2^63 but 68 Choose 34 > 2^63. + } + VLOG(1) << "LastNThatDoesNotOverflowForAllK(): " << absl::Now() - start_time; + return result; +} + +template +bool NChooseKIntermediateComputationOverflowsInt(Int n, Int k) { + DCHECK_LE(k, n - k); + static const auto* const result = + new std::vector(LastNThatDoesNotOverflowForAllK( + /*overflows_intermediate_computation=*/true)); + return k < result->size() ? n > (*result)[k] : true; +} + +template +bool NChooseKResultOverflowsInt(Int n, Int k) { + DCHECK_LE(k, n - k); + static const auto* const result = + new std::vector(LastNThatDoesNotOverflowForAllK( + /*overflows_intermediate_computation=*/false)); + return k < result->size() ? n > (*result)[k] : true; +} +} // namespace + +// NOTE(user): If performance ever matters, we could simply precompute and +// store all (N choose K) that don't overflow, there aren't that many of them: +// only a few tens of thousands, after removing simple cases like k ≤ 5. +absl::StatusOr NChooseK(int64_t n, int64_t k) { + if (n < 0) { + return absl::InvalidArgumentError(absl::StrFormat("n is negative (%d)", n)); + } + if (k < 0) { + return absl::InvalidArgumentError(absl::StrFormat("k is negative (%d)", k)); + } + if (k > n) { + return absl::InvalidArgumentError( + absl::StrFormat("k=%d is greater than n=%d", k, n)); + } + if (k > n / 2) k = n - k; + if (k == 0) return 1; + if (n < std::numeric_limits::max() && + !NChooseKIntermediateComputationOverflowsInt(n, k)) { + return static_cast(InternalChoose(n, k)); + } + if (!NChooseKIntermediateComputationOverflowsInt(n, k)) { + return InternalChoose(n, k); + } + if (NChooseKResultOverflowsInt(n, k)) { + return absl::InvalidArgumentError( + absl::StrFormat("(%d choose %d) overflows int64", n, k)); + } + return static_cast(InternalChoose(n, k)); +} + +} // namespace operations_research diff --git a/ortools/algorithms/n_choose_k.h b/ortools/algorithms/n_choose_k.h new file mode 100644 index 0000000000..937faffdce --- /dev/null +++ b/ortools/algorithms/n_choose_k.h @@ -0,0 +1,34 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_ +#define OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_ + +#include + +#include "absl/status/statusor.h" + +namespace operations_research { +// Returns the number of ways to choose k elements among n, ignoring the order, +// i.e., the binomial coefficient (n, k). +// This is like std::exp(MathUtil::LogCombinations(n, k)), but faster, with +// perfect accuracy, and returning an error iff the result would overflow an +// int64_t or if an argument is invalid (i.e., n < 0, k < 0, or k > n). +// +// NOTE(user): If you need a variation of this, ask the authors: it's very easy +// to add. E.g., other int types, other behaviors (e.g., return 0 if k > n, or +// std::numeric_limits::max() on overflow, etc). +absl::StatusOr NChooseK(int64_t n, int64_t k); +} // namespace operations_research + +#endif // OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_ diff --git a/ortools/algorithms/n_choose_k_test.cc b/ortools/algorithms/n_choose_k_test.cc new file mode 100644 index 0000000000..693b2949f6 --- /dev/null +++ b/ortools/algorithms/n_choose_k_test.cc @@ -0,0 +1,323 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/algorithms/n_choose_k.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/dump_vars.h" +//#include "ortools/base/fuzztest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/mathutil.h" +#include "ortools/util/flat_matrix.h" + +namespace operations_research { +namespace { +//using ::fuzztest::NonNegative; +using ::testing::HasSubstr; +using ::testing::status::IsOkAndHolds; +using ::testing::status::StatusIs; + +constexpr int64_t kint64max = std::numeric_limits::max(); + +TEST(NChooseKTest, TrivialErrorCases) { + absl::BitGen random; + constexpr int kNumTests = 100'000; + for (int t = 0; t < kNumTests; ++t) { + const int64_t x = absl::LogUniform(random, 0, kint64max); + EXPECT_THAT(NChooseK(-1, x), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("n is negative"))); + EXPECT_THAT(NChooseK(x, -1), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("k is negative"))); + if (x != kint64max) { + EXPECT_THAT(NChooseK(x, x + 1), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("greater than n"))); + } + ASSERT_FALSE(HasFailure()) << DUMP_VARS(t, x); + } +} + +TEST(NChooseKTest, Symmetry) { + absl::BitGen random; + constexpr int kNumTests = 1'000'000; + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 0, kint64max); + const int64_t k = absl::LogUniform(random, 0, n); + const absl::StatusOr result1 = NChooseK(n, k); + const absl::StatusOr result2 = NChooseK(n, n - k); + if (result1.ok()) { + ASSERT_THAT(result2, IsOkAndHolds(result1.value())) << DUMP_VARS(t, n, k); + } else { + ASSERT_EQ(result2.status().code(), result1.status().code()) + << DUMP_VARS(t, n, k, result1, result2); + } + } +} + +TEST(NChooseKTest, Invariant) { + absl::BitGen random; + constexpr int kNumTests = 1'000'000; + int num_tested_invariants = 0; + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 2, 100); + const int64_t k = absl::LogUniform(random, 1, n - 1); + const absl::StatusOr n_k = NChooseK(n, k); + const absl::StatusOr nm1_k = NChooseK(n - 1, k); + const absl::StatusOr nm1_km1 = NChooseK(n - 1, k - 1); + if (n_k.ok()) { + ++num_tested_invariants; + ASSERT_OK(nm1_k); + ASSERT_OK(nm1_km1); + ASSERT_EQ(n_k.value(), nm1_k.value() + nm1_km1.value()) + << DUMP_VARS(t, n, k, n_k, nm1_k, nm1_km1); + } + } + EXPECT_GE(num_tested_invariants, kNumTests / 10); +} + +TEST(NChooseKTest, ComparisonAgainstClosedFormsForK0) { + for (int64_t n : {int64_t{0}, int64_t{1}, kint64max}) { + EXPECT_THAT(NChooseK(n, 0), IsOkAndHolds(1)) << n; + } + absl::BitGen random; + constexpr int kNumTests = 1'000'000; + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 0, kint64max); + ASSERT_THAT(NChooseK(n, 0), IsOkAndHolds(1)) << DUMP_VARS(n, t); + } +} + +TEST(NChooseKTest, ComparisonAgainstClosedFormsForK1) { + for (int64_t n : {int64_t{1}, kint64max}) { + EXPECT_THAT(NChooseK(n, 1), IsOkAndHolds(n)); + } + absl::BitGen random; + constexpr int kNumTests = 1'000'000; + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 1, kint64max); + ASSERT_THAT(NChooseK(n, 1), IsOkAndHolds(n)) << DUMP_VARS(t); + } +} + +TEST(NChooseKTest, ComparisonAgainstClosedFormsForK2) { + // 2^32 Choose 2 = 2^32 × (2^32-1) / 2 = 2^63 - 2^31 < kint64max, + // but (2^32+1) Choose 2 = 2^63 + 2^31 overflows. + constexpr int64_t max_n = int64_t{1} << 32; + for (int64_t n : {int64_t{2}, max_n}) { + const int64_t n_choose_2 = + static_cast(absl::uint128(n) * (n - 1) / 2); + EXPECT_THAT(NChooseK(n, 2), IsOkAndHolds(n_choose_2)) << DUMP_VARS(n); + } + EXPECT_THAT(NChooseK(max_n + 1, 2), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))); + + absl::BitGen random; + constexpr int kNumTests = 100'000; + // Random valid results. + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 2, max_n); + const int64_t n_choose_2 = + static_cast(absl::uint128(n) * (n - 1) / 2); + ASSERT_THAT(NChooseK(n, 2), IsOkAndHolds(n_choose_2)) << DUMP_VARS(t, n); + } + // Random overflows. + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, max_n + 1, kint64max); + ASSERT_THAT(NChooseK(n, 2), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))) + << DUMP_VARS(t, n); + } +} + +TEST(NChooseKTest, ComparisonAgainstClosedFormsForK3) { + // This is 1 + ∛6×2^21. Checked manually on Google's scientific calculator. + const int64_t max_n = + static_cast(1 + std::pow(6, 1.0 / 3) * std::pow(2, 21)); + for (int64_t n : {int64_t{3}, max_n}) { + const int64_t n_choose_3 = + static_cast(absl::uint128(n) * (n - 1) * (n - 2) / 6); + EXPECT_THAT(NChooseK(n, 3), IsOkAndHolds(n_choose_3)) << DUMP_VARS(n); + } + EXPECT_THAT(NChooseK(max_n + 1, 3), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))); + + absl::BitGen random; + constexpr int kNumTests = 100'000; + // Random valid results. + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 3, max_n); + const int64_t n_choose_3 = + static_cast(absl::uint128(n) * (n - 1) * (n - 2) / 6); + ASSERT_THAT(NChooseK(n, 3), IsOkAndHolds(n_choose_3)) << DUMP_VARS(t, n); + } + // Random overflows. + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, max_n + 1, kint64max); + ASSERT_THAT(NChooseK(n, 3), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))) + << DUMP_VARS(t, n); + } +} + +TEST(NChooseKTest, ComparisonAgainstClosedFormsForK4) { + // This is 1.5 + ∜24 × 2^(63/4). + // Checked manually on Google's scientific calculator. + const int64_t max_n = + static_cast(1.5 + std::pow(24, 1.0 / 4) * std::pow(2, 63.0 / 4)); + for (int64_t n : {int64_t{4}, max_n}) { + const int64_t n_choose_4 = static_cast(absl::uint128(n) * (n - 1) * + (n - 2) * (n - 3) / 24); + EXPECT_THAT(NChooseK(n, 4), IsOkAndHolds(n_choose_4)) << DUMP_VARS(n); + } + EXPECT_THAT(NChooseK(max_n + 1, 4), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))); + + absl::BitGen random; + constexpr int kNumTests = 100'000; + // Random valid results. + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, 4, max_n); + const int64_t n_choose_4 = static_cast(absl::uint128(n) * (n - 1) * + (n - 2) * (n - 3) / 24); + ASSERT_THAT(NChooseK(n, 4), IsOkAndHolds(n_choose_4)) << DUMP_VARS(t, n); + } + // Random overflows. + for (int t = 0; t < kNumTests; ++t) { + const int64_t n = absl::LogUniform(random, max_n + 1, kint64max); + ASSERT_THAT(NChooseK(n, 4), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))) + << DUMP_VARS(t, n); + } +} + +TEST(NChooseKTest, ComparisonAgainstPascalTriangleForK5OrAbove) { + // Fill the Pascal triangle. Use -1 for int64_t overflows. We go up to n = + // 17000 because (17000 Choose 5) ≈ 1.2e19 which overflows an int64_t. + constexpr int max_n = 17000; + FlatMatrix triangle(max_n + 1, max_n + 1); + for (int n = 0; n <= max_n; ++n) { + triangle[n][0] = 1; + triangle[n][n] = 1; + for (int i = 1; i < n; ++i) { + const int64_t a = triangle[n - 1][i - 1]; + const int64_t b = triangle[n - 1][i]; + if (a < 0 || b < 0 || absl::int128(a) + b > kint64max) { + triangle[n][i] = -1; + } else { + triangle[n][i] = a + b; + } + } + } + // Checking all 17000²/2 slots would be too expensive, so we check each + // "column" downwards until the first 10 overflows, and stop. + for (int k = 5; k < max_n; ++k) { + int num_overflows = 0; + for (int n = k + 5; n < max_n; ++n) { + if (num_overflows > 0) EXPECT_EQ(triangle[n][k], -1); + if (triangle[n][k] < 0) { + ++num_overflows; + EXPECT_THAT(NChooseK(n, k), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))); + if (num_overflows > 10) break; + } else { + EXPECT_THAT(NChooseK(n, k), IsOkAndHolds(triangle[n][k])); + } + } + } +} + +void MatchesLogCombinations(int n, int k) { + if (n < k) { + std::swap(k, n); + } + const auto exact = NChooseK(n, k); + const double log_approx = MathUtil::LogCombinations(n, k); + if (exact.ok()) { + // We accepted to compute the exact value, make sure that it matches the + // approximation. + ASSERT_NEAR(log(exact.value()), log_approx, 0.0001); + } else { + // We declined to compute the exact value, make sure that we had a good + // reason to, i.e. that the result did indeed overflow. + ASSERT_THAT(exact, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("overflows int64"))); + const double approx = exp(log_approx); + ASSERT_GE(std::nextafter(approx, std::numeric_limits::infinity()), + std::numeric_limits::max()) + << "we declined to compute the exact value of NChooseK(" << n << ", " + << k << "), but the log value is " << log_approx + << " (value: " << approx << "), which fits in int64_t"; + } +} +/* +FUZZ_TEST(NChooseKTest, MatchesLogCombinations) + // Ideally we'd test with `uint64_t`, but `LogCombinations` only accepts + // `int`. + .WithDomains(NonNegative(), NonNegative()); +*/ +template +void BM_NChooseK(benchmark::State& state) { + static constexpr int kNumInputs = 1000; + // Use deterministic random numbers to avoid noise. + std::mt19937 gen(42); + std::uniform_int_distribution random(0, kMaxN); + std::vector> inputs; + inputs.reserve(kNumInputs); + for (int i = 0; i < kNumInputs; ++i) { + int64_t n = random(gen); + int64_t k = random(gen); + if (n < k) { + std::swap(n, k); + } + inputs.emplace_back(n, k); + } + // Force the one-time, costly static initializations of NChooseK() to happen + // before the benchmark starts. + auto result = NChooseK(62, 31); + benchmark::DoNotOptimize(result); + + // Start the benchmark. + for (auto _ : state) { + for (const auto [n, k] : inputs) { + auto result = algo(n, k); + benchmark::DoNotOptimize(result); + } + } + state.SetItemsProcessed(state.iterations() * kNumInputs); +} +BENCHMARK(BM_NChooseK<30, operations_research::NChooseK>); // int32_t domain. +BENCHMARK( + BM_NChooseK<60, operations_research::NChooseK>); // int{32,64} domain. +BENCHMARK( + BM_NChooseK<100, operations_research::NChooseK>); // int{32,64,128} domain. +BENCHMARK( + BM_NChooseK<100, MathUtil::LogCombinations>); // int{32,64,128} domain. + +} // namespace +} // namespace operations_research