diff --git a/summac/model_summac.py b/summac/model_summac.py index 0026bbe..045915d 100644 --- a/summac/model_summac.py +++ b/summac/model_summac.py @@ -295,7 +295,7 @@ def forward(self, originals, generateds, images=None): histograms.append(histogram) N = len(histograms) - histograms = torch.FloatTensor(histograms).to(self.device) + histograms = torch.FloatTensor(np.array(histograms)).to(self.device) non_zeros = (torch.sum(histograms, dim=-1) != 0.0).long() seq_lengths = non_zeros.sum(dim=-1).tolist()