Skip to content

Commit a5adbd2

Browse files
authored
Merge pull request #100 from neuro-ml/dev
Basic support of pickling
2 parents 5e06553 + 2d7009b commit a5adbd2

File tree

8 files changed

+58
-105
lines changed

8 files changed

+58
-105
lines changed

connectome/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.8.1'
1+
__version__ = '0.9.0'

connectome/containers/base.py

+3-49
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import warnings
32
from operator import itemgetter
43
from typing import Callable, Optional, Union
54

@@ -11,44 +10,15 @@
1110
from ..utils import NameSet, StringsLike, check_for_duplicates, node_to_dict
1211
from .context import BagContext, ChainContext, Context, NoContext, update_map
1312

14-
__all__ = 'Container', 'EdgesBag'
13+
__all__ = 'EdgesBag',
1514

1615
logger = logging.getLogger(__name__)
1716

1817

19-
class Container:
20-
def __init__(self):
21-
warnings.warn(
22-
'The container interface is deprecated and will be merged with `EdgesBag` soon',
23-
UserWarning, stacklevel=2
24-
)
25-
warnings.warn(
26-
'The container interface is deprecated and will be merged with `EdgesBag` soon',
27-
DeprecationWarning, stacklevel=2
28-
)
29-
30-
def wrap(self, container: 'EdgesBag') -> 'EdgesBag':
31-
raise NotImplementedError
32-
33-
3418
class EdgesBag:
3519
def __init__(self, inputs: Nodes, outputs: Nodes, edges: BoundEdges, context: Optional[Context], *,
36-
virtual_nodes: Optional[NameSet] = None, virtual: Optional[NameSet] = None,
37-
persistent_nodes: Optional[NameSet] = None, persistent: Optional[NameSet] = None,
38-
optional_nodes: Optional[NodeSet] = None, optional: Optional[NodeSet] = None):
39-
if virtual_nodes is not None:
40-
assert virtual is None
41-
warnings.warn('The "virtual_nodes" argument is deprecated. Use `virtual` instead', stacklevel=2)
42-
virtual = virtual_nodes
43-
if optional_nodes is not None:
44-
assert optional is None
45-
warnings.warn('The "optional_nodes" argument is deprecated. Use `optional` instead', stacklevel=2)
46-
optional = optional_nodes
47-
if persistent_nodes is not None:
48-
assert persistent is None
49-
warnings.warn('The "persistent_nodes" argument is deprecated. Use `persistent` instead', stacklevel=2)
50-
persistent = persistent_nodes
51-
20+
virtual: Optional[NameSet] = None, persistent: Optional[NameSet] = None,
21+
optional: Optional[NodeSet] = None):
5222
if virtual is None:
5323
virtual = set()
5424
if persistent is None:
@@ -103,22 +73,6 @@ def loopback(self, func: Callable, inputs: StringsLike, output: StringsLike) ->
10373
virtual=None, persistent=None, optional=state.optional | new_optionals,
10474
)
10575

106-
# TODO: deprecated
107-
@property
108-
def persistent_nodes(self):
109-
warnings.warn('This attribute is deprecated. Use `persistent` instead', stacklevel=2)
110-
return self.persistent
111-
112-
@property
113-
def optional_nodes(self):
114-
warnings.warn('This attribute is deprecated. Use `optional` instead', stacklevel=2)
115-
return self.optional
116-
117-
@property
118-
def virtual_nodes(self):
119-
warnings.warn('This attribute is deprecated. Use `virtual` instead', stacklevel=2)
120-
return self.virtual
121-
12276

12377
def normalize_bag(inputs: Nodes, outputs: Nodes, edges: BoundEdges, virtuals: NameSet, optionals: NodeSet,
12478
persistent_nodes: NameSet, allow_missing_inputs: bool = True):

connectome/engine/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from .compiler import GraphCompiler
33
from .edges import *
44
from .graph import Graph
5-
from .node_hash import ApplyHash, CustomHash, FilterHash, GraphHash, LeafHash, NodeHash, NodeHashes, TupleHash
5+
from .node_hash import ApplyHash, CustomHash, GraphHash, LeafHash, NodeHash, NodeHashes

connectome/engine/graph.py

-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def __init__(self, inputs: TreeNodes, output: TreeNode):
2222
self.output = output
2323
self.counts = counts
2424
self.__signature__ = signature
25-
# TODO: deprecate
26-
self.call = self.__call__
2725

2826
def __call__(*args, **kwargs):
2927
self, *args = args

connectome/engine/node_hash.py

-44
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from typing import Any, Callable, Sequence, Tuple
32

