Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tensornetwork/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
from tensornetwork.version import __version__
from tensornetwork.visualization.graphviz import to_graphviz
from tensornetwork import contractors
from tensornetwork import config
from typing import Text, Optional, Type, Union
from tensornetwork.utils import load_nodes, save_nodes
from tensornetwork.matrixproductstates.finite_mps import FiniteMPS
from tensornetwork.matrixproductstates.infinite_mps import InfiniteMPS


def set_default_backend(backend: Union[Text, BaseBackend]) -> None:
config.default_backend = backend
from tensornetwork.backend_contextmanager import DefaultBackend, set_default_backend
41 changes: 41 additions & 0 deletions tensornetwork/backend_contextmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Text, Union
from tensornetwork.backends.base_backend import BaseBackend

class DefaultBackend():
"""Context manager for setting up backend for nodes"""

def __init__(self, backend: Union[Text, BaseBackend]) -> None:
if not isinstance(backend, (Text, BaseBackend)):
raise ValueError("Item passed to DefaultBackend "
"must be Text or BaseBackend")
self.backend = backend

def __enter__(self):
_default_backend_stack.stack.append(self)

def __exit__(self, exc_type, exc_val, exc_tb):
_default_backend_stack.stack.pop()

class _DefaultBackendStack():
"""A stack to keep track default backends context manager"""

def __init__(self):
self.stack = []
self.default_backend = "numpy"

def get_current_backend(self):
return self.stack[-1].backend if self.stack else self.default_backend

_default_backend_stack = _DefaultBackendStack()

def get_default_backend():
return _default_backend_stack.get_current_backend()

def set_default_backend(backend: Union[Text, BaseBackend]) -> None:
if _default_backend_stack.stack:
raise AssertionError("The default backend should not be changed "
"inside the backend context manager")
if not isinstance(backend, (Text, BaseBackend)):
raise ValueError("Item passed to set_default_backend "
"must be Text or BaseBackend")
_default_backend_stack.default_backend = backend
1 change: 0 additions & 1 deletion tensornetwork/backends/backend_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tensornetwork.backends.shell import shell_backend
from tensornetwork.backends.pytorch import pytorch_backend
from tensornetwork.backends import base_backend
import tensornetwork.config as config_file

