Skip to content

Commit b1a97ec

Browse files
committed
Add tests for systematic variations in distributed RDataFrame
1 parent 944b3f2 commit b1a97ec

File tree

4 files changed

+220
-0
lines changed

4 files changed

+220
-0
lines changed

python/distrdf/dask/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,7 @@ ROOTTEST_ADD_TEST(dask_test_rungraphs
2626
ROOTTEST_ADD_TEST(dask_test_reducer_merge
2727
MACRO test_reducer_merge.py)
2828

29+
ROOTTEST_ADD_TEST(dask_test_variations
30+
MACRO test_variations.py)
31+
2932
endif()
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import unittest
2+
3+
import ROOT
4+
5+
import DistRDF
6+
from DistRDF.Backends import Dask
7+
8+
from dask.distributed import Client, LocalCluster
9+
10+
11+
class VariationsTest(unittest.TestCase):
12+
"""Tests usage of systematic variations with Dask backend"""
13+
14+
@classmethod
15+
def setUpClass(cls):
16+
"""
17+
Set up test environment for this class. Currently this includes:
18+
19+
- Initialize a Dask client for the tests in this class. This uses a
20+
`LocalCluster` object that spawns 2 single-threaded Python processes.
21+
"""
22+
cls.client = Client(LocalCluster(n_workers=2, threads_per_worker=1, processes=True))
23+
24+
@classmethod
25+
def tearDownClass(cls):
26+
"""Reset test environment."""
27+
cls.client.shutdown()
28+
cls.client.close()
29+
30+
def test_histo(self):
31+
df = Dask.RDataFrame(10, daskclient=self.client, npartitions=2).Define("x", "1")
32+
df1 = df.Vary("x", "ROOT::RVecI{-2,2}", ["down", "up"])
33+
h = df1.Histo1D("x")
34+
histos = DistRDF.VariationsFor(h)
35+
36+
expectednames = ["nominal", "x:up", "x:down"]
37+
expectedmeans = [1, 2, -2]
38+
for varname, mean in zip(expectednames, expectedmeans):
39+
histo = histos[varname]
40+
self.assertIsInstance(histo, ROOT.TH1D)
41+
self.assertEqual(histo.GetEntries(), 10)
42+
self.assertAlmostEqual(histo.GetMean(), mean)
43+
44+
def test_graph(self):
45+
df = Dask.RDataFrame(10, daskclient=self.client, npartitions=2).Define("x", "1")
46+
g = df.Vary("x", "ROOT::RVecI{-1, 2}", nVariations=2).Graph("x", "x")
47+
gs = DistRDF.VariationsFor(g)
48+
49+
self.assertAlmostEqual(g.GetMean(), 1)
50+
51+
expectednames = ["nominal", "x:0", "x:1"]
52+
expectedmeans = [1, -1, 2]
53+
for varname, mean in zip(expectednames, expectedmeans):
54+
graph = gs[varname]
55+
self.assertIsInstance(graph, ROOT.TGraph)
56+
self.assertAlmostEqual(graph.GetMean(), mean)
57+
58+
def test_mixed(self):
59+
df = Dask.RDataFrame(10, daskclient=self.client, npartitions=2).Define("x", "1").Define("y", "42")
60+
h = df.Vary("x", "ROOT::RVecI{-1, 2}", variationTags=["down", "up"]).Histo1D("x", "y")
61+
histos = DistRDF.VariationsFor(h)
62+
63+
expectednames = ["nominal", "x:down", "x:up"]
64+
expectedmeans = [1, -1, 2]
65+
expectedmax = 420
66+
for varname, mean in zip(expectednames, expectedmeans):
67+
histo = histos[varname]
68+
self.assertIsInstance(histo, ROOT.TH1D)
69+
self.assertAlmostEqual(histo.GetMaximum(), expectedmax)
70+
self.assertAlmostEqual(histo.GetMean(), mean)
71+
72+
def test_simultaneous(self):
73+
df = Dask.RDataFrame(10, daskclient=self.client, npartitions=2).Define("x", "1").Define("y", "42")
74+
h = df.Vary(["x", "y"],
75+
"ROOT::RVec<ROOT::RVecI>{{-1, 2, 3}, {41, 43, 44}}",
76+
["down", "up", "other"], "xy").Histo1D("x", "y")
77+
histos = DistRDF.VariationsFor(h)
78+
79+
expectednames = ["nominal", "xy:down", "xy:up", "xy:other"]
80+
expectedmeans = [1, -1, 2, 3]
81+
expectedmax = [420, 410, 430, 440]
82+
for varname, mean, maxval in zip(expectednames, expectedmeans, expectedmax):
83+
graph = histos[varname]
84+
self.assertIsInstance(graph, ROOT.TH1D)
85+
self.assertAlmostEqual(graph.GetMaximum(), maxval)
86+
self.assertAlmostEqual(graph.GetMean(), mean)
87+
88+
def test_varyfiltersum(self):
89+
df = Dask.RDataFrame(10, daskclient=self.client, npartitions=2).Define("x", "1")
90+
df_sum = df.Vary("x", "ROOT::RVecI{-1*x, 2*x}", ("down", "up"), "myvariation").Filter("x > 0").Sum("x")
91+
92+
self.assertAlmostEqual(df_sum.GetValue(), 10)
93+
94+
sums = DistRDF.VariationsFor(df_sum)
95+
96+
expectednames = ["nominal", "myvariation:down", "myvariation:up"]
97+
expectedsums = [10, 0, 20]
98+
for varname, val in zip(expectednames, expectedsums):
99+
self.assertAlmostEqual(sums[varname], val)
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main(argv=[__file__])

python/distrdf/spark/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,7 @@ if (ROOT_test_distrdf_pyspark_FOUND)
4949
MACRO test_reducer_merge.py
5050
ENVIRONMENT ${PYSPARK_ENV_VARS})
5151

