3
3
import copy
4
4
from contextlib import contextmanager
5
5
from functools import singledispatch
6
- from typing import List , Optional
6
+ from typing import TYPE_CHECKING , List , Optional
7
7
8
8
from aesara .graph .basic import Variable
9
9
from aesara .graph .utils import add_tag_trace
10
10
from aesara .link .basic import Container
11
11
from aesara .link .c .type import generic
12
12
13
13
14
+ if TYPE_CHECKING :
15
+ from aesara .graph .type import Type
16
+
17
+
14
18
__SHARED_CONTEXT__ : Optional [List [Variable ]] = None
15
19
16
20
@@ -30,14 +34,39 @@ def collect_new_shareds():
30
34
class SharedVariable (Variable ):
31
35
"""Variable that is shared between compiled functions."""
32
36
33
- container : Optional [Container ] = None
34
- """
35
- A container to use for this SharedVariable when it is an implicit
36
- function parameter.
37
- """
37
+ def __init__ (
38
+ self ,
39
+ type : "Type" ,
40
+ value ,
41
+ strict : bool ,
42
+ allow_downcast = None ,
43
+ container : Optional [Container ] = None ,
44
+ name : Optional [str ] = None ,
45
+ ):
46
+ r"""
47
+ Parameters
48
+ ----------
49
+ type
50
+ The `Type` for this variable (see `Variable`).
51
+ value
52
+ A value to associate with this variable (a new container will be
53
+ created).
54
+ strict
55
+ ``True`` means that values assigned to this variable will not be
56
+ cast or copied, so they must have the correct `Type`\s.
57
+ allow_downcast
58
+ Only applies if `strict` is ``False``.
59
+ ``True`` means that the assigned value can lose precision when cast
60
+ during assignment. ``None`` means that only down-casting of a Python
61
+ float to a scalar ``floatX`` is allowed.
62
+ container
63
+ The container to use for this variable. Illegal to pass this as well as
64
+ a value.
65
+ name
66
+ The name for this variable (see `Variable`).
38
67
39
- def __init__ ( self , name , type , value , strict , allow_downcast = None , container = None ):
40
- super ().__init__ (type = type , name = name , owner = None , index = None )
68
+ """
69
+ super ().__init__ (type = type , owner = None , index = None , name = name )
41
70
42
71
if container is not None :
43
72
self .container = container
@@ -107,26 +136,6 @@ def set_value(self, new_value, borrow=False):
107
136
def get_test_value (self ):
108
137
return self .get_value (borrow = True , return_internal_type = True )
109
138
110
- def zero (self , borrow = False ):
111
- """
112
- Set the values of a shared variable to 0.
113
-
114
- Parameters
115
- ----------
116
- borrow : bbol
117
- True to modify the value of a shared variable directly by using
118
- its previous value. Potentially this can cause problems
119
- regarding to the aliased memory.
120
-
121
- Changes done with this function will be visible to all functions using
122
- this SharedVariable.
123
-
124
- """
125
- if borrow :
126
- self .container .value [...] = 0
127
- else :
128
- self .container .value = 0 * self .container .value
129
-
130
139
def clone (self , ** kwargs ):
131
140
name = kwargs .get ("name" , self .name )
132
141
cp = self .__class__ (
@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
209
218
return SharedVariable (
210
219
type = generic ,
211
220
value = value ,
212
- name = name ,
213
221
strict = strict ,
214
222
allow_downcast = allow_downcast ,
223
+ name = name ,
215
224
)
0 commit comments