Skip to content

Commit

Permalink
fixed full on master pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Aelphy committed Aug 26, 2019
1 parent b807db9 commit 38fe7db
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions t3nsor/tensor_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, tt_cores, shape=None, tt_ranks=None, convert_to_tensors=True)
self._parameter = None
self._dof = np.sum([np.prod(list(tt_core.shape)) for tt_core in self._tt_cores])
self._total = np.prod(self._shape)


@property
def tt_cores(self):
Expand Down Expand Up @@ -73,15 +73,15 @@ def parameter(self):
return self._parameter
else:
raise ValueError('Not a parameter, run .to_parameter() first')

@property
def dof(self):
return self._dof

@property
def total(self):
return self._total


def to(self, device):
new_cores = []
Expand Down Expand Up @@ -109,7 +109,7 @@ def to_parameter(self):
new_cores.append(core)

tt_p = TensorTrain(new_cores, convert_to_tensors=False)
tt_p._parameter = nn.ParameterList(tt_p.tt_cores)
tt_p._parameter = nn.ParameterList(tt_p.tt_cores)
tt_p._is_parameter = True
return tt_p

Expand All @@ -122,7 +122,7 @@ def full(self):

for i in range(1, num_dims):
res = res.view(-1, ranks[i])
curr_core = self.tt_cores[i].view(ranks[i], -1)
curr_core = self.tt_cores[i].reshape(ranks[i], -1)
res = torch.matmul(res, curr_core)

if self.is_tt_matrix:
Expand All @@ -138,7 +138,7 @@ def full(self):
for i in range(1, 2 * num_dims, 2):
transpose.append(i)
res = res.permute(*transpose)

if self.is_tt_matrix:
res = res.contiguous().view(*shape)
else:
Expand Down Expand Up @@ -268,8 +268,8 @@ def full(self):
for i in range(1, 2 * num_dims, 2):
transpose.append(i + 1)
res = res.permute(transpose)
if self.is_tt_matrix:

if self.is_tt_matrix:
res = res.contiguous().view(*shape)
else:
res = res.view(*shape)
Expand Down

0 comments on commit 38fe7db

Please sign in to comment.