diff --git a/FWCore/ParameterSet/python/Config.py b/FWCore/ParameterSet/python/Config.py index c6d198ba04ed6..e26c65c7c0be6 100644 --- a/FWCore/ParameterSet/python/Config.py +++ b/FWCore/ParameterSet/python/Config.py @@ -1093,6 +1093,11 @@ def globalReplace(self,label,new): if not hasattr(self,label): raise LookupError("process has no item of label "+label) setattr(self,label,new) + def setSwitchProducerCaseForAll(self, switchProducerType, case): + """Set the chosen case to 'case' for all SwitchProducers ot type 'SwitchProducerType'""" + for sp in self.__switchproducers.values(): + if sp.__class__.__name__ == switchProducerType: + sp.setCase_(case) def _insertInto(self, parameterSet, itemDict): for name,value in itemDict.items(): value.insertInto(parameterSet, name) @@ -1850,6 +1855,16 @@ def __init__(self, **kargs): ), **kargs) specialImportRegistry.registerSpecialImportForType(SwitchProducerTest, "from test import SwitchProducerTest") + class SwitchProducerTest2(SwitchProducer): + def __init__(self, **kargs): + super(SwitchProducerTest2,self).__init__( + dict( + test1 = lambda: (True, -7), + test2 = lambda: (True, -5), + test30 = lambda: (True, -10), + ), **kargs) + specialImportRegistry.registerSpecialImportForType(SwitchProducerTest2, "from test import SwitchProducerTest2") + class TestModuleCommand(unittest.TestCase): def setUp(self): """Nothing to do """ @@ -3069,6 +3084,50 @@ def testSwitchProducer(self): self.assertEqual((True,"EDAlias"), p.values["sp@test2"][1].values["@module_edm_type"]) self.assertEqual((True,"Bar"), p.values["sp@test2"][1].values["a"][1][0].values["type"]) + # Forcing the choice + proc = Process("test") + proc.sp1 = SwitchProducerTest(test1 = EDProducer("Foo1"), test3 = EDProducer("Fred1")) + proc.sp2 = SwitchProducerTest(test1 = EDProducer("Foo2"), test2 = EDProducer("Bar2"), test3 = EDProducer("Fred2")) + proc.sp10 = SwitchProducerTest2(test1 = EDProducer("Foo10"), test2 = EDProducer("Bar10"), test30 = EDProducer("Wilma10")) + proc.t = Task(proc.sp1, proc.sp2, proc.sp10) + proc.p = Path(proc.t) + self.assertEqual(proc.sp1._getProducer().type_(), "Fred1") + self.assertEqual(proc.sp2._getProducer().type_(), "Fred2") + self.assertEqual(proc.sp10._getProducer().type_(), "Bar10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test1") + self.assertEqual(proc.sp1._getProducer().type_(), "Foo1") + self.assertEqual(proc.sp2._getProducer().type_(), "Foo2") + self.assertEqual(proc.sp10._getProducer().type_(), "Bar10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest2", "test30") + self.assertEqual(proc.sp1._getProducer().type_(), "Foo1") + self.assertEqual(proc.sp2._getProducer().type_(), "Foo2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2") + self.assertRaises(RuntimeError, proc.sp1._getProducer) + self.assertEqual(proc.sp2._getProducer().type_(), "Bar2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest", None) + self.assertEqual(proc.sp1._getProducer().type_(), "Fred1") + self.assertEqual(proc.sp2._getProducer().type_(), "Fred2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + p = TestMakePSet() + proc.fillProcessDesc(p) + self.assertEqual((False, "sp1@test3"), p.values["sp1"][1].values["@chosen_case"]) + self.assertEqual((False, "sp2@test3"), p.values["sp2"][1].values["@chosen_case"]) + self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"]) + proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2") + p = TestMakePSet() + self.assertRaises(RuntimeError, proc.fillProcessDesc, p) + proc.sp1.setCase_("test1") + self.assertEqual(proc.sp1._getProducer().type_(), "Foo1") + self.assertEqual(proc.sp2._getProducer().type_(), "Bar2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + p = TestMakePSet() + proc.fillProcessDesc(p) + self.assertEqual((False, "sp1@test1"), p.values["sp1"][1].values["@chosen_case"]) + self.assertEqual((False, "sp2@test2"), p.values["sp2"][1].values["@chosen_case"]) + self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"]) + def testPrune(self): p = Process("test") p.a = EDAnalyzer("MyAnalyzer") diff --git a/FWCore/ParameterSet/python/Modules.py b/FWCore/ParameterSet/python/Modules.py index fa0a9b899e066..630e37ffc4d08 100644 --- a/FWCore/ParameterSet/python/Modules.py +++ b/FWCore/ParameterSet/python/Modules.py @@ -216,6 +216,8 @@ def nameInProcessDesc_(self, myname): # SwitchProducer to be pickleable. def _switch_cpu(): return (True, 1) +def _switch_gpu_test(): + return (True, 2) class SwitchProducer(EDProducer): """This purpose class is to provide a switch of EDProducers for a single module/product label. @@ -245,21 +247,37 @@ def __init__(self, **kargs): produced with one of the producers. It would be good if their output product types and instance names would be the same (or very close). + + The decision can be "forced" with 'foo.setCase_("case1")' and + unset with 'foo.setCase_(None)'. This setting persists through + dumpPython() and picking. """ def __init__(self, caseFunctionDict, **kargs): super(SwitchProducer,self).__init__(None) self._caseFunctionDict = copy.copy(caseFunctionDict) self.__setParameters(kargs) self._isModified = False + self._setCase = None @staticmethod def getCpu(): """Returns a function that returns the priority for a CPU "computing device". Intended to be used by deriving classes.""" return _switch_cpu + def setCase_(self, case): + if case is not None and case not in self._caseFunctionDict: + raise ValueError("Case '{}' is not allowed (allowed ones are {})".format(case, ",".join(self._caseFunctionDict.keys()))) + self._setCase = case + return self + def _chooseCase(self): """Returns the name of the chosen case.""" cases = self.parameterNames_() + if self._setCase is not None: + if self._setCase not in cases: + raise RuntimeError("Case '{}' has been set, but there is no EDProducer/EDAlias for that".format(self._setCase)) + return self._setCase + bestCase = None for case in cases: (enabled, priority) = self._caseFunctionDict[case]() @@ -280,9 +298,9 @@ def __typeIsValid(typ): def __addParameter(self, name, value): if not self.__typeIsValid(value): raise TypeError(name+" does not already exist, so it can only be set to a cms.EDProducer or cms.EDAlias") - if name not in self._caseFunctionDict: + elif name not in self._caseFunctionDict: raise ValueError("Case '%s' is not allowed (allowed ones are %s)" % (name, ",".join(self._caseFunctionDict.keys()))) - if name in self.__dict__: + elif name in self.__dict__: message = "Duplicate insert of member " + name message += "\nThe original parameters are:\n" message += self.dumpPython() + '\n' @@ -340,6 +358,7 @@ def clone(self, **params): returnValue.__init__(**myparams) returnValue._isModified = False returnValue._isFrozen = False + returnValue._setCase = self._setCase saveOrigin(returnValue, 1) return returnValue @@ -355,7 +374,10 @@ def dumpPython(self, options=PrintOptions()): if result[-1] == ",": result = result.rstrip(",") options.unindent() - result += "\n)\n" + result += "\n)" + if self._setCase is not None: + result += ".setCase_('{}')".format(self._setCase) + result += "\n" return result def directDependencies(self): @@ -444,7 +466,10 @@ def __init__(self, **kargs): class SwitchProducerPickleable(SwitchProducer): def __init__(self, **kargs): super(SwitchProducerPickleable,self).__init__( - dict(cpu = SwitchProducer.getCpu()), **kargs) + dict( + cpu = SwitchProducer.getCpu(), + gpu = _switch_gpu_test, + ), **kargs) class TestModules(unittest.TestCase): def testEDAnalyzer(self): @@ -577,6 +602,21 @@ def testSwitchProducer(self): sp = SwitchProducerTest1Dis(test1 = EDProducer("Bar")) self.assertRaises(RuntimeError, sp._getProducer) + # Case forcing + sp = SwitchProducerTest(test1 = EDProducer("Foo"), test2 = EDProducer("Bar")) + sp.setCase_("test1") + self.assertEqual(sp._getProducer().type_(), "Foo") + sp.setCase_("test2") + self.assertEqual(sp._getProducer().type_(), "Bar") + sp.setCase_("test1") + self.assertEqual(sp._getProducer().type_(), "Foo") + sp.setCase_(None) + self.assertEqual(sp._getProducer().type_(), "Bar") + self.assertRaises(ValueError, sp.setCase_, "nonexistent") + sp.setCase_("test3") + self.assertRaises(RuntimeError, sp._getProducer) + + # Mofications from .Types import int32, string, PSet sp = SwitchProducerTest(test1 = EDProducer("Foo", @@ -593,7 +633,7 @@ def testSwitchProducer(self): self.assertEqual(cl.test2.type_(), "Bar") self.assertEqual(cl.test2.aa.value(), 11) self.assertEqual(cl.test2.bb.cc.value(), 12) - self.assertEqual(sp._getProducer().type_(), "Bar") + self.assertEqual(cl._getProducer().type_(), "Bar") # Modify clone cl.test1.a = 3 self.assertEqual(cl.test1.a.value(), 3) @@ -625,6 +665,13 @@ def _assignSwitchProducer(): self.assertRaises(TypeError, lambda: sp.clone(test1 = EDAnalyzer("Foo"))) self.assertRaises(TypeError, lambda: sp.clone(test1 = SwitchProducerTest(test1 = SwitchProducerTest(test1 = EDProducer("Foo"))))) + sp.setCase_("test1") + cl = sp.clone() + self.assertEqual(cl._getProducer().type_(), "Foo") + cl.setCase_(None) + self.assertEqual(cl._getProducer().type_(), "Bar") + self.assertEqual(sp._getProducer().type_(), "Foo") + # Dump sp = SwitchProducerTest(test2 = EDProducer("Foo", a = int32(1), @@ -648,12 +695,57 @@ def _assignSwitchProducer(): ) ) """) + + sp.setCase_("test1") + self.assertEqual(sp.dumpPython(), +"""SwitchProducerTest( + test1 = cms.EDProducer("Bar", + aa = cms.int32(11), + bb = cms.PSet( + cc = cms.int32(12) + ) + ), + test2 = cms.EDProducer("Foo", + a = cms.int32(1), + b = cms.PSet( + c = cms.int32(2) + ) + ) +).setCase_('test1') +""") + sp.setCase_(None) + self.assertEqual(sp.dumpPython(), +"""SwitchProducerTest( + test1 = cms.EDProducer("Bar", + aa = cms.int32(11), + bb = cms.PSet( + cc = cms.int32(12) + ) + ), + test2 = cms.EDProducer("Foo", + a = cms.int32(1), + b = cms.PSet( + c = cms.int32(2) + ) + ) +) +""") + # Pickle import pickle - sp = SwitchProducerPickleable(cpu = EDProducer("Foo")) + sp = SwitchProducerPickleable(cpu = EDProducer("Foo"), gpu=EDProducer("Bar")) + pkl = pickle.dumps(sp) + unpkl = pickle.loads(pkl) + self.assertEqual(unpkl.cpu.type_(), "Foo") + self.assertEqual(unpkl.gpu.type_(), "Bar") + self.assertEqual(unpkl._getProducer().type_(), "Bar") + + sp.setCase_("cpu") pkl = pickle.dumps(sp) unpkl = pickle.loads(pkl) self.assertEqual(unpkl.cpu.type_(), "Foo") + self.assertEqual(unpkl.gpu.type_(), "Bar") + self.assertEqual(unpkl._getProducer().type_(), "Foo") def testSwithProducerWithAlias(self): # Constructor