diff --git a/tensornetwork/__init__.py b/tensornetwork/__init__.py index fd6269031..96a490fe4 100644 --- a/tensornetwork/__init__.py +++ b/tensornetwork/__init__.py @@ -10,12 +10,8 @@ from tensornetwork.version import __version__ from tensornetwork.visualization.graphviz import to_graphviz from tensornetwork import contractors -from tensornetwork import config from typing import Text, Optional, Type, Union from tensornetwork.utils import load_nodes, save_nodes from tensornetwork.matrixproductstates.finite_mps import FiniteMPS from tensornetwork.matrixproductstates.infinite_mps import InfiniteMPS - - -def set_default_backend(backend: Union[Text, BaseBackend]) -> None: - config.default_backend = backend +from tensornetwork.backend_contextmanager import DefaultBackend, set_default_backend diff --git a/tensornetwork/backend_contextmanager.py b/tensornetwork/backend_contextmanager.py new file mode 100644 index 000000000..814d6d7bf --- /dev/null +++ b/tensornetwork/backend_contextmanager.py @@ -0,0 +1,41 @@ +from typing import Text, Union +from tensornetwork.backends.base_backend import BaseBackend + +class DefaultBackend(): + """Context manager for setting up backend for nodes""" + + def __init__(self, backend: Union[Text, BaseBackend]) -> None: + if not isinstance(backend, (Text, BaseBackend)): + raise ValueError("Item passed to DefaultBackend " + "must be Text or BaseBackend") + self.backend = backend + + def __enter__(self): + _default_backend_stack.stack.append(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + _default_backend_stack.stack.pop() + +class _DefaultBackendStack(): + """A stack to keep track default backends context manager""" + + def __init__(self): + self.stack = [] + self.default_backend = "numpy" + + def get_current_backend(self): + return self.stack[-1].backend if self.stack else self.default_backend + +_default_backend_stack = _DefaultBackendStack() + +def get_default_backend(): + return _default_backend_stack.get_current_backend() + +def set_default_backend(backend: Union[Text, BaseBackend]) -> None: + if _default_backend_stack.stack: + raise AssertionError("The default backend should not be changed " + "inside the backend context manager") + if not isinstance(backend, (Text, BaseBackend)): + raise ValueError("Item passed to set_default_backend " + "must be Text or BaseBackend") + _default_backend_stack.default_backend = backend diff --git a/tensornetwork/backends/backend_factory.py b/tensornetwork/backends/backend_factory.py index 52d0bfb2e..859d829a4 100644 --- a/tensornetwork/backends/backend_factory.py +++ b/tensornetwork/backends/backend_factory.py @@ -19,7 +19,6 @@ from tensornetwork.backends.shell import shell_backend from tensornetwork.backends.pytorch import pytorch_backend from tensornetwork.backends import base_backend -import tensornetwork.config as config_file _BACKENDS = { "tensorflow": tensorflow_backend.TensorFlowBackend, diff --git a/tensornetwork/config.py b/tensornetwork/config.py index 8c347b4fa..0a54b0cb6 100644 --- a/tensornetwork/config.py +++ b/tensornetwork/config.py @@ -11,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -default_backend = "numpy" diff --git a/tensornetwork/ncon_interface.py b/tensornetwork/ncon_interface.py index f4dde1593..330bdc044 100644 --- a/tensornetwork/ncon_interface.py +++ b/tensornetwork/ncon_interface.py @@ -16,7 +16,7 @@ import warnings from typing import Any, Sequence, List, Optional, Union, Text, Tuple, Dict from tensornetwork import network_components -from tensornetwork import config +from tensornetwork.backend_contextmanager import get_default_backend from tensornetwork.backends import backend_factory Tensor = Any @@ -67,8 +67,8 @@ def ncon(tensors: Sequence[Union[network_components.BaseNode, Tensor]], structure. con_order: List of edge labels specifying the contraction order. out_order: List of edge labels specifying the output order. - backend: String specifying the backend to use. Defaults to - `tensornetwork.config.default_backend`. + backend: String specifying the backend to use. Defaults to + `tensornetwork.backend_contextmanager.get_default_backend`. Returns: The result of the contraction. The result is returned as a `Node` @@ -78,7 +78,7 @@ def ncon(tensors: Sequence[Union[network_components.BaseNode, Tensor]], if backend and (backend not in backend_factory._BACKENDS): raise ValueError("Backend '{}' does not exist".format(backend)) if backend is None: - backend = config.default_backend + backend = get_default_backend() are_nodes = [isinstance(t, network_components.BaseNode) for t in tensors] nodes = {t for t in tensors if isinstance(t, network_components.BaseNode)} diff --git a/tensornetwork/network_components.py b/tensornetwork/network_components.py index 1f6de8917..8d564df5c 100644 --- a/tensornetwork/network_components.py +++ b/tensornetwork/network_components.py @@ -21,10 +21,10 @@ import h5py #pylint: disable=useless-import-alias -import tensornetwork.config as config from tensornetwork import ops from tensornetwork.backends import backend_factory from tensornetwork.backends.base_backend import BaseBackend +from tensornetwork.backend_contextmanager import get_default_backend string_type = h5py.special_dtype(vlen=str) Tensor = Any @@ -525,8 +525,8 @@ def __init__(self, """Create a node. Args: - tensor: The concrete that is represented by this node, or a `BaseNode` - object. If a tensor is passed, it can be + tensor: The concrete that is represented by this node, or a `BaseNode` + object. If a tensor is passed, it can be be either a numpy array or the tensor-type of the used backend. If a `BaseNode` is passed, the passed node has to have the same \ backend as given by `backend`. @@ -543,7 +543,7 @@ def __init__(self, backend = tensor.backend tensor = tensor.tensor if not backend: - backend = config.default_backend + backend = get_default_backend() if isinstance(backend, BaseBackend): backend_obj = backend else: @@ -633,13 +633,13 @@ def __init__(self, backend: An optional backend for the node. If `None`, a default backend is used dtype: The dtype used to initialize a numpy-copy node. - Note that this dtype has to be a numpy dtype, and it has to be + Note that this dtype has to be a numpy dtype, and it has to be compatible with the dtype of the backend, e.g. for a tensorflow backend with a tf.Dtype=tf.floa32, `dtype` has to be `np.float32`. """ if not backend: - backend = config.default_backend + backend = get_default_backend() backend_obj = backend_factory.get_backend(backend) self.rank = rank @@ -1092,14 +1092,14 @@ def disconnect(self, edge2_name: Optional[Text] = None) -> Tuple["Edge", "Edge"]: """ Break an existing non-dangling edge. - This updates both Edge.node1 and Edge.node2 by removing the + This updates both Edge.node1 and Edge.node2 by removing the connecting edge from `Edge.node1.edges` and `Edge.node2.edges` and adding new dangling edges instead Args: edge1_name: A name for the new dangling edge at `self.node1` edge2_name: A name for the new dangling edge at `self.node2` Returns: - (new_edge1, new_edge2): The new `Edge` objects of + (new_edge1, new_edge2): The new `Edge` objects of `self.node1` and `self.node2` """ if self.is_dangling(): @@ -1155,7 +1155,7 @@ def get_parallel_edges(edge: Edge) -> Set[Edge]: edge: The given edge. Returns: - A `set` of all of the edges parallel to the given edge + A `set` of all of the edges parallel to the given edge (including the given edge). """ return get_shared_edges(edge.node1, edge.node2) @@ -1389,8 +1389,8 @@ def split_edge(edge: Edge, shape: Tuple[int, ...], new_edge_names: Optional[List[Text]] = None) -> List[Edge]: """Split an `Edge` into multiple edges according to `shape`. Reshapes - the underlying tensors connected to the edge accordingly. - + the underlying tensors connected to the edge accordingly. + This method acts as the inverse operation of flattening edges and distinguishes between the following edge cases when adding new edges: 1) standard edge connecting two different nodes: reshape node dimensions @@ -1772,7 +1772,7 @@ def disconnect(edge, edge2_name: Optional[Text] = None) -> Tuple[Edge, Edge]: """ Break an existing non-dangling edge. - This updates both Edge.node1 and Edge.node2 by removing the + This updates both Edge.node1 and Edge.node2 by removing the connecting edge from `Edge.node1.edges` and `Edge.node2.edges` and adding new dangling edges instead """ @@ -1894,9 +1894,9 @@ def outer_product_final_nodes(nodes: Iterable[BaseNode], edge_order: List[Edge]) -> BaseNode: """Get the outer product of `nodes` - For example, if there are 3 nodes remaining in `nodes` with + For example, if there are 3 nodes remaining in `nodes` with shapes :math:`(2, 3)`, :math:`(4, 5, 6)`, and :math:`(7)` - respectively, the newly returned node will have shape + respectively, the newly returned node will have shape :math:`(2, 3, 4, 5, 6, 7)`. Args: diff --git a/tensornetwork/network_operations.py b/tensornetwork/network_operations.py index 5aea214db..c2e228792 100644 --- a/tensornetwork/network_operations.py +++ b/tensornetwork/network_operations.py @@ -19,7 +19,6 @@ import numpy as np #pylint: disable=useless-import-alias -import tensornetwork.config as config #pylint: disable=line-too-long from tensornetwork.network_components import BaseNode, Node, CopyNode, Edge, disconnect from tensornetwork.backends import backend_factory diff --git a/tensornetwork/tests/backend_contextmanager_test.py b/tensornetwork/tests/backend_contextmanager_test.py new file mode 100644 index 000000000..60f6e833b --- /dev/null +++ b/tensornetwork/tests/backend_contextmanager_test.py @@ -0,0 +1,46 @@ +import tensornetwork as tn +from tensornetwork.backend_contextmanager import _default_backend_stack +import pytest +import numpy as np + +def test_contextmanager_simple(): + with tn.DefaultBackend("tensorflow"): + a = tn.Node(np.ones((10,))) + b = tn.Node(np.ones((10,))) + assert a.backend.name == b.backend.name + +def test_contextmanager_default_backend(): + tn.set_default_backend("pytorch") + with tn.DefaultBackend("numpy"): + assert _default_backend_stack.default_backend == "pytorch" + +def test_contextmanager_interruption(): + tn.set_default_backend("pytorch") + with pytest.raises(AssertionError): + with tn.DefaultBackend("numpy"): + tn.set_default_backend("tensorflow") + +def test_contextmanager_nested(): + with tn.DefaultBackend("tensorflow"): + a = tn.Node(np.ones((10,))) + assert a.backend.name == "tensorflow" + with tn.DefaultBackend("numpy"): + b = tn.Node(np.ones((10,))) + assert b.backend.name == "numpy" + c = tn.Node(np.ones((10,))) + assert c.backend.name == "tensorflow" + d = tn.Node(np.ones((10,))) + assert d.backend.name == "numpy" + +def test_contextmanager_wrong_item(): + a = tn.Node(np.ones((10,))) + with pytest.raises(ValueError): + with tn.DefaultBackend(a): # pytype: disable=wrong-arg-types + pass + +def test_contextmanager_BaseBackend(): + tn.set_default_backend("pytorch") + a = tn.Node(np.ones((10,))) + with tn.DefaultBackend(a.backend): + b = tn.Node(np.ones((10,))) + assert b.backend.name == "pytorch" diff --git a/tensornetwork/tests/tensornetwork_test.py b/tensornetwork/tests/tensornetwork_test.py index 259aff3ac..10cb4ce38 100644 --- a/tensornetwork/tests/tensornetwork_test.py +++ b/tensornetwork/tests/tensornetwork_test.py @@ -13,6 +13,7 @@ # limitations under the License. import tensornetwork as tn +from tensornetwork.backend_contextmanager import _default_backend_stack import pytest import numpy as np import tensorflow as tf @@ -522,7 +523,7 @@ def test_set_node2(backend): def test_set_default(backend): tn.set_default_backend(backend) - assert tn.config.default_backend == backend + assert _default_backend_stack.default_backend == backend a = tn.Node(np.eye(2)) assert a.backend.name == backend