Skip to content

Commit

Permalink
add contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
kaihsin committed Apr 17, 2024
1 parent 1f6b59b commit 1ce23f5
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions src/cytnx_torch/unitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,17 @@ def rank(self) -> int:
return len(self.labels)

@property
def shape(self) -> List[int]:
return [b.dim for b in self.bonds]
def shape(self) -> Tuple[int]:
return tuple([b.dim for b in self.bonds])

@property
@abstractmethod
def is_contiguous(self) -> bool:
raise NotImplementedError("not implement for abstract type trait.")

@abstractmethod
def contiguous(self) -> "AbstractUniTensor":
raise NotImplementedError("not implement for abstract type trait.")

def _relabel(self, old_labels: List[str], new_labels: List[str]) -> None:

Expand Down Expand Up @@ -160,6 +169,13 @@ def _repr_body_diagram(self) -> str:
def is_sym(self) -> bool:
return False

@property
def is_contiguous(self) -> bool:
return self.data.is_contiguous()

def contiguous(self) -> "RegularUniTensor":
return RegularUniTensor(**self._get_generic_meta(), data=self.data.contiguous())

def permute(self, *args, by_label: bool = True) -> "RegularUniTensor":

if by_label:
Expand Down Expand Up @@ -246,6 +262,16 @@ def _repr_body_diagram(self) -> str:
def is_sym(self) -> bool:
return True

@property
def is_contiguous(self) -> bool:
return np.all([blk.is_contiguous() for blk in self.blocks])

def contiguous(self) -> "BlockUniTensor":
return BlockUniTensor(
**self._get_generic_meta(),
blocks=[blk.contiguous() for blk in self.blocks],
)

def permute(self, *args, by_label: bool = True) -> "BlockUniTensor":

if by_label:
Expand Down

0 comments on commit 1ce23f5

Please sign in to comment.