Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add a mechanism to set the chosen case for SwitchProducer #35510

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions FWCore/ParameterSet/python/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,11 @@ def globalReplace(self,label,new):
if not hasattr(self,label):
raise LookupError("process has no item of label "+label)
setattr(self,label,new)
def setSwitchProducerCaseForAll(self, switchProducerType, case):
"""Set the chosen case to 'case' for all SwitchProducers ot type 'SwitchProducerType'"""
for sp in self.__switchproducers.values():
if sp.__class__.__name__ == switchProducerType:
sp.setCase_(case)
def _insertInto(self, parameterSet, itemDict):
for name,value in itemDict.items():
value.insertInto(parameterSet, name)
Expand Down Expand Up @@ -1850,6 +1855,16 @@ def __init__(self, **kargs):
), **kargs)
specialImportRegistry.registerSpecialImportForType(SwitchProducerTest, "from test import SwitchProducerTest")

class SwitchProducerTest2(SwitchProducer):
def __init__(self, **kargs):
super(SwitchProducerTest2,self).__init__(
dict(
test1 = lambda: (True, -7),
test2 = lambda: (True, -5),
test30 = lambda: (True, -10),
), **kargs)
specialImportRegistry.registerSpecialImportForType(SwitchProducerTest2, "from test import SwitchProducerTest2")

class TestModuleCommand(unittest.TestCase):
def setUp(self):
"""Nothing to do """
Expand Down Expand Up @@ -3069,6 +3084,50 @@ def testSwitchProducer(self):
self.assertEqual((True,"EDAlias"), p.values["sp@test2"][1].values["@module_edm_type"])
self.assertEqual((True,"Bar"), p.values["sp@test2"][1].values["a"][1][0].values["type"])

