Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tensornetwork/network_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,18 @@ def __init__(self,
"""

self.is_disabled = False
self.name = name if name is not None else '__unnamed_node__'
if not name:
name = '__unnamed_node__'
else:
if not isinstance(name, str):
raise TypeError("Node name should be str type")
self.name = name
self.backend = backend
self._shape = shape
if axis_names is not None:
for axis_name in axis_names:
if not isinstance(axis_name, str):
raise TypeError("axis_names should be str type")
self._edges = [
Edge(node1=self, axis1=i, name=edge_name)
for i, edge_name in enumerate(axis_names)
Expand Down Expand Up @@ -125,6 +133,9 @@ def add_axis_names(self, axis_names: List[Text]) -> None:
raise ValueError("axis_names is not the same length as the tensor shape."
"axis_names length: {}, tensor.shape length: {}".format(
len(axis_names), len(self.shape)))
for axis_name in axis_names:
if not isinstance(axis_name, str):
raise TypeError("axis_names should be str type")
self.axis_names = axis_names[:]

def add_edge(self,
Expand Down Expand Up @@ -312,6 +323,8 @@ def get_all_dangling(self) -> Set["Edge"]:
return {edge for edge in self.edges if edge.is_dangling()}

def set_name(self, name) -> None:
if not isinstance(name, str):
raise TypeError("Node name should be str type")
self.name = name

def has_nondangling_edge(self) -> bool:
Expand Down Expand Up @@ -373,6 +386,16 @@ def edges(self, edges: List) -> None:
self.name))
self._edges = edges

@property
def name(self) -> Text:
return self._name

@name.setter
def name(self, name) -> None:
if not isinstance(name, str):
raise TypeError("Node name should be str type")
self._name = name

@property
def axis_names(self) -> List[Text]:
return self._axis_names
Expand All @@ -382,8 +405,12 @@ def axis_names(self, axis_names: List[Text]) -> None:
if len(axis_names) != len(self.shape):
raise ValueError("Expected {} names, only got {}.".format(
len(self.shape), len(axis_names)))
for axis_name in axis_names:
if not isinstance(axis_name, str):
raise TypeError("axis_names should be str type")
self._axis_names = axis_names


@property
def signature(self) -> Optional[int]:
if self.is_disabled:
Expand Down Expand Up @@ -810,6 +837,9 @@ def __init__(self,
self.is_disabled = False
if not name:
name = '__unnamed_edge__'
else:
if not isinstance(name, str):
raise TypeError("Edge name should be str type")
self._name = name
self.node1 = node1
self._axis1 = axis1
Expand Down Expand Up @@ -844,6 +874,8 @@ def name(self, name) -> None:
if self.is_disabled:
raise ValueError(
'Edge has been disabled, setting its name is no longer possible')
if not isinstance(name, str):
raise TypeError("Edge name should be str type")
self._name = name

@property
Expand Down Expand Up @@ -988,6 +1020,8 @@ def is_being_used(self) -> bool:
return result

def set_name(self, name: Text) -> None:
if not isinstance(name, str):
raise TypeError("Edge name should be str type")
self.name = name

def _save_edge(self, edge_group: h5py.Group) -> None:
Expand Down