diff --git a/python/xorbits/_mars/dataframe/base/map_chunk.py b/python/xorbits/_mars/dataframe/base/map_chunk.py index 96fdad3a0..f2eef9a1f 100644 --- a/python/xorbits/_mars/dataframe/base/map_chunk.py +++ b/python/xorbits/_mars/dataframe/base/map_chunk.py @@ -70,11 +70,18 @@ def _set_inputs(self, inputs): self.input = self.inputs[0] def _infer_attrs_by_call(self, df_or_series): - test_obj = ( - build_df(df_or_series, size=2) - if df_or_series.ndim == 2 - else build_series(df_or_series, size=2, name=df_or_series.name) - ) + if len(df_or_series) == 0: + test_obj = ( + build_empty_df(df_or_series.dtypes) + if df_or_series.ndim == 2 + else (build_empty_series(df_or_series.dtype, name=df_or_series.name)) + ) + else: + test_obj = ( + build_df(df_or_series, size=2) + if df_or_series.ndim == 2 + else build_series(df_or_series, size=2, name=df_or_series.name) + ) kwargs = self.kwargs or dict() if self.with_chunk_index: kwargs["chunk_index"] = (0,) * df_or_series.ndim @@ -256,14 +263,6 @@ def execute(cls, ctx, op: "DataFrameMapChunk"): func = cloudpickle.loads(op.func) inp = ctx[op.input.key] out = op.outputs[0] - if len(inp) == 0: - if op.output_types[0] == OutputType.dataframe: - ctx[out.key] = build_empty_df(out.dtypes) - elif op.output_types[0] == OutputType.series: - ctx[out.key] = build_empty_series(out.dtype, name=out.name) - else: - raise ValueError(f"Chunk can not be empty except for dataframe/series.") - return kwargs = op.kwargs or dict() if op.with_chunk_index: diff --git a/python/xorbits/_mars/dataframe/base/tests/test_base_execution.py b/python/xorbits/_mars/dataframe/base/tests/test_base_execution.py index 529812ea5..d9f594bce 100644 --- a/python/xorbits/_mars/dataframe/base/tests/test_base_execution.py +++ b/python/xorbits/_mars/dataframe/base/tests/test_base_execution.py @@ -3232,3 +3232,48 @@ def test_copy_deep(setup, chunk_size): expected_c["a1"] = expected_c["a1"] + 0.8 pd.testing.assert_frame_equal(xdf_c.execute().fetch(), expected_c) pd.testing.assert_frame_equal(xdf.execute().fetch(), expected) + + +def test_map_chunk_with_empty_input(setup): + df = pd.DataFrame(columns=["a", "b", "c"]) + series = pd.Series(name="hello") + mdf = from_pandas_df(df) + ms = from_pandas_series(series) + + # df to df + def p(d): + if not len(d): + return pd.DataFrame([[None] * d.shape[1]], columns=d.columns) + else: + return d + + res = mdf.map_chunk(p) + expected = pd.DataFrame([[None] * df.shape[1]], columns=df.columns) + pd.testing.assert_frame_equal(res.execute().fetch(), expected) + + # series to series + def x1(d): + if not len(d): + return pd.Series([1], name=d.name) + else: + return d + + res = ms.map_chunk(x1) + expected = pd.Series([1], name=series.name) + pd.testing.assert_series_equal(res.execute().fetch(), expected) + + # series to df + def x2(d): + return pd.DataFrame({d.name: [np.nan, 1, 2]}) + + res = ms.map_chunk(x2) + expected = pd.DataFrame({series.name: [np.nan, 1, 2]}) + pd.testing.assert_frame_equal(res.execute().fetch(), expected) + + # df to series + def x3(d): + return pd.Series(list(d.columns), name=d.columns[1]) + + res = mdf.map_chunk(x3) + expected = pd.Series(list(df.columns), name=df.columns[1]) + pd.testing.assert_series_equal(res.execute().fetch(), expected)