From 447f90c10b03f1127956603e6eeba2c5cc54925b Mon Sep 17 00:00:00 2001 From: Qi Chen Date: Thu, 13 Jul 2023 14:00:13 +0100 Subject: [PATCH] Fix flaky test_filter_numeric_isnotin_unsigned (#496) The cause was type_arithmetic_promoted_type would return int64 as the common type for uint64 and any signed int. When we then do the static_cast(*ptr++) before calling the Is[Not]InOperator, the uint64 is converted to int64 and the special overloads in the Operators are never used. + Ability to only_test_encoding_version_v1 in a test/class/module --- .../processing/operation_dispatch_binary.hpp | 10 +++- cpp/arcticdb/processing/operation_types.hpp | 57 +++++++++++++++---- python/tests/conftest.py | 12 +++- .../arcticdb/version_store/test_filtering.py | 51 +++++++++++------ 4 files changed, 99 insertions(+), 31 deletions(-) diff --git a/cpp/arcticdb/processing/operation_dispatch_binary.hpp b/cpp/arcticdb/processing/operation_dispatch_binary.hpp index 5accbeee47..1245c3cc0c 100644 --- a/cpp/arcticdb/processing/operation_dispatch_binary.hpp +++ b/cpp/arcticdb/processing/operation_dispatch_binary.hpp @@ -100,8 +100,14 @@ VariantData binary_membership(const ColumnWithStrings& column_with_strings, Valu auto ptr = reinterpret_cast(block.value().data()); const auto row_count = block.value().row_count(); for (auto i = 0u; i < row_count; ++i, ++pos) { - if(func(static_cast(*ptr++), *typed_value_set)) - inserter = pos; + if constexpr (MembershipOperator::needs_uint64_special_handling) { + // Avoid narrowing conversion on *ptr: + if (func(*ptr++, *typed_value_set, UInt64SpecialHandlingTag{})) + inserter = pos; + } else { + if (func(static_cast(*ptr++), *typed_value_set)) + inserter = pos; + } } } inserter.flush(); diff --git a/cpp/arcticdb/processing/operation_types.hpp b/cpp/arcticdb/processing/operation_types.hpp index f18eb655ed..0461b0883f 100644 --- a/cpp/arcticdb/processing/operation_types.hpp +++ b/cpp/arcticdb/processing/operation_types.hpp @@ -56,6 +56,7 @@ struct PlusOperator; struct MinusOperator; struct TimesOperator; struct DivideOperator; +struct MembershipOperator; namespace arithmetic_promoted_type::details { template @@ -179,8 +180,15 @@ struct type_arithmetic_promoted_type { std::conditional_t<(std::is_signed_v && sizeof(LHS) > sizeof(RHS)) || (std::is_signed_v && sizeof(RHS) > sizeof(LHS)), // If the signed type is strictly larger than the unsigned type, then promote to the signed type typename arithmetic_promoted_type::details::signed_width_t, - // Otherwise, promote to a signed type wider than the unsigned type, so that it can be exactly represented - typename arithmetic_promoted_type::details::signed_width_t<2 * max_width> + // Otherwise, check if the unsigned one is the widest type we support + std::conditional_t || std::is_same_v, + // If so, there's no common type that can completely hold both arguments. We trigger operation-specific handling + std::conditional_t, + RHS, // Retains ValueSetBaseType in binary_membership() + int64_t>, // Retain the broken behaviour for Divide for now (https://github.com/man-group/ArcticDB/issues/594) + // There should be a signed type wider than the unsigned type, so both can be exactly represented + typename arithmetic_promoted_type::details::signed_width_t<2 * max_width> + > > > > @@ -373,18 +381,44 @@ bool operator()(int64_t t, uint64_t u) const { } }; -struct IsInOperator { +struct MembershipOperator { +protected: + /** Returns the high bits beyond what can be held in the signed integer type I, including I's sign bit. */ + template + static constexpr uint64_t incomparable_bits() { + return std::numeric_limits::max() - static_cast(std::numeric_limits::max()); + } + + template + static constexpr bool is_signed_int = std::is_integral_v && std::is_signed_v; + +public: + /** This is tighter than the signatures of the special handling operator()s below to reject argument types smaller + * than uint64 going down the special handling via type promotion. */ + template + static constexpr bool needs_uint64_special_handling = + (std::is_same_v && is_signed_int) || + (std::is_same_v && is_signed_int); +}; + +/** Used as a dummy parameter to ensure we don't pick the non-special handling overloads by mistake. */ +struct UInt64SpecialHandlingTag {}; + +struct IsInOperator: MembershipOperator { template bool operator()(T t, const std::unordered_set& u) const { return u.count(t) > 0; } -bool operator()(uint64_t t, const std::unordered_set& u) const { - if (comparison::msb_set(t)) + +template>> +bool operator()(uint64_t t, const std::unordered_set& u, UInt64SpecialHandlingTag = {}) const { + auto incomparable = t & incomparable_bits(); + if (incomparable) return false; else return u.count(t) > 0; } -bool operator()(int64_t t, const std::unordered_set& u) const { +bool operator()(int64_t t, const std::unordered_set& u, UInt64SpecialHandlingTag = {}) const { if (t < 0) return false; else @@ -397,18 +431,21 @@ bool operator()(T t, const emilib::HashSet& u) const { } }; -struct IsNotInOperator { +struct IsNotInOperator: MembershipOperator { template bool operator()(T t, const std::unordered_set& u) const { return u.count(t) == 0; } -bool operator()(uint64_t t, const std::unordered_set& u) const { - if (comparison::msb_set(t)) + +template>> +bool operator()(uint64_t t, const std::unordered_set& u, UInt64SpecialHandlingTag = {}) const { + auto incomparable = t & incomparable_bits(); + if (incomparable) return true; else return u.count(t) == 0; } -bool operator()(int64_t t, const std::unordered_set& u) const { +bool operator()(int64_t t, const std::unordered_set& u, UInt64SpecialHandlingTag = {}) const { if (t < 0) return true; else diff --git a/python/tests/conftest.py b/python/tests/conftest.py index ff2c2c5f22..7a3d6a62ad 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -253,9 +253,15 @@ class EncodingVersion(enum.IntEnum): V2 = 1 -@pytest.fixture(params=list(EncodingVersion)) -def encoding_version(request) -> EncodingVersion: - return request.param +@pytest.fixture(scope="session") +def only_test_encoding_version_v1(): + """Dummy fixture to reference at module/class level to reduce test cases""" + + +def pytest_generate_tests(metafunc): + if "encoding_version" in metafunc.fixturenames: + only_v1 = "only_test_encoding_version_v1" in metafunc.fixturenames + metafunc.parametrize("encoding_version", [EncodingVersion.V1] if only_v1 else list(EncodingVersion)) @pytest.fixture diff --git a/python/tests/unit/arcticdb/version_store/test_filtering.py b/python/tests/unit/arcticdb/version_store/test_filtering.py index 68f1494049..4fa7a43d43 100644 --- a/python/tests/unit/arcticdb/version_store/test_filtering.py +++ b/python/tests/unit/arcticdb/version_store/test_filtering.py @@ -19,13 +19,14 @@ import pandas as pd import pytest from pytz import timezone +from packaging.version import Version import random import string from arcticdb.exceptions import ArcticNativeException from arcticdb.version_store.processing import QueryBuilder from arcticdb_ext.exceptions import InternalException, UserInputException -from arcticdb.util.test import assert_frame_equal, IS_PANDAS_ZERO +from arcticdb.util.test import assert_frame_equal, PANDAS_VERSION from arcticdb.util.hypothesis import ( use_of_function_scoped_fixtures_in_hypothesis_checked, integral_type_strategies, @@ -39,15 +40,20 @@ from arcticdb_ext import set_config_int +@pytest.fixture(scope="module", autouse=True) +def _restrict_to(only_test_encoding_version_v1): + pass + + def generic_filter_test(version_store, symbol, df, arctic_query, pandas_query, dynamic_strings=True): version_store.write(symbol, df, dynamic_strings=dynamic_strings) expected = df.query(pandas_query) received = version_store.read(symbol, query_builder=arctic_query).data if not np.array_equal(expected, received): - print("Original dataframe\n{}".format(df)) - print("Pandas query\n{}".format(pandas_query)) - print("Expected\n{}".format(expected)) - print("Received\n{}".format(received)) + print(f"\nOriginal dataframe:\n{df}\ndtypes:\n{df.dtypes}") + print(f"\nPandas query: {pandas_query}") + print(f"\nPandas returns:\n{expected}") + print(f"\nQueryBuilder returns:\n{received}") assert False assert True @@ -895,12 +901,6 @@ def test_filter_isin_clashing_sets(lmdb_version_store): generic_filter_test(lmdb_version_store, "test_filter_isin_clashing_sets", df, q, pandas_query) -def numeric_isin_asumptions(df, vals): - assume(not df.empty) - # If df values need a uint64 to hold them then we only support unsigned vals - assume(df["a"].between(-(2**63), 2**63 - 1).all() or all(v >= 0 for v in vals)) - - @use_of_function_scoped_fixtures_in_hypothesis_checked @settings(deadline=None) @given( @@ -908,7 +908,7 @@ def numeric_isin_asumptions(df, vals): vals=st.frozensets(signed_integral_type_strategies(), min_size=1), ) def test_filter_numeric_isin_signed(lmdb_version_store, df, vals): - numeric_isin_asumptions(df, vals) + assume(not df.empty) q = QueryBuilder() q = q[q["a"].isin(vals)] pandas_query = "a in {}".format(list(vals)) @@ -922,7 +922,7 @@ def test_filter_numeric_isin_signed(lmdb_version_store, df, vals): vals=st.frozensets(unsigned_integral_type_strategies(), min_size=1), ) def test_filter_numeric_isin_unsigned(lmdb_version_store, df, vals): - numeric_isin_asumptions(df, vals) + assume(not df.empty) q = QueryBuilder() q = q[q["a"].isin(vals)] pandas_query = "a in {}".format(list(vals)) @@ -968,9 +968,9 @@ def test_filter_numeric_isin_unsigned(lmdb_version_store): df=dataframes_with_names_and_dtypes(["a"], integral_type_strategies()), vals=st.frozensets(unsigned_integral_type_strategies(), min_size=1), ) -@pytest.mark.skipif(IS_PANDAS_ZERO, reason="Early Pandas filtering does not handle unsigned well") +@pytest.mark.skipif(PANDAS_VERSION < Version("1.2"), reason="Early Pandas filtering does not handle unsigned well") def test_filter_numeric_isnotin_unsigned(lmdb_version_store, df, vals): - numeric_isin_asumptions(df, vals) + assume(not df.empty) q = QueryBuilder() q = q[q["a"].isnotin(vals)] pandas_query = "a not in {}".format(list(vals)) @@ -984,7 +984,7 @@ def test_filter_numeric_isnotin_unsigned(lmdb_version_store, df, vals): vals=st.frozensets(signed_integral_type_strategies(), min_size=1), ) def test_filter_numeric_isnotin_signed(lmdb_version_store, df, vals): - numeric_isin_asumptions(df, vals) + assume(not df.empty) q = QueryBuilder() q = q[q["a"].isnotin(vals)] pandas_query = "a not in {}".format(list(vals)) @@ -1010,6 +1010,25 @@ def test_filter_numeric_isnotin_hashing_overflow(lmdb_version_store): assert_frame_equal(df, result) +_uint64_max = np.iinfo(np.uint64).max + + +@pytest.mark.parametrize("op", ("in", "not in")) +@pytest.mark.parametrize("signed_type", (np.int8, np.int16, np.int32, np.int64)) +@pytest.mark.parametrize("uint64_in", ("df", "vals") if PANDAS_VERSION >= Version("1.2") else ("vals",)) +def test_filter_numeric_membership_mixing_int64_and_uint64(lmdb_version_store, op, signed_type, uint64_in): + signed = signed_type(-1) + if uint64_in == "df": + df, vals = pd.DataFrame({"a": [_uint64_max]}), [signed] + else: + df, vals = pd.DataFrame({"a": [signed]}), [_uint64_max] + + q = QueryBuilder() + q = q[q["a"].isin(vals) if op == "in" else q["a"].isnotin(vals)] + pandas_query = f"a {op} {vals}" + generic_filter_test(lmdb_version_store, "test_filter_numeric_mixing", df, q, pandas_query) + + @use_of_function_scoped_fixtures_in_hypothesis_checked @settings(deadline=None) @given(df=dataframes_with_names_and_dtypes(["a"], integral_type_strategies()))