Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 27, 2021
1 parent 6542e40 commit 13ec1ee
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,9 @@ def _class_test(
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items()
}
if isinstance(preds, torch.Tensor):
preds[i].cpu()
target[i].cpu()
sk_batch_result = sk_metric(preds[i], target[i], **batch_kwargs_update)
preds_ = preds[i].cpu() if isinstance(preds, torch.Tensor) else preds[i]
target_ = target[i].cpu() if isinstance(target, torch.Tensor) else target[i]
sk_batch_result = sk_metric(preds_, target_, **batch_kwargs_update)
if isinstance(batch_result, dict):
for key in batch_result.keys():
_assert_allclose(batch_result, sk_batch_result[key].numpy(), atol=atol, key=key)
Expand Down

0 comments on commit 13ec1ee

Please sign in to comment.