-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Add a mechanism to set the chosen case for SwitchProducer #35510
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -216,6 +216,8 @@ def nameInProcessDesc_(self, myname): | |
# SwitchProducer to be pickleable. | ||
def _switch_cpu(): | ||
return (True, 1) | ||
def _switch_gpu_test(): | ||
return (True, 2) | ||
|
||
class SwitchProducer(EDProducer): | ||
"""This purpose class is to provide a switch of EDProducers for a single module/product label. | ||
|
@@ -245,21 +247,37 @@ def __init__(self, **kargs): | |
produced with one of the producers. It would be good if their | ||
output product types and instance names would be the same (or very | ||
close). | ||
|
||
The decision can be "forced" with 'foo.setCase_("case1")' and | ||
unset with 'foo.setCase_(None)'. This setting persists through | ||
dumpPython() and picking. | ||
""" | ||
def __init__(self, caseFunctionDict, **kargs): | ||
super(SwitchProducer,self).__init__(None) | ||
self._caseFunctionDict = copy.copy(caseFunctionDict) | ||
self.__setParameters(kargs) | ||
self._isModified = False | ||
self._setCase = None | ||
|
||
@staticmethod | ||
def getCpu(): | ||
"""Returns a function that returns the priority for a CPU "computing device". Intended to be used by deriving classes.""" | ||
return _switch_cpu | ||
|
||
def setCase_(self, case): | ||
if case is not None and case not in self._caseFunctionDict: | ||
raise ValueError("Case '{}' is not allowed (allowed ones are {})".format(case, ",".join(self._caseFunctionDict.keys()))) | ||
self._setCase = case | ||
return self | ||
|
||
def _chooseCase(self): | ||
"""Returns the name of the chosen case.""" | ||
cases = self.parameterNames_() | ||
if self._setCase is not None: | ||
if self._setCase not in cases: | ||
raise RuntimeError("Case '{}' has been set, but there is no EDProducer/EDAlias for that".format(self._setCase)) | ||
return self._setCase | ||
|
||
bestCase = None | ||
for case in cases: | ||
(enabled, priority) = self._caseFunctionDict[case]() | ||
|
@@ -280,9 +298,9 @@ def __typeIsValid(typ): | |
def __addParameter(self, name, value): | ||
if not self.__typeIsValid(value): | ||
raise TypeError(name+" does not already exist, so it can only be set to a cms.EDProducer or cms.EDAlias") | ||
if name not in self._caseFunctionDict: | ||
elif name not in self._caseFunctionDict: | ||
raise ValueError("Case '%s' is not allowed (allowed ones are %s)" % (name, ",".join(self._caseFunctionDict.keys()))) | ||
if name in self.__dict__: | ||
elif name in self.__dict__: | ||
message = "Duplicate insert of member " + name | ||
message += "\nThe original parameters are:\n" | ||
message += self.dumpPython() + '\n' | ||
|
@@ -340,6 +358,7 @@ def clone(self, **params): | |
returnValue.__init__(**myparams) | ||
returnValue._isModified = False | ||
returnValue._isFrozen = False | ||
returnValue._setCase = self._setCase | ||
saveOrigin(returnValue, 1) | ||
return returnValue | ||
|
||
|
@@ -355,7 +374,10 @@ def dumpPython(self, options=PrintOptions()): | |
if result[-1] == ",": | ||
result = result.rstrip(",") | ||
options.unindent() | ||
result += "\n)\n" | ||
result += "\n)" | ||
if self._setCase is not None: | ||
result += ".setCase_('{}')".format(self._setCase) | ||
result += "\n" | ||
return result | ||
|
||
def directDependencies(self): | ||
|
@@ -444,7 +466,10 @@ def __init__(self, **kargs): | |
class SwitchProducerPickleable(SwitchProducer): | ||
def __init__(self, **kargs): | ||
super(SwitchProducerPickleable,self).__init__( | ||
dict(cpu = SwitchProducer.getCpu()), **kargs) | ||
dict( | ||
cpu = SwitchProducer.getCpu(), | ||
gpu = _switch_gpu_test, | ||
), **kargs) | ||
|
||
class TestModules(unittest.TestCase): | ||
def testEDAnalyzer(self): | ||
|
@@ -577,6 +602,21 @@ def testSwitchProducer(self): | |
sp = SwitchProducerTest1Dis(test1 = EDProducer("Bar")) | ||
self.assertRaises(RuntimeError, sp._getProducer) | ||
|
||
# Case forcing | ||
sp = SwitchProducerTest(test1 = EDProducer("Foo"), test2 = EDProducer("Bar")) | ||
sp.setCase_("test1") | ||
self.assertEqual(sp._getProducer().type_(), "Foo") | ||
sp.setCase_("test2") | ||
self.assertEqual(sp._getProducer().type_(), "Bar") | ||
sp.setCase_("test1") | ||
self.assertEqual(sp._getProducer().type_(), "Foo") | ||
sp.setCase_(None) | ||
self.assertEqual(sp._getProducer().type_(), "Bar") | ||
self.assertRaises(ValueError, sp.setCase_, "nonexistent") | ||
sp.setCase_("test3") | ||
self.assertRaises(RuntimeError, sp._getProducer) | ||
|
||
|
||
# Mofications | ||
from .Types import int32, string, PSet | ||
sp = SwitchProducerTest(test1 = EDProducer("Foo", | ||
|
@@ -593,7 +633,7 @@ def testSwitchProducer(self): | |
self.assertEqual(cl.test2.type_(), "Bar") | ||
self.assertEqual(cl.test2.aa.value(), 11) | ||
self.assertEqual(cl.test2.bb.cc.value(), 12) | ||
self.assertEqual(sp._getProducer().type_(), "Bar") | ||
self.assertEqual(cl._getProducer().type_(), "Bar") | ||
# Modify clone | ||
cl.test1.a = 3 | ||
self.assertEqual(cl.test1.a.value(), 3) | ||
|
@@ -625,6 +665,13 @@ def _assignSwitchProducer(): | |
self.assertRaises(TypeError, lambda: sp.clone(test1 = EDAnalyzer("Foo"))) | ||
self.assertRaises(TypeError, lambda: sp.clone(test1 = SwitchProducerTest(test1 = SwitchProducerTest(test1 = EDProducer("Foo"))))) | ||
|
||
sp.setCase_("test1") | ||
cl = sp.clone() | ||
self.assertEqual(cl._getProducer().type_(), "Foo") | ||
cl.setCase_(None) | ||
self.assertEqual(cl._getProducer().type_(), "Bar") | ||
self.assertEqual(sp._getProducer().type_(), "Foo") | ||
|
||
# Dump | ||
sp = SwitchProducerTest(test2 = EDProducer("Foo", | ||
a = int32(1), | ||
|
@@ -648,12 +695,57 @@ def _assignSwitchProducer(): | |
) | ||
) | ||
""") | ||
|
||
sp.setCase_("test1") | ||
self.assertEqual(sp.dumpPython(), | ||
"""SwitchProducerTest( | ||
test1 = cms.EDProducer("Bar", | ||
aa = cms.int32(11), | ||
bb = cms.PSet( | ||
cc = cms.int32(12) | ||
) | ||
), | ||
test2 = cms.EDProducer("Foo", | ||
a = cms.int32(1), | ||
b = cms.PSet( | ||
c = cms.int32(2) | ||
) | ||
) | ||
).setCase_('test1') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dumping the "set case" like this looks a bit weird ( foo = SwitchProducerTest(
...
)
foo.setCase_(...) would require knowing the module label (which is not available here), or extending the functionality of I also considered SwitchProducerTest(
test1 = ...,
test2 = ....,
setCase_ = cms.untracked.string('test1')
) but that appeared to require adding special cases for |
||
""") | ||
sp.setCase_(None) | ||
self.assertEqual(sp.dumpPython(), | ||
"""SwitchProducerTest( | ||
test1 = cms.EDProducer("Bar", | ||
aa = cms.int32(11), | ||
bb = cms.PSet( | ||
cc = cms.int32(12) | ||
) | ||
), | ||
test2 = cms.EDProducer("Foo", | ||
a = cms.int32(1), | ||
b = cms.PSet( | ||
c = cms.int32(2) | ||
) | ||
) | ||
) | ||
""") | ||
|
||
# Pickle | ||
import pickle | ||
sp = SwitchProducerPickleable(cpu = EDProducer("Foo")) | ||
sp = SwitchProducerPickleable(cpu = EDProducer("Foo"), gpu=EDProducer("Bar")) | ||
pkl = pickle.dumps(sp) | ||
unpkl = pickle.loads(pkl) | ||
self.assertEqual(unpkl.cpu.type_(), "Foo") | ||
self.assertEqual(unpkl.gpu.type_(), "Bar") | ||
self.assertEqual(unpkl._getProducer().type_(), "Bar") | ||
|
||
sp.setCase_("cpu") | ||
pkl = pickle.dumps(sp) | ||
unpkl = pickle.loads(pkl) | ||
self.assertEqual(unpkl.cpu.type_(), "Foo") | ||
self.assertEqual(unpkl.gpu.type_(), "Bar") | ||
self.assertEqual(unpkl._getProducer().type_(), "Foo") | ||
|
||
def testSwithProducerWithAlias(self): | ||
# Constructor | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.