Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 26, 2021
1 parent 38819cb commit bf78331
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def _assert_half_support(
target: torch tensor with targets
device: determine device, either "cpu" or "cuda"
"""
p = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device)
t = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device)
y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device)
y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device)
metric_module = metric_module.to(device)
assert metric_module(p, t)
assert metric_functional(p, t)
assert metric_module(y_hat, y)
assert metric_functional(y_hat, y)


class MetricTester:
Expand Down

0 comments on commit bf78331

Please sign in to comment.