Skip to content

Commit

Permalink
Add Process.setSwitchProducerCaseForAll() function
Browse files Browse the repository at this point in the history
  • Loading branch information
makortel committed Oct 1, 2021
1 parent 7a09c3e commit e0da71a
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions FWCore/ParameterSet/python/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,11 @@ def globalReplace(self,label,new):
if not hasattr(self,label):
raise LookupError("process has no item of label "+label)
setattr(self,label,new)
def setSwitchProducerCaseForAll(self, switchProducerType, case):
"""Set the chosen case to 'case' for all SwitchProducers ot type 'SwitchProducerType'"""
for sp in self.__switchproducers.values():
if sp.__class__.__name__ == switchProducerType:
sp.setCase_(case)
def _insertInto(self, parameterSet, itemDict):
for name,value in itemDict.items():
value.insertInto(parameterSet, name)
Expand Down Expand Up @@ -1850,6 +1855,16 @@ def __init__(self, **kargs):
), **kargs)
specialImportRegistry.registerSpecialImportForType(SwitchProducerTest, "from test import SwitchProducerTest")

class SwitchProducerTest2(SwitchProducer):
def __init__(self, **kargs):
super(SwitchProducerTest2,self).__init__(
dict(
test1 = lambda: (True, -7),
test2 = lambda: (True, -5),
test30 = lambda: (True, -10),
), **kargs)
specialImportRegistry.registerSpecialImportForType(SwitchProducerTest2, "from test import SwitchProducerTest2")

class TestModuleCommand(unittest.TestCase):
def setUp(self):
"""Nothing to do """
Expand Down Expand Up @@ -3069,6 +3084,50 @@ def testSwitchProducer(self):
self.assertEqual((True,"EDAlias"), p.values["sp@test2"][1].values["@module_edm_type"])
self.assertEqual((True,"Bar"), p.values["sp@test2"][1].values["a"][1][0].values["type"])

# Forcing the choice
proc = Process("test")
proc.sp1 = SwitchProducerTest(test1 = EDProducer("Foo1"), test3 = EDProducer("Fred1"))
proc.sp2 = SwitchProducerTest(test1 = EDProducer("Foo2"), test2 = EDProducer("Bar2"), test3 = EDProducer("Fred2"))
proc.sp10 = SwitchProducerTest2(test1 = EDProducer("Foo10"), test2 = EDProducer("Bar10"), test30 = EDProducer("Wilma10"))
proc.t = Task(proc.sp1, proc.sp2, proc.sp10)
proc.p = Path(proc.t)
self.assertEqual(proc.sp1._getProducer().type_(), "Fred1")
self.assertEqual(proc.sp2._getProducer().type_(), "Fred2")
self.assertEqual(proc.sp10._getProducer().type_(), "Bar10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test1")
self.assertEqual(proc.sp1._getProducer().type_(), "Foo1")
self.assertEqual(proc.sp2._getProducer().type_(), "Foo2")
self.assertEqual(proc.sp10._getProducer().type_(), "Bar10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest2", "test30")
self.assertEqual(proc.sp1._getProducer().type_(), "Foo1")
self.assertEqual(proc.sp2._getProducer().type_(), "Foo2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2")
self.assertRaises(RuntimeError, proc.sp1._getProducer)
self.assertEqual(proc.sp2._getProducer().type_(), "Bar2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest", None)
self.assertEqual(proc.sp1._getProducer().type_(), "Fred1")
self.assertEqual(proc.sp2._getProducer().type_(), "Fred2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
p = TestMakePSet()
proc.fillProcessDesc(p)
self.assertEqual((False, "sp1@test3"), p.values["sp1"][1].values["@chosen_case"])
self.assertEqual((False, "sp2@test3"), p.values["sp2"][1].values["@chosen_case"])
self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"])
proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2")
p = TestMakePSet()
self.assertRaises(RuntimeError, proc.fillProcessDesc, p)
proc.sp1.setCase_("test1")
self.assertEqual(proc.sp1._getProducer().type_(), "Foo1")
self.assertEqual(proc.sp2._getProducer().type_(), "Bar2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
p = TestMakePSet()
proc.fillProcessDesc(p)
self.assertEqual((False, "sp1@test1"), p.values["sp1"][1].values["@chosen_case"])
self.assertEqual((False, "sp2@test2"), p.values["sp2"][1].values["@chosen_case"])
self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"])

def testPrune(self):
p = Process("test")
p.a = EDAnalyzer("MyAnalyzer")
Expand Down

0 comments on commit e0da71a

Please sign in to comment.