|
| 1 | +import unittest |
| 2 | +import warnings |
| 3 | + |
| 4 | +import pyspark |
| 5 | + |
| 6 | +import ROOT |
| 7 | + |
| 8 | +import DistRDF |
| 9 | +from DistRDF.Backends import Spark |
| 10 | + |
| 11 | + |
| 12 | +class VariationsTest(unittest.TestCase): |
| 13 | + """Tests usage of systematic variations with Spark backend""" |
| 14 | + |
| 15 | + @classmethod |
| 16 | + def setUpClass(cls): |
| 17 | + """ |
| 18 | + Set up test environment for this class. Currently this includes: |
| 19 | +
|
| 20 | + - Ignore `ResourceWarning: unclosed socket` warning triggered by Spark. |
| 21 | + this is ignored by default in any application, but Python's unittest |
| 22 | + library overrides the default warning filters thus exposing this |
| 23 | + warning |
| 24 | + - Initialize a SparkContext for the tests in this class |
| 25 | + """ |
| 26 | + warnings.simplefilter("ignore", ResourceWarning) |
| 27 | + |
| 28 | + sparkconf = pyspark.SparkConf().setMaster("local[2]") |
| 29 | + cls.sc = pyspark.SparkContext(conf=sparkconf) |
| 30 | + |
| 31 | + @classmethod |
| 32 | + def tearDownClass(cls): |
| 33 | + """Reset test environment.""" |
| 34 | + warnings.simplefilter("default", ResourceWarning) |
| 35 | + |
| 36 | + cls.sc.stop() |
| 37 | + |
| 38 | + def test_histo(self): |
| 39 | + df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1") |
| 40 | + df1 = df.Vary("x", "ROOT::RVecI{-2,2}", ["down", "up"]) |
| 41 | + h = df1.Histo1D("x") |
| 42 | + histos = DistRDF.VariationsFor(h) |
| 43 | + |
| 44 | + expectednames = ["nominal", "x:up", "x:down"] |
| 45 | + expectedmeans = [1, 2, -2] |
| 46 | + for varname, mean in zip(expectednames, expectedmeans): |
| 47 | + histo = histos[varname] |
| 48 | + self.assertIsInstance(histo, ROOT.TH1D) |
| 49 | + self.assertEqual(histo.GetEntries(), 10) |
| 50 | + self.assertAlmostEqual(histo.GetMean(), mean) |
| 51 | + |
| 52 | + def test_graph(self): |
| 53 | + df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1") |
| 54 | + g = df.Vary("x", "ROOT::RVecI{-1, 2}", nVariations=2).Graph("x", "x") |
| 55 | + gs = DistRDF.VariationsFor(g) |
| 56 | + |
| 57 | + self.assertAlmostEqual(g.GetMean(), 1) |
| 58 | + |
| 59 | + expectednames = ["nominal", "x:0", "x:1"] |
| 60 | + expectedmeans = [1, -1, 2] |
| 61 | + for varname, mean in zip(expectednames, expectedmeans): |
| 62 | + graph = gs[varname] |
| 63 | + self.assertIsInstance(graph, ROOT.TGraph) |
| 64 | + self.assertAlmostEqual(graph.GetMean(), mean) |
| 65 | + |
| 66 | + def test_mixed(self): |
| 67 | + df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1").Define("y", "42") |
| 68 | + h = df.Vary("x", "ROOT::RVecI{-1, 2}", variationTags=["down", "up"]).Histo1D("x", "y") |
| 69 | + histos = DistRDF.VariationsFor(h) |
| 70 | + |
| 71 | + expectednames = ["nominal", "x:down", "x:up"] |
| 72 | + expectedmeans = [1, -1, 2] |
| 73 | + expectedmax = 420 |
| 74 | + for varname, mean in zip(expectednames, expectedmeans): |
| 75 | + histo = histos[varname] |
| 76 | + self.assertIsInstance(histo, ROOT.TH1D) |
| 77 | + self.assertAlmostEqual(histo.GetMaximum(), expectedmax) |
| 78 | + self.assertAlmostEqual(histo.GetMean(), mean) |
| 79 | + |
| 80 | + def test_simultaneous(self): |
| 81 | + df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1").Define("y", "42") |
| 82 | + h = df.Vary(["x", "y"], |
| 83 | + "ROOT::RVec<ROOT::RVecI>{{-1, 2, 3}, {41, 43, 44}}", |
| 84 | + ["down", "up", "other"], "xy").Histo1D("x", "y") |
| 85 | + histos = DistRDF.VariationsFor(h) |
| 86 | + |
| 87 | + expectednames = ["nominal", "xy:down", "xy:up", "xy:other"] |
| 88 | + expectedmeans = [1, -1, 2, 3] |
| 89 | + expectedmax = [420, 410, 430, 440] |
| 90 | + for varname, mean, maxval in zip(expectednames, expectedmeans, expectedmax): |
| 91 | + graph = histos[varname] |
| 92 | + self.assertIsInstance(graph, ROOT.TH1D) |
| 93 | + self.assertAlmostEqual(graph.GetMaximum(), maxval) |
| 94 | + self.assertAlmostEqual(graph.GetMean(), mean) |
| 95 | + |
| 96 | + def test_varyfiltersum(self): |
| 97 | + df = Spark.RDataFrame(10, sparkcontext=self.sc, npartitions=2).Define("x", "1") |
| 98 | + df_sum = df.Vary("x", "ROOT::RVecI{-1*x, 2*x}", ("down", "up"), "myvariation").Filter("x > 0").Sum("x") |
| 99 | + |
| 100 | + self.assertAlmostEqual(df_sum.GetValue(), 10) |
| 101 | + |
| 102 | + sums = DistRDF.VariationsFor(df_sum) |
| 103 | + |
| 104 | + expectednames = ["nominal", "myvariation:down", "myvariation:up"] |
| 105 | + expectedsums = [10, 0, 20] |
| 106 | + for varname, val in zip(expectednames, expectedsums): |
| 107 | + self.assertAlmostEqual(sums[varname], val) |
| 108 | + |
| 109 | + |
| 110 | +if __name__ == "__main__": |
| 111 | + unittest.main(argv=[__file__]) |
0 commit comments