Skip to content

Commit

Permalink
Refactor SharedVariable type and interface
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 21, 2022
1 parent bae26ef commit 3ad936f
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 56 deletions.
67 changes: 38 additions & 29 deletions aesara/compile/sharedvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import copy
from contextlib import contextmanager
from functools import singledispatch
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

from aesara.graph.basic import Variable
from aesara.graph.utils import add_tag_trace
from aesara.link.basic import Container
from aesara.link.c.type import generic


if TYPE_CHECKING:
from aesara.graph.type import Type


__SHARED_CONTEXT__: Optional[List[Variable]] = None


Expand All @@ -30,14 +34,39 @@ def collect_new_shareds():
class SharedVariable(Variable):
"""Variable that is shared between compiled functions."""

container: Optional[Container] = None
"""
A container to use for this SharedVariable when it is an implicit
function parameter.
"""
def __init__(
self,
type: "Type",
value,
strict: bool,
allow_downcast=None,
container: Optional[Container] = None,
name: Optional[str] = None,
):
r"""
Parameters
----------
type
The `Type` for this variable (see `Variable`).
value
A value to associate with this variable (a new container will be
created).
strict
``True`` means that values assigned to this variable will not be
cast or copied, so they must have the correct `Type`\s.
allow_downcast
Only applies if `strict` is ``False``.
``True`` means that the assigned value can lose precision when cast
during assignment. ``None`` means that only down-casting of a Python
float to a scalar ``floatX`` is allowed.
container
The container to use for this variable. Illegal to pass this as well as
a value.
name
The name for this variable (see `Variable`).
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
super().__init__(type=type, name=name, owner=None, index=None)
"""
super().__init__(type=type, owner=None, index=None, name=name)

if container is not None:
self.container = container
Expand Down Expand Up @@ -107,26 +136,6 @@ def set_value(self, new_value, borrow=False):
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)

def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
Parameters
----------
borrow : bbol
True to modify the value of a shared variable directly by using
its previous value. Potentially this can cause problems
regarding to the aliased memory.
Changes done with this function will be visible to all functions using
this SharedVariable.
"""
if borrow:
self.container.value[...] = 0
else:
self.container.value = 0 * self.container.value

def clone(self, **kwargs):
name = kwargs.get("name", self.name)
cp = self.__class__(
Expand Down Expand Up @@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
return SharedVariable(
type=generic,
value=value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
name=name,
)
12 changes: 7 additions & 5 deletions aesara/sparse/sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import scipy.sparse

from aesara.compile import SharedVariable, shared_constructor
from aesara.compile import shared_constructor
from aesara.sparse.basic import SparseTensorType, _sparse_py_operators
from aesara.tensor.sharedvar import TensorSharedVariable


class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
@property
def format(self):
return self.type.format


@shared_constructor.register(scipy.sparse.spmatrix)
Expand All @@ -24,5 +26,5 @@ def sparse_constructor(
value = copy.deepcopy(value)

return SparseTensorSharedVariable(
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name
)
2 changes: 1 addition & 1 deletion aesara/tensor/random/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def randomgen_constructor(
return rng_sv_type(
type=rng_type,
value=value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
name=name,
)
24 changes: 20 additions & 4 deletions aesara/tensor/sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,25 @@ def load_shared_variable(val):
return tensor_constructor(val)


# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass
def zero(self, borrow: bool = False):
r"""Set the values of a shared variable to 0.
Parameters
----------
borrow
``True`` to modify the value of a shared variable directly by using
its previous value. Potentially this can cause problems regarding
to the aliased memory.
Changes done with this function will be visible to all functions using
this `SharedVariable`.
"""
if borrow:
self.container.value[...] = 0
else:
self.container.value = 0 * self.container.value


@_get_vector_length.register(TensorSharedVariable)
Expand Down Expand Up @@ -69,13 +85,13 @@ def tensor_constructor(
return TensorSharedVariable(
type=type,
value=np.array(value, copy=(not borrow)),
name=name,
strict=strict,
allow_downcast=allow_downcast,
name=name,
)


class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
class ScalarSharedVariable(TensorSharedVariable):
pass


Expand Down
15 changes: 15 additions & 0 deletions tests/sparse/test_sharedvar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
import scipy as sp

import aesara
from aesara.sparse.sharedvar import SparseTensorSharedVariable


def test_shared_basic():
x = aesara.shared(
sp.sparse.csr_matrix(np.eye(100), dtype=np.float64), name="blah", borrow=True
)

assert isinstance(x, SparseTensorSharedVariable)
assert x.format == "csr"
assert x.dtype == "float64"
14 changes: 0 additions & 14 deletions tests/sparse/test_sparse.py

This file was deleted.

28 changes: 25 additions & 3 deletions tests/tensor/test_sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aesara.tensor import get_vector_length
from aesara.tensor.basic import MakeVector
from aesara.tensor.shape import Shape_i, specify_shape
from aesara.tensor.sharedvar import ScalarSharedVariable, TensorSharedVariable
from tests import unittest_tools as utt


Expand Down Expand Up @@ -649,10 +650,31 @@ class TestSharedOptions:
pass


def test_tensor_shared_zero():
shared_val = np.array([1.0, 3.0], dtype=np.float32)
res = aesara.shared(value=shared_val, borrow=True)
assert isinstance(res, TensorSharedVariable)
assert res.get_value(borrow=True) is shared_val

res.zero(borrow=True)
new_shared_val = res.get_value(borrow=True)
assert new_shared_val is shared_val
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))

res.set_value(shared_val, borrow=True)

res.zero(borrow=False)
new_shared_val = res.get_value(borrow=True)
assert new_shared_val is not shared_val
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))


def test_scalar_shared_options():
# Simple test to make sure we do not loose that fonctionality.
aesara.shared(value=0.0, name="lk", borrow=True)
aesara.shared(value=np.float32(0.0), name="lk", borrow=True)
res = aesara.shared(value=np.float32(0.0), name="lk", borrow=True)
assert isinstance(res, ScalarSharedVariable)
assert res.type.dtype == "float32"
assert res.name == "lk"
assert res.type.shape == ()


def test_get_vector_length():
Expand Down

0 comments on commit 3ad936f

Please sign in to comment.