43
NODE_TYPES = set()
@@ -75,46 +74,3 @@ def __init__(self, marker: Any, *children: NodeHash):
7574
(self.type, marker, *(h.value for h in children)),
7675
(self.type, marker, *children),
7776
)
78-
79-
80-
# TODO: deprecated
81-
class CompoundBase(NodeHash):
82-
type = None
83-
84-
def __init__(self, *children: NodeHash):
85-
warnings.warn('This interface is deprecated', DeprecationWarning)
86-
warnings.warn('This interface is deprecated', UserWarning)
87-
super().__init__(
88-
(self.type, *(h.value for h in children)),
89-
(self.type, *children),
90-
)
91-
92-
93-
class TupleHash(ApplyHash):
94-
type = -1
95-
96-
def __init__(self, *children: NodeHash):
97-
warnings.warn('This interface is deprecated. Use ApplyHash instead', DeprecationWarning)
98-
warnings.warn('This interface is deprecated. Use ApplyHash instead', UserWarning)
99-
super().__init__(tuple, *children)
100-
101-
102-
class FilterHash(ApplyHash):
103-
type = -2
104-
105-
def __init__(self, graph: GraphHash, values: NodeHash):
106-
warnings.warn('This interface is deprecated. Use ApplyHash instead', DeprecationWarning)
107-
warnings.warn('This interface is deprecated. Use ApplyHash instead', UserWarning)
108-
super().__init__(filter, graph, values)
109-
110-
111-
class MergeHash(CustomHash):
112-
type = -3
113-
114-
def __init__(self, *children: NodeHash):
115-
warnings.warn('This interface is deprecated. Use CustomHash instead', DeprecationWarning)
116-
warnings.warn('This interface is deprecated. Use CustomHash instead', UserWarning)
117-
super().__init__('connectome.SwitchEdge', *children)
118-
119-
120-
PrecomputeHash = NodeHash

connectome/interface/factory.py

-7
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,6 @@ def is_detectable(value):
5858
BUILTIN_DECORATORS = staticmethod, classmethod, property
5959

6060

61-
class FactoryLayer(CallableLayer):
62-
def __init__(self, container: EdgesBag, properties: Iterable[str], special_methods: Iterable[str]):
63-
warnings.warn('This class is deprecated', DeprecationWarning)
64-
warnings.warn('This class is deprecated', UserWarning)
65-
super().__init__(container, properties)
66-
67-
6861
class GraphFactory:
6962
layer_cls: Type[Layer] = CallableLayer
7063

connectome/interface/metaclasses.py

+13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ..layers import CallableLayer, Layer
55
from ..utils import MultiDict
6+
from .decorators import RuntimeAnnotation
67
from .factory import GraphFactory, SourceFactory, TransformFactory, add_from_mixins, add_quals, items_to_container
78

89
logger = logging.getLogger(__name__)
@@ -14,6 +15,16 @@ class APIMeta(type):
1415
def __prepare__(mcs, *args, **kwargs):
1516
return MultiDict()
1617

18+
def __getattr__(self, item):
19+
# we need this behaviour mostly to support pickling of functions defined inside the class
20+
try:
21+
value = self.__original__scope__[item]
22+
while isinstance(value, RuntimeAnnotation):
23+
value = value.__func__
24+
return value
25+
except KeyError:
26+
raise AttributeError(item) from None
27+
1728
def __new__(mcs, class_name, bases, namespace, **flags):
1829
if '__factory' in flags:
1930
factory = flags.pop('__factory')
@@ -46,6 +57,8 @@ def __new__(mcs, class_name, bases, namespace, **flags):
4657
add_from_mixins(namespace, bases)
4758
scope = factory.make_scope(class_name, namespace)
4859

60+
# TODO: need a standardized set of magic fields
61+
scope['__original__scope__'] = namespace
4962
return super().__new__(mcs, class_name, (main,), scope, **flags)
5063

5164

tests/test_interface/test_metaclasses.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import pickle
2+
13
import pytest
24

3-
from connectome import Mixin, Source, Transform
5+
from connectome import Mixin, Source, Transform, meta
46

57

68
def test_subclasses():
@@ -33,3 +35,40 @@ class G(C):
3335
pass
3436

3537
assert str(B()) == 'B()'
38+
39+
40+
class A(Transform):
41+
def x(x):
42+
return x
43+
44+
def _t(x):
45+
return x ** 2
46+
47+
def y(x, _t):
48+
return x + _t
49+
50+
51+
class B(Transform):
52+
@meta
53+
def ids():
54+
return '123'
55+
56+
def x(x):
57+
return x
58+
59+
def _t(x):
60+
return x ** 2
61+
62+
def y(x, _t):
63+
return x + _t
64+
65+
66+
def test_pickleable():
67+
a = A()
68+
b = B()
69+
assert a.x != A.x
70+
71+
for f in a.x, a.y, a._compile(dir(a)), b._compile(dir(b)):
72+
pickled = pickle.dumps(f)
73+
g = pickle.loads(pickled)
74+
assert f(0) == g(0)

0 commit comments

Comments
 (0)