diff --git a/_unittests/ut_df/test_pandas_groupbynan.py b/_unittests/ut_df/test_pandas_groupbynan.py index 141568b..63af8fc 100644 --- a/_unittests/ut_df/test_pandas_groupbynan.py +++ b/_unittests/ut_df/test_pandas_groupbynan.py @@ -1,3 +1,4 @@ +# coding: utf-8 """ @brief test log(time=1s) """ @@ -5,7 +6,7 @@ import pandas import numpy from scipy.sparse.linalg import lsqr as sparse_lsqr -from pyquickhelper.pycode import ExtTestCase +from pyquickhelper.pycode import ExtTestCase, ignore_warnings from pandas_streaming.df import pandas_groupby_nan, numpy_types @@ -102,6 +103,40 @@ def test_pandas_groupbynan_regular_nanback(self): lambda: pandas_groupby_nan(df, ["a", "cc"], nanback=True).sum(), NotImplementedError) + def test_pandas_groupbynan_doc(self): + data = [dict(a=2, ind="a", n=1), + dict(a=2, ind="a"), + dict(a=3, ind="b"), + dict(a=30)] + df = pandas.DataFrame(data) + gr2 = pandas_groupby_nan(df, ["ind"]).sum() + ind = list(gr2['ind']) + self.assertTrue(numpy.isnan(ind[-1])) + val = list(gr2['a']) + self.assertEqual(val[-1], 30) + + @ignore_warnings(UserWarning) + def test_pandas_groupbynan_doc2(self): + data = [dict(a=2, ind="a", n=1), + dict(a=2, ind="a"), + dict(a=3, ind="b"), + dict(a=30)] + df = pandas.DataFrame(data) + gr2 = pandas_groupby_nan(df, ["ind", "a"], nanback=False).sum() + ind = list(gr2['ind']) + self.assertEqual(ind[-1], "²nan") + + def test_pandas_groupbynan_doc3(self): + data = [dict(a=2, ind="a", n=1), + dict(a=2, ind="a"), + dict(a=3, ind="b"), + dict(a=30)] + df = pandas.DataFrame(data) + self.assertRaise(lambda: pandas_groupby_nan(df, ["ind", "n"]).sum(), + NotImplementedError) + # ind = list(gr2['ind']) + # self.assertTrue(numpy.isnan(ind[-1])) + if __name__ == "__main__": unittest.main() diff --git a/pandas_streaming/df/dataframe_helpers.py b/pandas_streaming/df/dataframe_helpers.py index 18ead2e..7591fdf 100644 --- a/pandas_streaming/df/dataframe_helpers.py +++ b/pandas_streaming/df/dataframe_helpers.py @@ -289,7 +289,7 @@ def pandas_fillna(df, by, hasna=None, suffix=None): :param suffix: use a prefix for the NaN value :return: list of values chosen for each column, new dataframe (new copy) """ - suffix = suffix if suffix else "²" + suffix = suffix if suffix else "²nan" df = df.copy() rep = {} for c in by: @@ -364,7 +364,10 @@ def pandas_groupby_nan(df, by, axis=0, as_index=False, suffix=None, nanback=True from pandas import DataFrame - data = [dict(a=2, ind="a", n=1), dict(a=2, ind="a"), dict(a=3, ind="b"), dict(a=30)] + data = [dict(a=2, ind="a", n=1), + dict(a=2, ind="a"), + dict(a=3, ind="b"), + dict(a=30)] df = DataFrame(data) print(df) gr = df.groupby(["ind"]).sum() @@ -378,7 +381,10 @@ def pandas_groupby_nan(df, by, axis=0, as_index=False, suffix=None, nanback=True from pandas import DataFrame from pandas_streaming.df import pandas_groupby_nan - data = [dict(a=2, ind="a", n=1), dict(a=2, ind="a"), dict(a=3, ind="b"), dict(a=30)] + data = [dict(a=2, ind="a", n=1), + dict(a=2, ind="a"), + dict(a=3, ind="b"), + dict(a=30)] df = DataFrame(data) gr2 = pandas_groupby_nan(df, ["ind"]).sum() print(gr2) @@ -436,10 +442,22 @@ def pandas_groupby_nan(df, by, axis=0, as_index=False, suffix=None, nanback=True res.grouper.groupings[0].grouping_vector = arr if (hasattr(res.grouper.groupings[0], '_cache') and 'result_index' in res.grouper.groupings[0]._cache): - res.grouper.groupings[0]._cache = {} + index = res.grouper.groupings[0]._cache['result_index'] + if len(rep) == 1: + key = list(rep.values())[0] + new_index = numpy.array(index) + for i in range(0, len(new_index)): # pylint: disable=C0200 + if new_index[i] == key: + new_index[i] = numpy.nan + res.grouper.groupings[0]._cache['result_index'] = ( + index.__class__(new_index)) + else: + raise NotImplementedError( + "NaN values not implemented for multiindex.") else: - raise NotImplementedError("Not implemented for type: {0}".format( - type(res.grouper.groupings[0].grouper))) + raise NotImplementedError( + "Not implemented for type: {0}".format( + type(res.grouper.groupings[0].grouper))) res.grouper._cache['result_index'] = res.grouper.groupings[0]._group_index else: if not nanback: