Skip to content

Commit 3ad936f

Browse files
Refactor SharedVariable type and interface
1 parent bae26ef commit 3ad936f

File tree

7 files changed

+106
-56
lines changed

7 files changed

+106
-56
lines changed

aesara/compile/sharedvalue.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
import copy
44
from contextlib import contextmanager
55
from functools import singledispatch
6-
from typing import List, Optional
6+
from typing import TYPE_CHECKING, List, Optional
77

88
from aesara.graph.basic import Variable
99
from aesara.graph.utils import add_tag_trace
1010
from aesara.link.basic import Container
1111
from aesara.link.c.type import generic
1212

1313

14+
if TYPE_CHECKING:
15+
from aesara.graph.type import Type
16+
17+
1418
__SHARED_CONTEXT__: Optional[List[Variable]] = None
1519

1620

@@ -30,14 +34,39 @@ def collect_new_shareds():
3034
class SharedVariable(Variable):
3135
"""Variable that is shared between compiled functions."""
3236

33-
container: Optional[Container] = None
34-
"""
35-
A container to use for this SharedVariable when it is an implicit
36-
function parameter.
37-
"""
37+
def __init__(
38+
self,
39+
type: "Type",
40+
value,
41+
strict: bool,
42+
allow_downcast=None,
43+
container: Optional[Container] = None,
44+
name: Optional[str] = None,
45+
):
46+
r"""
47+
Parameters
48+
----------
49+
type
50+
The `Type` for this variable (see `Variable`).
51+
value
52+
A value to associate with this variable (a new container will be
53+
created).
54+
strict
55+
``True`` means that values assigned to this variable will not be
56+
cast or copied, so they must have the correct `Type`\s.
57+
allow_downcast
58+
Only applies if `strict` is ``False``.
59+
``True`` means that the assigned value can lose precision when cast
60+
during assignment. ``None`` means that only down-casting of a Python
61+
float to a scalar ``floatX`` is allowed.
62+
container
63+
The container to use for this variable. Illegal to pass this as well as
64+
a value.
65+
name
66+
The name for this variable (see `Variable`).
3867
39-
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
40-
super().__init__(type=type, name=name, owner=None, index=None)
68+
"""
69+
super().__init__(type=type, owner=None, index=None, name=name)
4170

4271
if container is not None:
4372
self.container = container
@@ -107,26 +136,6 @@ def set_value(self, new_value, borrow=False):
107136
def get_test_value(self):
108137
return self.get_value(borrow=True, return_internal_type=True)
109138

110-
def zero(self, borrow=False):
111-
"""
112-
Set the values of a shared variable to 0.
113-
114-
Parameters
115-
----------
116-
borrow : bbol
117-
True to modify the value of a shared variable directly by using
118-
its previous value. Potentially this can cause problems
119-
regarding to the aliased memory.
120-
121-
Changes done with this function will be visible to all functions using
122-
this SharedVariable.
123-
124-
"""
125-
if borrow:
126-
self.container.value[...] = 0
127-
else:
128-
self.container.value = 0 * self.container.value
129-
130139
def clone(self, **kwargs):
131140
name = kwargs.get("name", self.name)
132141
cp = self.__class__(
@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
209218
return SharedVariable(
210219
type=generic,
211220
value=value,
212-
name=name,
213221
strict=strict,
214222
allow_downcast=allow_downcast,
223+
name=name,
215224
)

aesara/sparse/sharedvar.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import scipy.sparse
44

5-
from aesara.compile import SharedVariable, shared_constructor
5+
from aesara.compile import shared_constructor
66
from aesara.sparse.basic import SparseTensorType, _sparse_py_operators
7+
from aesara.tensor.sharedvar import TensorSharedVariable
78

89

9-
class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
10-
dtype = property(lambda self: self.type.dtype)
11-
format = property(lambda self: self.type.format)
10+
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
11+
@property
12+
def format(self):
13+
return self.type.format
1214

1315

1416
@shared_constructor.register(scipy.sparse.spmatrix)
@@ -24,5 +26,5 @@ def sparse_constructor(
2426
value = copy.deepcopy(value)
2527

2628
return SparseTensorSharedVariable(
27-
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
29+
type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name
2830
)

aesara/tensor/random/var.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def randomgen_constructor(
3737
return rng_sv_type(
3838
type=rng_type,
3939
value=value,
40-
name=name,
4140
strict=strict,
4241
allow_downcast=allow_downcast,
42+
name=name,
4343
)

aesara/tensor/sharedvar.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,25 @@ def load_shared_variable(val):
1919
return tensor_constructor(val)
2020

2121

22-
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
2322
class TensorSharedVariable(_tensor_py_operators, SharedVariable):
24-
pass
23+
def zero(self, borrow: bool = False):
24+
r"""Set the values of a shared variable to 0.
25+
26+
Parameters
27+
----------
28+
borrow
29+
``True`` to modify the value of a shared variable directly by using
30+
its previous value. Potentially this can cause problems regarding
31+
to the aliased memory.
32+
33+
Changes done with this function will be visible to all functions using
34+
this `SharedVariable`.
35+
36+
"""
37+
if borrow:
38+
self.container.value[...] = 0
39+
else:
40+
self.container.value = 0 * self.container.value
2541

2642

2743
@_get_vector_length.register(TensorSharedVariable)
@@ -69,13 +85,13 @@ def tensor_constructor(
6985
return TensorSharedVariable(
7086
type=type,
7187
value=np.array(value, copy=(not borrow)),
72-
name=name,
7388
strict=strict,
7489
allow_downcast=allow_downcast,
90+
name=name,
7591
)
7692

7793

78-
class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
94+
class ScalarSharedVariable(TensorSharedVariable):
7995
pass
8096

8197

tests/sparse/test_sharedvar.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
import scipy as sp
3+
4+
import aesara
5+
from aesara.sparse.sharedvar import SparseTensorSharedVariable
6+
7+
8+
def test_shared_basic():
9+
x = aesara.shared(
10+
sp.sparse.csr_matrix(np.eye(100), dtype=np.float64), name="blah", borrow=True
11+
)
12+
13+
assert isinstance(x, SparseTensorSharedVariable)
14+
assert x.format == "csr"
15+
assert x.dtype == "float64"

tests/sparse/test_sparse.py

-14
This file was deleted.

tests/tensor/test_sharedvar.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aesara.tensor import get_vector_length
1111
from aesara.tensor.basic import MakeVector
1212
from aesara.tensor.shape import Shape_i, specify_shape
13+
from aesara.tensor.sharedvar import ScalarSharedVariable, TensorSharedVariable
1314
from tests import unittest_tools as utt
1415

1516

@@ -649,10 +650,31 @@ class TestSharedOptions:
649650
pass
650651

651652

653+
def test_tensor_shared_zero():
654+
shared_val = np.array([1.0, 3.0], dtype=np.float32)
655+
res = aesara.shared(value=shared_val, borrow=True)
656+
assert isinstance(res, TensorSharedVariable)
657+
assert res.get_value(borrow=True) is shared_val
658+
659+
res.zero(borrow=True)
660+
new_shared_val = res.get_value(borrow=True)
661+
assert new_shared_val is shared_val
662+
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))
663+
664+
res.set_value(shared_val, borrow=True)
665+
666+
res.zero(borrow=False)
667+
new_shared_val = res.get_value(borrow=True)
668+
assert new_shared_val is not shared_val
669+
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))
670+
671+
652672
def test_scalar_shared_options():
653-
# Simple test to make sure we do not loose that fonctionality.
654-
aesara.shared(value=0.0, name="lk", borrow=True)
655-
aesara.shared(value=np.float32(0.0), name="lk", borrow=True)
673+
res = aesara.shared(value=np.float32(0.0), name="lk", borrow=True)
674+
assert isinstance(res, ScalarSharedVariable)
675+
assert res.type.dtype == "float32"
676+
assert res.name == "lk"
677+
assert res.type.shape == ()
656678

657679

658680
def test_get_vector_length():

0 commit comments

Comments
 (0)