Skip to content

Commit

Permalink
add a root file with rdf test
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna committed Mar 15, 2022
1 parent d62d989 commit 5f37041
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 41 deletions.
56 changes: 46 additions & 10 deletions src/awkward/_v2/_connect/rdataframe/from_rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,28 @@ def connect_ArrayBuilder(compiler, builder):
return f_cache


def count_dimentions(column_type):
rvec_count = column_type.count("RVec<")
v_count = column_type.count("vector<")
rvec_count = 0 if rvec_count == -1 else rvec_count
v_count = 0 if v_count == -1 else v_count
return rvec_count + v_count


def type_of_nested_data(column_type):
t1 = column_type.rfind("<")
t2 = column_type.find(">")
return column_type if t1 == -1 and t2 == -1 else column_type[t1 + 1 : t2]


def func_real(builder, cpp_reference, n):
for x in cpp_reference:
builder.real(x) if n <= 0 else func_real(builder, x, n - 1)


def to_awkward_array(
data_frame, compiler, columns=None, exclude=None, columns_as_records=True
):
import ROOT

# Find all column names in the dataframe
if not columns:
columns = [str(x) for x in data_frame.GetColumnNames()]
Expand All @@ -261,22 +278,41 @@ def to_awkward_array(
columns = [x for x in columns if x not in exclude]

builder = ak.ArrayBuilder()
b = ROOT.awkward.ArrayBuilderShim(ctypes.cast(builder._layout._ptr, ctypes.c_voidp))

if columns_as_records:
result_ptrs = {}
b.beginlist()
for col in columns:
b.beginrecord_check(col)
builder.begin_record(col)
builder.field(col)

column_type = data_frame.GetColumnType(col)
result_ptrs[col] = data_frame.Take[column_type](col)
cpp_reference = result_ptrs[col].GetValue()
if column_type == "double":

# check if the column type is either ROOT::VecOps::RVec or
# std::vector, possibly nested
result = count_dimentions(column_type)

for _ in range(result):
builder.begin_list()

type = type_of_nested_data(column_type)
if type == "double":

# FIXME: one thread, sequential???
# data_frame.Foreach["std::function<uint8_t(double)>"](b.real, [col])
for x in cpp_reference:
b.real(x)
b.endrecord()
b.endlist()

if result > 0:
func_real(builder, cpp_reference, result)
else:
builder.begin_list()
for x in cpp_reference:
builder.real(x)
builder.end_list()

for _ in range(result):
builder.end_list()

builder.end_record()

return builder.snapshot()
1 change: 0 additions & 1 deletion src/awkward/_v2/_connect/rdataframe/to_rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def class_type(self, key):
def generate(self, compiler, array, name=None, use_cached=True):
layout = array.layout
generator = ak._v2._connect.cling.togenerator(layout.form)
# lookup = ak._v2._lookup.Lookup(layout)

generator.generate(compiler)
key = generator.entry_type()
Expand Down
2 changes: 1 addition & 1 deletion tests/v2/test_1295-awkward-to-rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_nested_array_1():
rdf = rdf_source.generate(
rdf_source, compiler=compiler, array=array, name="c++func_name"
)
rdf.Display().Print()
rdf.Count().Print()

# both jitted
### rdf = ak._v2.to_rdataframe({"col1": ak_array_1, "col2": ak_array_2})
Expand Down
146 changes: 117 additions & 29 deletions tests/v2/test_1295-rdataframe-to-awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,45 +232,133 @@ def test_as_awkward():
assert (
str(array.layout.form)
== """{
"class": "ListOffsetArray64",
"offsets": "i64",
"content": {
"class": "UnionArray8_64",
"tags": "i8",
"index": "i64",
"contents": [
{
"class": "RecordArray",
"contents": {},
"parameters": {
"__record__": "x"
"class": "UnionArray8_64",
"tags": "i8",
"index": "i64",
"contents": [
{
"class": "RecordArray",
"contents": {
"x": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
{
"class": "RecordArray",
"contents": {},
"parameters": {
"__record__": "xx"
"parameters": {
"__record__": "x"
}
},
{
"class": "RecordArray",
"contents": {
"xx": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
"parameters": {
"__record__": "xx"
}
]
}
}
]
}"""
)

@pytest.mark.skip(reason="FIXME: test root file is not in git yet")
def test_rvec_snapshot():
treeName = "t"
fileName = "tests/samples/snapshot_nestedrvecs.root" #"snapshot_nestedrvecs.root"

@pytest.mark.skip(reason="FIXME: NotImplementedError")
def test_highlevel():
rdf = ROOT.RDataFrame(10).Define("x", "gRandom->Rndm()")
array = ak._v2.from_rdataframe(
rdf, "builder_name", function_for_foreach="my_function"
)
ROOT.gInterpreter.ProcessLine("""
#include <ROOT/RVec.hxx>
struct TwoInts {
int a, b;
};
#pragma link C++ class TwoInts+;
#pragma link C++ class ROOT::VecOps::RVec<TwoInts>+;
#pragma link C++ class ROOT::VecOps::RVec<ROOT::VecOps::RVec<TwoInts>>+;
""")
rdf = ROOT.RDataFrame(treeName, fileName)

array = rdf.AsAwkward(compiler, columns={"vv", "vvv", "vvti"}, columns_as_records=True)
assert (
str(array.layout.form)
== """{
"class": "RecordArray",
"contents": {
"one": "float64"
}
"class": "UnionArray8_64",
"tags": "i8",
"index": "i64",
"contents": [
{
"class": "RecordArray",
"contents": {
"NumpyArray_array": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
"parameters": {
"__record__": "NumpyArray_array"
}
},
{
"class": "RecordArray",
"contents": {
"px": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
"parameters": {
"__record__": "px"
}
},
{
"class": "RecordArray",
"contents": {
"py": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
"parameters": {
"__record__": "py"
}
},
{
"class": "RecordArray",
"contents": {
"E": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
"parameters": {
"__record__": "E"
}
},
{
"class": "RecordArray",
"contents": {
"pt": {
"class": "ListOffsetArray64",
"offsets": "i64",
"content": "float64"
}
},
"parameters": {
"__record__": "pt"
}
}
]
}"""
)
print(array.to_list())

0 comments on commit 5f37041

Please sign in to comment.