From 48758fb5bffa093c1207868756670e51fa78d6c0 Mon Sep 17 00:00:00 2001 From: Jan Schulz Date: Sun, 12 Apr 2015 23:37:47 +0200 Subject: [PATCH] Fix for comparisons of categorical and an scalar not in categories Up to now, a comparison of categorical data and a scalar, which is not in the categories would return `False` for all elements when it should raise a `TypeError`, which it now does. Also fix that `!=` comparisons would return `False` for all elements when the more logical choice would be `True`. --- doc/source/whatsnew/v0.16.1.txt | 2 ++ pandas/core/categorical.py | 9 ++++++++- pandas/tests/test_categorical.py | 27 +++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v0.16.1.txt b/doc/source/whatsnew/v0.16.1.txt index dc5e3ddcefc06..45c4c9b00c97c 100644 --- a/doc/source/whatsnew/v0.16.1.txt +++ b/doc/source/whatsnew/v0.16.1.txt @@ -123,3 +123,5 @@ Bug Fixes - Bug in which ``SparseDataFrame`` could not take `nan` as a column name (:issue:`8822`) - Bug in unequal comparisons between a ``Series`` of dtype `"category"` and a scalar (e.g. ``Series(Categorical(list("abc"), categories=list("cba"), ordered=True)) > "b"``, which wouldn't use the order of the categories but use the lexicographical order. (:issue:`9848`) + +- Bug in unequal comparisons between categorical data and a scalar, which was not in the categories (e.g. ``Series(Categorical(list("abc"), ordered=True)) > "d"``. This returned ``False`` for all elements, but now raises a TypeError. Equality comparisons also now return ``False`` for ``==`` and ``True`` for ``!=``. (:issue:`9848`) diff --git a/pandas/core/categorical.py b/pandas/core/categorical.py index 991678a8e7d79..b79f2c9b4f6df 100644 --- a/pandas/core/categorical.py +++ b/pandas/core/categorical.py @@ -61,7 +61,14 @@ def f(self, other): i = self.categories.get_loc(other) return getattr(self._codes, op)(i) else: - return np.repeat(False, len(self)) + if op == '__eq__': + return np.repeat(False, len(self)) + elif op == '__ne__': + return np.repeat(True, len(self)) + else: + msg = "Cannot compare a Categorical for op {op} with a scalar, " \ + "which is not a category." + raise TypeError(msg.format(op=op)) else: # allow categorical vs object dtype array comparisons for equality diff --git a/pandas/tests/test_categorical.py b/pandas/tests/test_categorical.py index 4c5678bf6633f..af48774492b11 100644 --- a/pandas/tests/test_categorical.py +++ b/pandas/tests/test_categorical.py @@ -1087,6 +1087,20 @@ def test_reflected_comparison_with_scalars(self): self.assert_numpy_array_equal(cat > cat[0], [False, True, True]) self.assert_numpy_array_equal(cat[0] < cat, [False, True, True]) + def test_comparison_with_unknown_scalars(self): + # https://github.com/pydata/pandas/issues/9836#issuecomment-92123057 and following + # comparisons with scalars not in categories should raise for unequal comps, but not for + # equal/not equal + cat = pd.Categorical([1, 2, 3], ordered=True) + + self.assertRaises(TypeError, lambda: cat < 4) + self.assertRaises(TypeError, lambda: cat > 4) + self.assertRaises(TypeError, lambda: 4 < cat) + self.assertRaises(TypeError, lambda: 4 > cat) + + self.assert_numpy_array_equal(cat == 4 , [False, False, False]) + self.assert_numpy_array_equal(cat != 4 , [True, True, True]) + class TestCategoricalAsBlock(tm.TestCase): _multiprocess_can_split_ = True @@ -2440,6 +2454,19 @@ def f(): cat > "b" self.assertRaises(TypeError, f) + # https://github.com/pydata/pandas/issues/9836#issuecomment-92123057 and following + # comparisons with scalars not in categories should raise for unequal comps, but not for + # equal/not equal + cat = Series(Categorical(list("abc"), ordered=True)) + + self.assertRaises(TypeError, lambda: cat < "d") + self.assertRaises(TypeError, lambda: cat > "d") + self.assertRaises(TypeError, lambda: "d" < cat) + self.assertRaises(TypeError, lambda: "d" > cat) + + self.assert_series_equal(cat == "d" , Series([False, False, False])) + self.assert_series_equal(cat != "d" , Series([True, True, True])) + # And test NaN handling... cat = Series(Categorical(["a","b","c", np.nan]))