_BACKENDS = {
"tensorflow": tensorflow_backend.TensorFlowBackend,
Expand Down
2 changes: 0 additions & 2 deletions tensornetwork/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

default_backend = "numpy"
8 changes: 4 additions & 4 deletions tensornetwork/ncon_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from typing import Any, Sequence, List, Optional, Union, Text, Tuple, Dict
from tensornetwork import network_components
from tensornetwork import config
from tensornetwork.backend_contextmanager import get_default_backend
from tensornetwork.backends import backend_factory
Tensor = Any

Expand Down Expand Up @@ -67,8 +67,8 @@ def ncon(tensors: Sequence[Union[network_components.BaseNode, Tensor]],
structure.
con_order: List of edge labels specifying the contraction order.
out_order: List of edge labels specifying the output order.
backend: String specifying the backend to use. Defaults to
`tensornetwork.config.default_backend`.
backend: String specifying the backend to use. Defaults to
`tensornetwork.backend_contextmanager.get_default_backend`.

Returns:
The result of the contraction. The result is returned as a `Node`
Expand All @@ -78,7 +78,7 @@ def ncon(tensors: Sequence[Union[network_components.BaseNode, Tensor]],
if backend and (backend not in backend_factory._BACKENDS):
raise ValueError("Backend '{}' does not exist".format(backend))
if backend is None:
backend = config.default_backend
backend = get_default_backend()

are_nodes = [isinstance(t, network_components.BaseNode) for t in tensors]
nodes = {t for t in tensors if isinstance(t, network_components.BaseNode)}
Expand Down
28 changes: 14 additions & 14 deletions tensornetwork/network_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import h5py

#pylint: disable=useless-import-alias
import tensornetwork.config as config
from tensornetwork import ops
from tensornetwork.backends import backend_factory
from tensornetwork.backends.base_backend import BaseBackend
from tensornetwork.backend_contextmanager import get_default_backend

string_type = h5py.special_dtype(vlen=str)
Tensor = Any
Expand Down Expand Up @@ -525,8 +525,8 @@ def __init__(self,
"""Create a node.

Args:
tensor: The concrete that is represented by this node, or a `BaseNode`
object. If a tensor is passed, it can be
tensor: The concrete that is represented by this node, or a `BaseNode`
object. If a tensor is passed, it can be
be either a numpy array or the tensor-type of the used backend.
If a `BaseNode` is passed, the passed node has to have the same \
backend as given by `backend`.
Expand All @@ -543,7 +543,7 @@ def __init__(self,
backend = tensor.backend
tensor = tensor.tensor
if not backend:
backend = config.default_backend
backend = get_default_backend()
if isinstance(backend, BaseBackend):
backend_obj = backend
else:
Expand Down Expand Up @@ -633,13 +633,13 @@ def __init__(self,
backend: An optional backend for the node. If `None`, a default
backend is used
dtype: The dtype used to initialize a numpy-copy node.
Note that this dtype has to be a numpy dtype, and it has to be
Note that this dtype has to be a numpy dtype, and it has to be
compatible with the dtype of the backend, e.g. for a tensorflow
backend with a tf.Dtype=tf.floa32, `dtype` has to be `np.float32`.
"""

if not backend:
backend = config.default_backend
backend = get_default_backend()
backend_obj = backend_factory.get_backend(backend)

self.rank = rank
Expand Down Expand Up @@ -1092,14 +1092,14 @@ def disconnect(self,
edge2_name: Optional[Text] = None) -> Tuple["Edge", "Edge"]:
"""
Break an existing non-dangling edge.
This updates both Edge.node1 and Edge.node2 by removing the
This updates both Edge.node1 and Edge.node2 by removing the
connecting edge from `Edge.node1.edges` and `Edge.node2.edges`
and adding new dangling edges instead
Args:
edge1_name: A name for the new dangling edge at `self.node1`
edge2_name: A name for the new dangling edge at `self.node2`
Returns:
(new_edge1, new_edge2): The new `Edge` objects of
(new_edge1, new_edge2): The new `Edge` objects of
`self.node1` and `self.node2`
"""
if self.is_dangling():
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def get_parallel_edges(edge: Edge) -> Set[Edge]:
edge: The given edge.

Returns:
A `set` of all of the edges parallel to the given edge
A `set` of all of the edges parallel to the given edge
(including the given edge).
"""
return get_shared_edges(edge.node1, edge.node2)
Expand Down Expand Up @@ -1389,8 +1389,8 @@ def split_edge(edge: Edge,
shape: Tuple[int, ...],
new_edge_names: Optional[List[Text]] = None) -> List[Edge]:
"""Split an `Edge` into multiple edges according to `shape`. Reshapes
the underlying tensors connected to the edge accordingly.
the underlying tensors connected to the edge accordingly.

This method acts as the inverse operation of flattening edges and
distinguishes between the following edge cases when adding new edges:
1) standard edge connecting two different nodes: reshape node dimensions
Expand Down Expand Up @@ -1772,7 +1772,7 @@ def disconnect(edge,
edge2_name: Optional[Text] = None) -> Tuple[Edge, Edge]:
"""
Break an existing non-dangling edge.
This updates both Edge.node1 and Edge.node2 by removing the
This updates both Edge.node1 and Edge.node2 by removing the
connecting edge from `Edge.node1.edges` and `Edge.node2.edges`
and adding new dangling edges instead
"""
Expand Down Expand Up @@ -1894,9 +1894,9 @@ def outer_product_final_nodes(nodes: Iterable[BaseNode],
edge_order: List[Edge]) -> BaseNode:
"""Get the outer product of `nodes`

For example, if there are 3 nodes remaining in `nodes` with
For example, if there are 3 nodes remaining in `nodes` with
shapes :math:`(2, 3)`, :math:`(4, 5, 6)`, and :math:`(7)`
respectively, the newly returned node will have shape
respectively, the newly returned node will have shape
:math:`(2, 3, 4, 5, 6, 7)`.

Args:
Expand Down
1 change: 0 additions & 1 deletion tensornetwork/network_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import numpy as np

#pylint: disable=useless-import-alias
import tensornetwork.config as config
#pylint: disable=line-too-long
from tensornetwork.network_components import BaseNode, Node, CopyNode, Edge, disconnect
from tensornetwork.backends import backend_factory
Expand Down
46 changes: 46 additions & 0 deletions tensornetwork/tests/backend_contextmanager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import tensornetwork as tn
from tensornetwork.backend_contextmanager import _default_backend_stack
import pytest
import numpy as np

def test_contextmanager_simple():
with tn.DefaultBackend("tensorflow"):
a = tn.Node(np.ones((10,)))
b = tn.Node(np.ones((10,)))
assert a.backend.name == b.backend.name

def test_contextmanager_default_backend():
tn.set_default_backend("pytorch")
with tn.DefaultBackend("numpy"):
assert _default_backend_stack.default_backend == "pytorch"

def test_contextmanager_interruption():
tn.set_default_backend("pytorch")
with pytest.raises(AssertionError):
with tn.DefaultBackend("numpy"):
tn.set_default_backend("tensorflow")

def test_contextmanager_nested():
with tn.DefaultBackend("tensorflow"):
a = tn.Node(np.ones((10,)))
assert a.backend.name == "tensorflow"
with tn.DefaultBackend("numpy"):
b = tn.Node(np.ones((10,)))
assert b.backend.name == "numpy"
c = tn.Node(np.ones((10,)))
assert c.backend.name == "tensorflow"
d = tn.Node(np.ones((10,)))
assert d.backend.name == "numpy"

def test_contextmanager_wrong_item():
a = tn.Node(np.ones((10,)))
with pytest.raises(ValueError):
with tn.DefaultBackend(a): # pytype: disable=wrong-arg-types
pass

def test_contextmanager_BaseBackend():
tn.set_default_backend("pytorch")
a = tn.Node(np.ones((10,)))
with tn.DefaultBackend(a.backend):
b = tn.Node(np.ones((10,)))
assert b.backend.name == "pytorch"
3 changes: 2 additions & 1 deletion tensornetwork/tests/tensornetwork_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import tensornetwork as tn
from tensornetwork.backend_contextmanager import _default_backend_stack
import pytest
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -522,7 +523,7 @@ def test_set_node2(backend):

def test_set_default(backend):
tn.set_default_backend(backend)
assert tn.config.default_backend == backend
assert _default_backend_stack.default_backend == backend
a = tn.Node(np.eye(2))
assert a.backend.name == backend

Expand Down