52+
ROOTTEST_ADD_TEST(spark_test_variations
53+
MACRO test_variations.py
54+
ENVIRONMENT ${PYSPARK_ENV_VARS})
5255
endif()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import unittest
2+
import warnings
3+
4+
import pyspark
5+
6+
import ROOT
7+
8+
import DistRDF
9+
from DistRDF.Backends import Spark
10+
11+
12+
class VariationsTest(unittest.TestCase):
13+
"""Tests usage of systematic variations with Spark backend"""
14+
15+
@classmethod
16+
def setUpClass(cls):
17+
"""
18+
Set up test environment for this class. Currently this includes:
19+
20+
- Ignore `ResourceWarning: unclosed socket` warning triggered by Spark.
21+
this is ignored by default in any application, but Python's unittest
22+
library overrides the default warning filters thus exposing this
23+
warning
24+
- Initialize a SparkContext for the tests in this class
25+
"""
26+
warnings.simplefilter("ignore", ResourceWarning)
27+
28+
sparkconf = pyspark.SparkConf().setMaster("local[2]")
29+
cls.sc = pyspark.SparkContext(conf=sparkconf)
30+
31+
@classmethod
32+
def tearDownClass(cls):
33+
"""Reset test environment."""
34+
warnings.simplefilter("default", ResourceWarning)
35+
36+
cls.sc.stop()
37+
38+
def test_histo(self):
39+
df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1")
40+
df1 = df.Vary("x", "ROOT::RVecI{-2,2}", ["down", "up"])
41+
h = df1.Histo1D("x")
42+
histos = DistRDF.VariationsFor(h)
43+
44+
expectednames = ["nominal", "x:up", "x:down"]
45+
expectedmeans = [1, 2, -2]
46+
for varname, mean in zip(expectednames, expectedmeans):
47+
histo = histos[varname]
48+
self.assertIsInstance(histo, ROOT.TH1D)
49+
self.assertEqual(histo.GetEntries(), 10)
50+
self.assertAlmostEqual(histo.GetMean(), mean)
51+
52+
def test_graph(self):
53+
df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1")
54+
g = df.Vary("x", "ROOT::RVecI{-1, 2}", nVariations=2).Graph("x", "x")
55+
gs = DistRDF.VariationsFor(g)
56+
57+
self.assertAlmostEqual(g.GetMean(), 1)
58+
59+
expectednames = ["nominal", "x:0", "x:1"]
60+
expectedmeans = [1, -1, 2]
61+
for varname, mean in zip(expectednames, expectedmeans):
62+
graph = gs[varname]
63+
self.assertIsInstance(graph, ROOT.TGraph)
64+
self.assertAlmostEqual(graph.GetMean(), mean)
65+
66+
def test_mixed(self):
67+
df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1").Define("y", "42")
68+
h = df.Vary("x", "ROOT::RVecI{-1, 2}", variationTags=["down", "up"]).Histo1D("x", "y")
69+
histos = DistRDF.VariationsFor(h)
70+
71+
expectednames = ["nominal", "x:down", "x:up"]
72+
expectedmeans = [1, -1, 2]
73+
expectedmax = 420
74+
for varname, mean in zip(expectednames, expectedmeans):
75+
histo = histos[varname]
76+
self.assertIsInstance(histo, ROOT.TH1D)
77+
self.assertAlmostEqual(histo.GetMaximum(), expectedmax)
78+
self.assertAlmostEqual(histo.GetMean(), mean)
79+
80+
def test_simultaneous(self):
81+
df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1").Define("y", "42")
82+
h = df.Vary(["x", "y"],
83+
"ROOT::RVec<ROOT::RVecI>{{-1, 2, 3}, {41, 43, 44}}",
84+
["down", "up", "other"], "xy").Histo1D("x", "y")
85+
histos = DistRDF.VariationsFor(h)
86+
87+
expectednames = ["nominal", "xy:down", "xy:up", "xy:other"]
88+
expectedmeans = [1, -1, 2, 3]
89+
expectedmax = [420, 410, 430, 440]
90+
for varname, mean, maxval in zip(expectednames, expectedmeans, expectedmax):
91+
graph = histos[varname]
92+
self.assertIsInstance(graph, ROOT.TH1D)
93+
self.assertAlmostEqual(graph.GetMaximum(), maxval)
94+
self.assertAlmostEqual(graph.GetMean(), mean)
95+
96+
def test_varyfiltersum(self):
97+
df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1")
98+
df_sum = df.Vary("x", "ROOT::RVecI{-1*x, 2*x}", ("down", "up"), "myvariation").Filter("x > 0").Sum("x")
99+
100+
self.assertAlmostEqual(df_sum.GetValue(), 10)
101+
102+
sums = DistRDF.VariationsFor(df_sum)
103+
104+
expectednames = ["nominal", "myvariation:down", "myvariation:up"]
105+
expectedsums = [10, 0, 20]
106+
for varname, val in zip(expectednames, expectedsums):
107+
self.assertAlmostEqual(sums[varname], val)
108+
109+
110+
if __name__ == "__main__":
111+
unittest.main(argv=[__file__])

0 commit comments

Comments
 (0)