@@ -2212,11 +2212,63 @@ def f():
22122212 tm .assert_series_equal (res , exp )
22132213
22142214 # And test NaN handling...
2215- cat = pd . Series (pd . Categorical (["a" ,"b" ,"c" , np .nan ]))
2215+ cat = Series (Categorical (["a" ,"b" ,"c" , np .nan ]))
22162216 exp = Series ([True , True , True , False ])
22172217 res = (cat == cat )
22182218 tm .assert_series_equal (res , exp )
22192219
2220+ def test_cat_equality (self ):
2221+
2222+ # GH 8938
2223+ # allow equality comparisons
2224+ a = Series (list ('abc' ),dtype = "category" )
2225+ b = Series (list ('abc' ),dtype = "object" )
2226+ c = Series (['a' ,'b' ,'cc' ],dtype = "object" )
2227+ d = Series (list ('acb' ),dtype = "object" )
2228+ e = Categorical (list ('abc' ))
2229+ f = Categorical (list ('acb' ))
2230+
2231+ # vs scalar
2232+ self .assertFalse ((a == 'a' ).all ())
2233+ self .assertTrue (((a != 'a' ) == ~ (a == 'a' )).all ())
2234+
2235+ self .assertFalse (('a' == a ).all ())
2236+ self .assertTrue ((a == 'a' )[0 ])
2237+ self .assertTrue (('a' == a )[0 ])
2238+ self .assertFalse (('a' != a )[0 ])
2239+
2240+ # vs list-like
2241+ self .assertTrue ((a == a ).all ())
2242+ self .assertFalse ((a != a ).all ())
2243+
2244+ self .assertTrue ((a == list (a )).all ())
2245+ self .assertTrue ((a == b ).all ())
2246+ self .assertTrue ((b == a ).all ())
2247+ self .assertTrue (((~ (a == b ))== (a != b )).all ())
2248+ self .assertTrue (((~ (b == a ))== (b != a )).all ())
2249+
2250+ self .assertFalse ((a == c ).all ())
2251+ self .assertFalse ((c == a ).all ())
2252+ self .assertFalse ((a == d ).all ())
2253+ self .assertFalse ((d == a ).all ())
2254+
2255+ # vs a cat-like
2256+ self .assertTrue ((a == e ).all ())
2257+ self .assertTrue ((e == a ).all ())
2258+ self .assertFalse ((a == f ).all ())
2259+ self .assertFalse ((f == a ).all ())
2260+
2261+ self .assertTrue (((~ (a == e )== (a != e )).all ()))
2262+ self .assertTrue (((~ (e == a )== (e != a )).all ()))
2263+ self .assertTrue (((~ (a == f )== (a != f )).all ()))
2264+ self .assertTrue (((~ (f == a )== (f != a )).all ()))
2265+
2266+ # non-equality is not comparable
2267+ self .assertRaises (TypeError , lambda : a < b )
2268+ self .assertRaises (TypeError , lambda : b < a )
2269+ self .assertRaises (TypeError , lambda : a > b )
2270+ self .assertRaises (TypeError , lambda : b > a )
2271+
22202272 def test_concat (self ):
22212273 cat = pd .Categorical (["a" ,"b" ], categories = ["a" ,"b" ])
22222274 vals = [1 ,2 ]
0 commit comments