# Forcing the choice
proc = Process("test")
proc.sp1 = SwitchProducerTest(test1 = EDProducer("Foo1"), test3 = EDProducer("Fred1"))
proc.sp2 = SwitchProducerTest(test1 = EDProducer("Foo2"), test2 = EDProducer("Bar2"), test3 = EDProducer("Fred2"))
proc.sp10 = SwitchProducerTest2(test1 = EDProducer("Foo10"), test2 = EDProducer("Bar10"), test30 = EDProducer("Wilma10"))
proc.t = Task(proc.sp1, proc.sp2, proc.sp10)
proc.p = Path(proc.t)
self.assertEqual(proc.sp1._getProducer().type_(), "Fred1")
self.assertEqual(proc.sp2._getProducer().type_(), "Fred2")
self.assertEqual(proc.sp10._getProducer().type_(), "Bar10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test1")
self.assertEqual(proc.sp1._getProducer().type_(), "Foo1")
self.assertEqual(proc.sp2._getProducer().type_(), "Foo2")
self.assertEqual(proc.sp10._getProducer().type_(), "Bar10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest2", "test30")
self.assertEqual(proc.sp1._getProducer().type_(), "Foo1")
self.assertEqual(proc.sp2._getProducer().type_(), "Foo2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2")
self.assertRaises(RuntimeError, proc.sp1._getProducer)
self.assertEqual(proc.sp2._getProducer().type_(), "Bar2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
proc.setSwitchProducerCaseForAll("SwitchProducerTest", None)
self.assertEqual(proc.sp1._getProducer().type_(), "Fred1")
self.assertEqual(proc.sp2._getProducer().type_(), "Fred2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
p = TestMakePSet()
proc.fillProcessDesc(p)
self.assertEqual((False, "sp1@test3"), p.values["sp1"][1].values["@chosen_case"])
self.assertEqual((False, "sp2@test3"), p.values["sp2"][1].values["@chosen_case"])
self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"])
proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2")
p = TestMakePSet()
self.assertRaises(RuntimeError, proc.fillProcessDesc, p)
proc.sp1.setCase_("test1")
self.assertEqual(proc.sp1._getProducer().type_(), "Foo1")
self.assertEqual(proc.sp2._getProducer().type_(), "Bar2")
self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10")
p = TestMakePSet()
proc.fillProcessDesc(p)
self.assertEqual((False, "sp1@test1"), p.values["sp1"][1].values["@chosen_case"])
self.assertEqual((False, "sp2@test2"), p.values["sp2"][1].values["@chosen_case"])
self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"])

def testPrune(self):
p = Process("test")
p.a = EDAnalyzer("MyAnalyzer")
Expand Down
104 changes: 98 additions & 6 deletions FWCore/ParameterSet/python/Modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def nameInProcessDesc_(self, myname):
# SwitchProducer to be pickleable.
def _switch_cpu():
return (True, 1)
def _switch_gpu_test():
return (True, 2)

class SwitchProducer(EDProducer):
"""This purpose class is to provide a switch of EDProducers for a single module/product label.
Expand Down Expand Up @@ -245,21 +247,37 @@ def __init__(self, **kargs):
produced with one of the producers. It would be good if their
output product types and instance names would be the same (or very
close).

The decision can be "forced" with 'foo.setCase_("case1")' and
unset with 'foo.setCase_(None)'. This setting persists through
dumpPython() and picking.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dumpPython() and picking.
dumpPython() and pickling.

"""
def __init__(self, caseFunctionDict, **kargs):
super(SwitchProducer,self).__init__(None)
self._caseFunctionDict = copy.copy(caseFunctionDict)
self.__setParameters(kargs)
self._isModified = False
self._setCase = None

@staticmethod
def getCpu():
"""Returns a function that returns the priority for a CPU "computing device". Intended to be used by deriving classes."""
return _switch_cpu

def setCase_(self, case):
if case is not None and case not in self._caseFunctionDict:
raise ValueError("Case '{}' is not allowed (allowed ones are {})".format(case, ",".join(self._caseFunctionDict.keys())))
self._setCase = case
return self

def _chooseCase(self):
"""Returns the name of the chosen case."""
cases = self.parameterNames_()
if self._setCase is not None:
if self._setCase not in cases:
raise RuntimeError("Case '{}' has been set, but there is no EDProducer/EDAlias for that".format(self._setCase))
return self._setCase

bestCase = None
for case in cases:
(enabled, priority) = self._caseFunctionDict[case]()
Expand All @@ -280,9 +298,9 @@ def __typeIsValid(typ):
def __addParameter(self, name, value):
if not self.__typeIsValid(value):
raise TypeError(name+" does not already exist, so it can only be set to a cms.EDProducer or cms.EDAlias")
if name not in self._caseFunctionDict:
elif name not in self._caseFunctionDict:
raise ValueError("Case '%s' is not allowed (allowed ones are %s)" % (name, ",".join(self._caseFunctionDict.keys())))
if name in self.__dict__:
elif name in self.__dict__:
message = "Duplicate insert of member " + name
message += "\nThe original parameters are:\n"
message += self.dumpPython() + '\n'
Expand Down Expand Up @@ -340,6 +358,7 @@ def clone(self, **params):
returnValue.__init__(**myparams)
returnValue._isModified = False
returnValue._isFrozen = False
returnValue._setCase = self._setCase
saveOrigin(returnValue, 1)
return returnValue

Expand All @@ -355,7 +374,10 @@ def dumpPython(self, options=PrintOptions()):
if result[-1] == ",":
result = result.rstrip(",")
options.unindent()
result += "\n)\n"
result += "\n)"
if self._setCase is not None:
result += ".setCase_('{}')".format(self._setCase)
result += "\n"
return result

def directDependencies(self):
Expand Down Expand Up @@ -444,7 +466,10 @@ def __init__(self, **kargs):
class SwitchProducerPickleable(SwitchProducer):
def __init__(self, **kargs):
super(SwitchProducerPickleable,self).__init__(
dict(cpu = SwitchProducer.getCpu()), **kargs)
dict(
cpu = SwitchProducer.getCpu(),
gpu = _switch_gpu_test,
), **kargs)

class TestModules(unittest.TestCase):
def testEDAnalyzer(self):
Expand Down Expand Up @@ -577,6 +602,21 @@ def testSwitchProducer(self):
sp = SwitchProducerTest1Dis(test1 = EDProducer("Bar"))
self.assertRaises(RuntimeError, sp._getProducer)

# Case forcing
sp = SwitchProducerTest(test1 = EDProducer("Foo"), test2 = EDProducer("Bar"))
sp.setCase_("test1")
self.assertEqual(sp._getProducer().type_(), "Foo")
sp.setCase_("test2")
self.assertEqual(sp._getProducer().type_(), "Bar")
sp.setCase_("test1")
self.assertEqual(sp._getProducer().type_(), "Foo")
sp.setCase_(None)
self.assertEqual(sp._getProducer().type_(), "Bar")
self.assertRaises(ValueError, sp.setCase_, "nonexistent")
sp.setCase_("test3")
self.assertRaises(RuntimeError, sp._getProducer)


# Mofications
from .Types import int32, string, PSet
sp = SwitchProducerTest(test1 = EDProducer("Foo",
Expand All @@ -593,7 +633,7 @@ def testSwitchProducer(self):
self.assertEqual(cl.test2.type_(), "Bar")
self.assertEqual(cl.test2.aa.value(), 11)
self.assertEqual(cl.test2.bb.cc.value(), 12)
self.assertEqual(sp._getProducer().type_(), "Bar")
self.assertEqual(cl._getProducer().type_(), "Bar")
# Modify clone
cl.test1.a = 3
self.assertEqual(cl.test1.a.value(), 3)
Expand Down Expand Up @@ -625,6 +665,13 @@ def _assignSwitchProducer():
self.assertRaises(TypeError, lambda: sp.clone(test1 = EDAnalyzer("Foo")))
self.assertRaises(TypeError, lambda: sp.clone(test1 = SwitchProducerTest(test1 = SwitchProducerTest(test1 = EDProducer("Foo")))))

sp.setCase_("test1")
cl = sp.clone()
self.assertEqual(cl._getProducer().type_(), "Foo")
cl.setCase_(None)
self.assertEqual(cl._getProducer().type_(), "Bar")
self.assertEqual(sp._getProducer().type_(), "Foo")

# Dump
sp = SwitchProducerTest(test2 = EDProducer("Foo",
a = int32(1),
Expand All @@ -648,12 +695,57 @@ def _assignSwitchProducer():
)
)
""")

sp.setCase_("test1")
self.assertEqual(sp.dumpPython(),
"""SwitchProducerTest(
test1 = cms.EDProducer("Bar",
aa = cms.int32(11),
bb = cms.PSet(
cc = cms.int32(12)
)
),
test2 = cms.EDProducer("Foo",
a = cms.int32(1),
b = cms.PSet(
c = cms.int32(2)
)
)
).setCase_('test1')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumping the "set case" like this looks a bit weird (setCase_() returns self so it should work), but doing it along

foo = SwitchProducerTest(
...
)
foo.setCase_(...)

would require knowing the module label (which is not available here), or extending the functionality of specialImportRegistry (would need to think more what exactly that would imply).

I also considered

SwitchProducerTest(
    test1 = ...,
    test2 = ....,
    setCase_ = cms.untracked.string('test1')
)

but that appeared to require adding special cases for setCase_ in many places complicating the implementation even more. Also given that this setting should(?) be somewhat special case, and that it is not passed on to C++, making it a function call felt better choice than a "real parameter".

""")
sp.setCase_(None)
self.assertEqual(sp.dumpPython(),
"""SwitchProducerTest(
test1 = cms.EDProducer("Bar",
aa = cms.int32(11),
bb = cms.PSet(
cc = cms.int32(12)
)
),
test2 = cms.EDProducer("Foo",
a = cms.int32(1),
b = cms.PSet(
c = cms.int32(2)
)
)
)
""")

# Pickle
import pickle
sp = SwitchProducerPickleable(cpu = EDProducer("Foo"))
sp = SwitchProducerPickleable(cpu = EDProducer("Foo"), gpu=EDProducer("Bar"))
pkl = pickle.dumps(sp)
unpkl = pickle.loads(pkl)
self.assertEqual(unpkl.cpu.type_(), "Foo")
self.assertEqual(unpkl.gpu.type_(), "Bar")
self.assertEqual(unpkl._getProducer().type_(), "Bar")

sp.setCase_("cpu")
pkl = pickle.dumps(sp)
unpkl = pickle.loads(pkl)
self.assertEqual(unpkl.cpu.type_(), "Foo")
self.assertEqual(unpkl.gpu.type_(), "Bar")
self.assertEqual(unpkl._getProducer().type_(), "Foo")

def testSwithProducerWithAlias(self):
# Constructor
Expand Down