Skip to content

Commit

Permalink
feat: ak.from_rdataframe should accept a single string 'columns'.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 6, 2022
1 parent d1ff4f1 commit a2437fa
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
32 changes: 29 additions & 3 deletions src/awkward/operations/ak_from_rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ def from_rdataframe(rdf, columns):
Args:
rdf (`ROOT.RDataFrame`): ROOT RDataFrame to convert into an
Awkward Array.
columns (str or tuple of str): A column or multiple columns to be
columns (str or iterable of str): A column or multiple columns to be
converted to Awkward Array.
Converts ROOT Data Frame columns into an Awkward Array.
Converts ROOT RDataFrame columns into an Awkward Array.
If `columns` is a string, the return value represents a single RDataFrame column.
If `columns` is any other iterable, the return value is a record array, in which
each field corresponds to an RDataFrame column. In particular, if the `columns`
iterable contains only one string, it is still a record array, which has only
one field.
See also #ak.to_rdataframe.
"""
Expand All @@ -24,7 +31,26 @@ def from_rdataframe(rdf, columns):
def _impl(data_frame, columns):
import awkward._connect.rdataframe.from_rdataframe # noqa: F401

return ak._connect.rdataframe.from_rdataframe.from_rdataframe(
if isinstance(columns, str):
columns = (columns,)
project = True
else:
columns = tuple(columns)
project = False

if not all(isinstance(x, str) for x in columns):
raise ak._errors.wrap_error(
TypeError(
f"'columns' must be a string or an iterable of strings, not {columns!r}"
)
)

out = ak._connect.rdataframe.from_rdataframe.from_rdataframe(
data_frame,
columns,
)

if project:
return out[columns[0]]
else:
return out
15 changes: 15 additions & 0 deletions tests/test_1473-from-rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,18 @@ def test_to_from_data_frame_rvec_of_rvec_of_rvec():
)

assert ak_array_in.to_list() == ak_array_out["x"].to_list()


def test_to_from_data_frame_columns_as_string():
ak_array_in = ak.Array(
[[[[1.1]]], [[[2.2], [3.3], [], [4.4]]], [[[], [5.5, 6.6], []]]]
)

data_frame = ak.to_rdataframe({"x": ak_array_in})

ak_array_out = ak.from_rdataframe(
data_frame,
columns="x",
)

assert ak_array_in.to_list() == ak_array_out.to_list()

0 comments on commit a2437fa

Please sign in to comment.